-- |
-- Module      : Crypto.KDF.PBKDF2
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Password Based Key Derivation Function 2
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface #-}

module Crypto.KDF.PBKDF2
    ( PRF
    , prfHMAC
    , Parameters(..)
    , generate
    , fastPBKDF2_SHA1
    , fastPBKDF2_SHA256
    , fastPBKDF2_SHA512
    ) where

import           Data.Word
import           Data.Bits
import           Foreign.Marshal.Alloc
import           Foreign.Ptr (plusPtr, Ptr)
import           Foreign.C.Types (CUInt(..), CSize(..))

import           Crypto.Hash (HashAlgorithm)
import qualified Crypto.MAC.HMAC as HMAC

import           Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess, Bytes)
import qualified Crypto.Internal.ByteArray as B
import           Data.Memory.PtrMethods

-- | The PRF used for PBKDF2
type PRF password =
       password -- ^ the password parameters
    -> Bytes    -- ^ the content
    -> Bytes    -- ^ prf(password,content)

-- | PRF for PBKDF2 using HMAC with the hash algorithm as parameter
prfHMAC :: (HashAlgorithm a, ByteArrayAccess password)
        => a
        -> PRF password
prfHMAC :: forall a password.
(HashAlgorithm a, ByteArrayAccess password) =>
a -> PRF password
prfHMAC a
alg password
k = a -> Context a -> Bytes -> Bytes
forall a. HashAlgorithm a => a -> Context a -> Bytes -> Bytes
hmacIncr a
alg (password -> Context a
forall key a.
(ByteArrayAccess key, HashAlgorithm a) =>
key -> Context a
HMAC.initialize password
k)
  where hmacIncr :: HashAlgorithm a => a -> HMAC.Context a -> (Bytes -> Bytes)
        hmacIncr :: forall a. HashAlgorithm a => a -> Context a -> Bytes -> Bytes
hmacIncr a
_ !Context a
ctx = \Bytes
b -> HMAC a -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (HMAC a -> Bytes) -> HMAC a -> Bytes
forall a b. (a -> b) -> a -> b
$ Context a -> HMAC a
forall a. HashAlgorithm a => Context a -> HMAC a
HMAC.finalize (Context a -> HMAC a) -> Context a -> HMAC a
forall a b. (a -> b) -> a -> b
$ Context a -> Bytes -> Context a
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update Context a
ctx Bytes
b

-- | Parameters for PBKDF2
data Parameters = Parameters
    { Parameters -> Int
iterCounts   :: Int -- ^ the number of user-defined iterations for the algorithms. e.g. WPA2 uses 4000.
    , Parameters -> Int
outputLength :: Int -- ^ the number of bytes to generate out of PBKDF2
    }

-- | generate the pbkdf2 key derivation function from the output
generate :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba)
         => PRF password
         -> Parameters
         -> password
         -> salt
         -> ba
generate :: forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
generate PRF password
prf Parameters
params password
password salt
salt =
    Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) ((Ptr Word8 -> IO ()) -> ba) -> (Ptr Word8 -> IO ()) -> ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> do
        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
p Word8
0 (Parameters -> Int
outputLength Parameters
params)
        Word32 -> Int -> Ptr Word8 -> IO ()
loop Word32
1 (Parameters -> Int
outputLength Parameters
params) Ptr Word8
p
  where
    !runPRF :: Bytes -> Bytes
runPRF = PRF password
prf password
password
    !hLen :: Int
hLen   = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length (Bytes -> Int) -> Bytes -> Int
forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes
runPRF Bytes
forall a. ByteArray a => a
B.empty

    -- run the following f function on each complete chunk.
    -- when having an incomplete chunk, we call partial.
    -- partial need to be the last call.
    --
    -- f(pass,salt,c,i) = U1 xor U2 xor .. xor Uc
    -- U1 = PRF(pass,salt || BE32(i))
    -- Uc = PRF(pass,Uc-1)
    loop :: Word32 -> Int -> Ptr Word8 -> IO ()
loop Word32
iterNb Int
len Ptr Word8
p
        | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0   = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hLen = Word32 -> Int -> Ptr Word8 -> IO ()
partial Word32
iterNb Int
len Ptr Word8
p
        | Bool
otherwise  = do
            let applyMany :: t -> Bytes -> IO ()
applyMany t
0 Bytes
_     = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                applyMany t
i Bytes
uprev = do
                    let uData :: Bytes
uData = Bytes -> Bytes
runPRF Bytes
uprev
                    Bytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. Bytes -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
uData ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
u -> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
p Ptr Word8
p Ptr Word8
u Int
hLen
                    t -> Bytes -> IO ()
applyMany (t
it -> t -> t
forall a. Num a => a -> a -> a
-t
1) Bytes
uData
            Int -> Bytes -> IO ()
forall {t}. (Eq t, Num t) => t -> Bytes -> IO ()
applyMany (Parameters -> Int
iterCounts Parameters
params) (salt -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert salt
salt Bytes -> Bytes -> Bytes
forall bs. ByteArray bs => bs -> bs -> bs
`B.append` Word32 -> Bytes
forall ba. ByteArray ba => Word32 -> ba
toBS Word32
iterNb)
            Word32 -> Int -> Ptr Word8 -> IO ()
loop (Word32
iterNbWord32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+Word32
1) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hLen) (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
hLen)

    partial :: Word32 -> Int -> Ptr Word8 -> IO ()
partial Word32
iterNb Int
len Ptr Word8
p = Int -> Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
hLen Int
8 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
tmp -> do
        let applyMany :: Int -> Bytes -> IO ()
            applyMany :: Int -> Bytes -> IO ()
applyMany Int
0 Bytes
_     = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            applyMany Int
i Bytes
uprev = do
                let uData :: Bytes
uData = Bytes -> Bytes
runPRF Bytes
uprev
                Bytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. Bytes -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
uData ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
u -> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
tmp Ptr Word8
tmp Ptr Word8
u Int
hLen
                Int -> Bytes -> IO ()
applyMany (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Bytes
uData
        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
tmp Word8
0 Int
hLen
        Int -> Bytes -> IO ()
applyMany (Parameters -> Int
iterCounts Parameters
params) (salt -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert salt
salt Bytes -> Bytes -> Bytes
forall bs. ByteArray bs => bs -> bs -> bs
`B.append` Word32 -> Bytes
forall ba. ByteArray ba => Word32 -> ba
toBS Word32
iterNb)
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
p Ptr Word8
tmp Int
len

    -- big endian encoding of Word32
    toBS :: ByteArray ba => Word32 -> ba
    toBS :: forall ba. ByteArray ba => Word32 -> ba
toBS Word32
w = [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack [Word8
a,Word8
b,Word8
c,Word8
d]
      where a :: Word8
a = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)
            b :: Word8
b = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff)
            c :: Word8
c = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff)
            d :: Word8
d = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xff)
{-# NOINLINE generate #-}

fastPBKDF2_SHA1 :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray out)
                => Parameters
                -> password
                -> salt
                -> out
fastPBKDF2_SHA1 :: forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA1 Parameters
params password
password salt
salt =
    Int -> (Ptr Word8 -> IO ()) -> out
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) ((Ptr Word8 -> IO ()) -> out) -> (Ptr Word8 -> IO ()) -> out
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outPtr ->
    password -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. password -> (Ptr p -> IO a) -> IO a
B.withByteArray password
password ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
passPtr ->
    salt -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. salt -> (Ptr p -> IO a) -> IO a
B.withByteArray salt
salt ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
saltPtr ->
        Ptr Word8
-> CSize
-> Ptr Word8
-> CSize
-> CUInt
-> Ptr Word8
-> CSize
-> IO ()
c_cryptonite_fastpbkdf2_hmac_sha1
            Ptr Word8
passPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ password -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length password
password)
            Ptr Word8
saltPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ salt -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length salt
salt)
            (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CUInt) -> Int -> CUInt
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
iterCounts Parameters
params)
            Ptr Word8
outPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
outputLength Parameters
params)

fastPBKDF2_SHA256 :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray out)
                  => Parameters
                  -> password
                  -> salt
                  -> out
fastPBKDF2_SHA256 :: forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA256 Parameters
params password
password salt
salt =
    Int -> (Ptr Word8 -> IO ()) -> out
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) ((Ptr Word8 -> IO ()) -> out) -> (Ptr Word8 -> IO ()) -> out
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outPtr ->
    password -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. password -> (Ptr p -> IO a) -> IO a
B.withByteArray password
password ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
passPtr ->
    salt -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. salt -> (Ptr p -> IO a) -> IO a
B.withByteArray salt
salt ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
saltPtr ->
        Ptr Word8
-> CSize
-> Ptr Word8
-> CSize
-> CUInt
-> Ptr Word8
-> CSize
-> IO ()
c_cryptonite_fastpbkdf2_hmac_sha256
            Ptr Word8
passPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ password -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length password
password)
            Ptr Word8
saltPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ salt -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length salt
salt)
            (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CUInt) -> Int -> CUInt
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
iterCounts Parameters
params)
            Ptr Word8
outPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
outputLength Parameters
params)

fastPBKDF2_SHA512 :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray out)
                  => Parameters
                  -> password
                  -> salt
                  -> out
fastPBKDF2_SHA512 :: forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA512 Parameters
params password
password salt
salt =
    Int -> (Ptr Word8 -> IO ()) -> out
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) ((Ptr Word8 -> IO ()) -> out) -> (Ptr Word8 -> IO ()) -> out
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outPtr ->
    password -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. password -> (Ptr p -> IO a) -> IO a
B.withByteArray password
password ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
passPtr ->
    salt -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. salt -> (Ptr p -> IO a) -> IO a
B.withByteArray salt
salt ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
saltPtr ->
        Ptr Word8
-> CSize
-> Ptr Word8
-> CSize
-> CUInt
-> Ptr Word8
-> CSize
-> IO ()
c_cryptonite_fastpbkdf2_hmac_sha512
            Ptr Word8
passPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ password -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length password
password)
            Ptr Word8
saltPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ salt -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length salt
salt)
            (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CUInt) -> Int -> CUInt
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
iterCounts Parameters
params)
            Ptr Word8
outPtr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
outputLength Parameters
params)


foreign import ccall unsafe "cryptonite_pbkdf2.h cryptonite_fastpbkdf2_hmac_sha1"
    c_cryptonite_fastpbkdf2_hmac_sha1 :: Ptr Word8 -> CSize
                                      -> Ptr Word8 -> CSize
                                      -> CUInt
                                      -> Ptr Word8 -> CSize
                                      -> IO ()

foreign import ccall unsafe "cryptonite_pbkdf2.h cryptonite_fastpbkdf2_hmac_sha256"
    c_cryptonite_fastpbkdf2_hmac_sha256 :: Ptr Word8 -> CSize
                                        -> Ptr Word8 -> CSize
                                        -> CUInt
                                        -> Ptr Word8 -> CSize
                                        -> IO ()

foreign import ccall unsafe "cryptonite_pbkdf2.h cryptonite_fastpbkdf2_hmac_sha512"
    c_cryptonite_fastpbkdf2_hmac_sha512 :: Ptr Word8 -> CSize
                                        -> Ptr Word8 -> CSize
                                        -> CUInt
                                        -> Ptr Word8 -> CSize
                                        -> IO ()