-- |
-- Module      : Crypto.KDF.Scrypt
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Scrypt key derivation function as defined in Colin Percival's paper
-- "Stronger Key Derivation via Sequential Memory-Hard Functions"
-- <http://www.tarsnap.com/scrypt/scrypt.pdf>.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Crypto.KDF.Scrypt
    ( Parameters(..)
    , generate
    ) where

import           Data.Word
import           Foreign.Marshal.Alloc
import           Foreign.Ptr (Ptr, plusPtr)
import           Control.Monad (forM_)

import           Crypto.Hash (SHA256(..))
import qualified Crypto.KDF.PBKDF2 as PBKDF2
import           Crypto.Internal.Compat (popCount, unsafeDoIO)
import           Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B

-- | Parameters for Scrypt
data Parameters = Parameters
    { Parameters -> Word64
n            :: Word64 -- ^ Cpu/Memory cost ratio. must be a power of 2 greater than 1. also known as N.
    , Parameters -> Int
r            :: Int    -- ^ Must satisfy r * p < 2^30
    , Parameters -> Int
p            :: Int    -- ^ Must satisfy r * p < 2^30
    , Parameters -> Int
outputLength :: Int    -- ^ the number of bytes to generate out of Scrypt
    }

foreign import ccall "crypton_scrypt_smix"
    ccrypton_scrypt_smix :: Ptr Word8 -> Word32 -> Word64 -> Ptr Word8 -> Ptr Word8 -> IO ()

-- | Generate the scrypt key derivation data
generate :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray output)
         => Parameters
         -> password
         -> salt
         -> output
generate :: forall password salt output.
(ByteArrayAccess password, ByteArrayAccess salt,
 ByteArray output) =>
Parameters -> password -> salt -> output
generate Parameters
params password
password salt
salt
    | Parameters -> Int
r Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
* Parameters -> Int
p Parameters
params Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0x40000000 =
        [Char] -> output
forall a. HasCallStack => [Char] -> a
error [Char]
"Scrypt: invalid parameters: r and p constraint"
    | Word64 -> Int
forall a. Bits a => a -> Int
popCount (Parameters -> Word64
n Parameters
params) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 =
        [Char] -> output
forall a. HasCallStack => [Char] -> a
error [Char]
"Scrypt: invalid parameters: n not a power of 2"
    | Bool
otherwise = IO output -> output
forall a. IO a -> a
unsafeDoIO (IO output -> output) -> IO output -> output
forall a b. (a -> b) -> a -> b
$ do
        let b :: Bytes
b = PRF password -> Parameters -> password -> salt -> Bytes
forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
PBKDF2.generate PRF password
prf (Int -> Int -> Parameters
PBKDF2.Parameters Int
1 Int
intLen) password
password salt
salt :: B.Bytes
        Bytes
newSalt <- Bytes -> (Ptr Any -> IO ()) -> IO Bytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy Bytes
b ((Ptr Any -> IO ()) -> IO Bytes) -> (Ptr Any -> IO ()) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Any
bPtr ->
            Int -> Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
128Int -> Int -> Int
forall a. Num a => a -> a -> a
*(Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ Parameters -> Word64
n Parameters
params)Int -> Int -> Int
forall a. Num a => a -> a -> a
*(Parameters -> Int
r Parameters
params)) Int
8 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
v ->
            Int -> Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned (Int
256Int -> Int -> Int
forall a. Num a => a -> a -> a
*Parameters -> Int
r Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64) Int
8 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
xy -> do
                [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..(Parameters -> Int
p Parameters
paramsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
                    Ptr Word8 -> Word32 -> Word64 -> Ptr Word8 -> Ptr Word8 -> IO ()
ccrypton_scrypt_smix (Ptr Any
bPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Parameters -> Int
r Parameters
params)))
                                            (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
r Parameters
params) (Parameters -> Word64
n Parameters
params) Ptr Word8
v Ptr Word8
xy

        output -> IO output
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (output -> IO output) -> output -> IO output
forall a b. (a -> b) -> a -> b
$ PRF password -> Parameters -> password -> Bytes -> output
forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
PBKDF2.generate PRF password
prf (Int -> Int -> Parameters
PBKDF2.Parameters Int
1 (Parameters -> Int
outputLength Parameters
params)) password
password (Bytes
newSalt :: B.Bytes)
  where prf :: PRF password
prf    = SHA256 -> PRF password
forall a password.
(HashAlgorithm a, ByteArrayAccess password) =>
a -> PRF password
PBKDF2.prfHMAC SHA256
SHA256
        intLen :: Int
intLen = Parameters -> Int
p Parameters
params Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Parameters -> Int
r Parameters
params
{-# NOINLINE generate #-}