-- |
-- Module      : Crypto.Store.Util
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
module Crypto.Store.Util
    ( (&&!)
    , reverseBytes
    , constAllEq
    , mapLeft
    , mapAsWord64LE
    ) where

import           Data.Bits
import           Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B
import           Data.List
import           Data.Memory.Endian
import           Data.Word

import           Foreign.Ptr (plusPtr)
import           Foreign.Storable

import GHC.Exts

-- | This is a strict version of &&.
(&&!) :: Bool -> Bool -> Bool
&&! :: Bool -> Bool -> Bool
(&&!) Bool
x Bool
y = Int# -> Bool
isTrue# (Int# -> Int# -> Int#
andI# (Bool -> Int#
forall {a}. a -> Int#
getTag# Bool
x) (Bool -> Int#
forall {a}. a -> Int#
getTag# Bool
y))
  where getTag# :: a -> Int#
getTag# !a
z = a -> Int#
forall {a}. a -> Int#
dataToTag# a
z
infixr 3 &&!

-- | Reverse a bytearray.
reverseBytes :: ByteArray ba => ba -> ba
#if MIN_VERSION_memory(0,14,18)
reverseBytes :: forall ba. ByteArray ba => ba -> ba
reverseBytes = ba -> ba
forall ba. ByteArray ba => ba -> ba
B.reverse
#else
reverseBytes = B.pack . reverse . B.unpack
#endif

-- | Test if all bytes in a bytearray are equal to the value specified.  Runs in
-- constant time.
constAllEq :: ByteArrayAccess ba => Word8 -> ba -> Bool
constAllEq :: forall ba. ByteArrayAccess ba => Word8 -> ba -> Bool
constAllEq Word8
b = (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0) (Word8 -> Bool) -> (ba -> Word8) -> ba -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Word8 -> Word8) -> Word8 -> [Word8] -> Word8
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Word8 -> Word8 -> Word8
fn Word8
0 ([Word8] -> Word8) -> (ba -> [Word8]) -> ba -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ba -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack
  where fn :: Word8 -> Word8 -> Word8
fn Word8
acc Word8
x = Word8
acc Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
b Word8
x

-- | Map over the left value.
mapLeft :: (a -> b) -> Either a c -> Either b c
mapLeft :: forall a b c. (a -> b) -> Either a c -> Either b c
mapLeft a -> b
f (Left a
a)  = b -> Either b c
forall a b. a -> Either a b
Left (a -> b
f a
a)
mapLeft a -> b
_ (Right c
c) = c -> Either b c
forall a b. b -> Either a b
Right c
c

-- | Same as 'Data.ByteArray.Mapping.mapAsWord64' but with little-endian words.
mapAsWord64LE :: ByteArray bs => (Word64 -> Word64) -> bs -> bs
mapAsWord64LE :: forall bs. ByteArray bs => (Word64 -> Word64) -> bs -> bs
mapAsWord64LE Word64 -> Word64
f bs
bs =
    Int -> (Ptr (LE Word64) -> IO ()) -> bs
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
len ((Ptr (LE Word64) -> IO ()) -> bs)
-> (Ptr (LE Word64) -> IO ()) -> bs
forall a b. (a -> b) -> a -> b
$ \Ptr (LE Word64)
dst ->
        bs -> (Ptr (LE Word64) -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. bs -> (Ptr p -> IO a) -> IO a
B.withByteArray bs
bs ((Ptr (LE Word64) -> IO ()) -> IO ())
-> (Ptr (LE Word64) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (LE Word64)
src ->
            Int -> Ptr (LE Word64) -> Ptr (LE Word64) -> IO ()
loop (Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) Ptr (LE Word64)
dst Ptr (LE Word64)
src
  where
        len :: Int
len = bs -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs
bs

        loop :: Int -> Ptr (LE Word64) -> Ptr (LE Word64) -> IO ()
        loop :: Int -> Ptr (LE Word64) -> Ptr (LE Word64) -> IO ()
loop Int
0 Ptr (LE Word64)
_ Ptr (LE Word64)
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        loop Int
i Ptr (LE Word64)
d Ptr (LE Word64)
s = do
            LE Word64
w <- Ptr (LE Word64) -> IO (LE Word64)
forall a. Storable a => Ptr a -> IO a
peek Ptr (LE Word64)
s
            let r :: Word64
r = Word64 -> Word64
f (LE Word64 -> Word64
forall a. ByteSwap a => LE a -> a
fromLE LE Word64
w)
            Ptr (LE Word64) -> LE Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (LE Word64)
d (Word64 -> LE Word64
forall a. ByteSwap a => a -> LE a
toLE Word64
r)
            Int -> Ptr (LE Word64) -> Ptr (LE Word64) -> IO ()
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Ptr (LE Word64)
d Ptr (LE Word64) -> Int -> Ptr (LE Word64)
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8) (Ptr (LE Word64)
s Ptr (LE Word64) -> Int -> Ptr (LE Word64)
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8)