-- |A small selection of utilities that might be of use to others working with bytestring/number combinations.
module Crypto.Util where

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen)
import Data.Bits (shiftL, shiftR)
import Data.Bits (xor, setBit, shiftR, shiftL)
import Control.Exception (Exception, throw)
import Data.Tagged
import System.IO.Unsafe
import Foreign.C.Types
import Foreign.Ptr

-- |@incBS bs@ inefficiently computes the value @i2bs (8 * B.length bs) (bs2i bs + 1)@
incBS :: B.ByteString -> B.ByteString
incBS :: ByteString -> ByteString
incBS ByteString
bs = [ByteString] -> ByteString
B.concat (ByteString -> Int -> [ByteString]
go ByteString
bs (ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
  where
  go :: ByteString -> Int -> [ByteString]
go ByteString
bs Int
i
        | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     = []
        | ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0xFF = (ByteString -> Int -> [ByteString]
go (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
B.init ByteString
bs) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton Word8
0]
        | Bool
otherwise            = [HasCallStack => ByteString -> ByteString
ByteString -> ByteString
B.init ByteString
bs] [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton (Word8 -> ByteString) -> Word8 -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i) Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1]
{-# INLINE incBS #-}


-- |@i2bs bitLen i@ converts @i@ to a 'ByteString' of @bitLen@ bits (must be a multiple of 8).
i2bs :: Int -> Integer -> B.ByteString
i2bs :: Int -> Integer -> ByteString
i2bs Int
l Integer
i = (Int -> Maybe (Word8, Int)) -> Int -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\Int
l' -> if Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Maybe (Word8, Int)
forall a. Maybe a
Nothing else (Word8, Int) -> Maybe (Word8, Int)
forall a. a -> Maybe a
Just (Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
l'), Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8)) (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
8)
{-# INLINE i2bs #-}

-- |@i2bs_unsized i@ converts @i@ to a 'ByteString' of sufficient bytes to express the integer.
-- The integer must be non-negative and a zero will be encoded in one byte.
i2bs_unsized :: Integer -> B.ByteString
i2bs_unsized :: Integer -> ByteString
i2bs_unsized Integer
0 = Word8 -> ByteString
B.singleton Word8
0
i2bs_unsized Integer
i = ByteString -> ByteString
B.reverse (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (Integer -> Maybe (Word8, Integer)) -> Integer -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\Integer
i' -> if Integer
i' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0 then Maybe (Word8, Integer)
forall a. Maybe a
Nothing else (Word8, Integer) -> Maybe (Word8, Integer)
forall a. a -> Maybe a
Just (Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
i', (Integer
i' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
8))) Integer
i
{-# INLINE i2bs_unsized #-}

-- | Useful utility to extract the result of a generator operation
-- and translate error results to exceptions.
throwLeft :: Exception e => Either e a -> a
throwLeft :: forall e a. Exception e => Either e a -> a
throwLeft (Left e
e)  = e -> a
forall a e. Exception e => e -> a
throw e
e
throwLeft (Right a
a) = a
a

-- |Obtain a tagged value for a particular instantiated type.
for :: Tagged a b -> a -> b
for :: forall a b. Tagged a b -> a -> b
for Tagged a b
t a
_ = Tagged a b -> b
forall {k} (s :: k) b. Tagged s b -> b
unTagged Tagged a b
t

-- |Infix `for` operator
(.::.) :: Tagged a b -> a -> b
.::. :: forall a b. Tagged a b -> a -> b
(.::.) = Tagged a b -> a -> b
forall a b. Tagged a b -> a -> b
for

-- | Checks two bytestrings for equality without breaches for
-- timing attacks.
--
-- Semantically, @constTimeEq = (==)@.  However, @x == y@ takes less
-- time when the first byte is different than when the first byte
-- is equal.  This side channel allows an attacker to mount a
-- timing attack.  On the other hand, @constTimeEq@ always takes the
-- same time regardless of the bytestrings' contents, unless they are
-- of difference size.
--
-- You should always use @constTimeEq@ when comparing secrets,
-- otherwise you may leave a significant security hole
-- (cf. <http://codahale.com/a-lesson-in-timing-attacks/>).
constTimeEq :: B.ByteString -> B.ByteString -> Bool
constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq ByteString
s1 ByteString
s2 =
    IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    ByteString -> (CStringLen -> IO Bool) -> IO Bool
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s1 ((CStringLen -> IO Bool) -> IO Bool)
-> (CStringLen -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s1_ptr, Int
s1_len) ->
    ByteString -> (CStringLen -> IO Bool) -> IO Bool
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s2 ((CStringLen -> IO Bool) -> IO Bool)
-> (CStringLen -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s2_ptr, Int
s2_len) ->
    if Int
s1_len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
s2_len
      then Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      else (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0) (CInt -> Bool) -> IO CInt -> IO Bool
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr CChar -> Ptr CChar -> CInt -> IO CInt
c_constTimeEq Ptr CChar
s1_ptr Ptr CChar
s2_ptr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s1_len)

foreign import ccall unsafe
   c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt

-- |Helper function to convert bytestrings to integers
bs2i :: B.ByteString -> Integer
bs2i :: ByteString -> Integer
bs2i ByteString
bs = (Integer -> Word8 -> Integer) -> Integer -> ByteString -> Integer
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
B.foldl' (\Integer
i Word8
b -> (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Word8 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Integer
0 ByteString
bs
{-# INLINE bs2i #-}

-- |zipWith xor + Pack
-- As a result of rewrite rules, this should automatically be
-- optimized (at compile time). to use the bytestring libraries
-- 'zipWith'' function.
zwp' :: B.ByteString -> B.ByteString -> B.ByteString
zwp' :: ByteString -> ByteString -> ByteString
zwp' ByteString
a = [Word8] -> ByteString
B.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a
{-# INLINE zwp' #-}

-- |zipWith xor + Pack
--
-- This is written intentionally to take advantage
-- of the bytestring libraries 'zipWith'' rewrite rule but at the
-- extra cost of the resulting lazy bytestring being more fragmented
-- than either of the two inputs.
zwp :: L.ByteString -> L.ByteString -> L.ByteString
zwp :: ByteString -> ByteString -> ByteString
zwp  ByteString
a ByteString
b = 
        let as :: [ByteString]
as = ByteString -> [ByteString]
L.toChunks ByteString
a
            bs :: [ByteString]
bs = ByteString -> [ByteString]
L.toChunks ByteString
b
        in [ByteString] -> ByteString
L.fromChunks ([ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as [ByteString]
bs)
  where
  go :: [ByteString] -> [ByteString] -> [ByteString]
go [] [ByteString]
_ = []
  go [ByteString]
_ [] = []
  go (ByteString
a:[ByteString]
as) (ByteString
b:[ByteString]
bs) =
        let l :: Int
l = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (ByteString -> Int
B.length ByteString
a) (ByteString -> Int
B.length ByteString
b)
            (ByteString
a',ByteString
ar) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
a
            (ByteString
b',ByteString
br) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
b
            as' :: [ByteString]
as' = if ByteString -> Int
B.length ByteString
ar Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then [ByteString]
as else ByteString
ar ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
as
            bs' :: [ByteString]
bs' = if ByteString -> Int
B.length ByteString
br Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then [ByteString]
bs else ByteString
br ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bs
        in (ByteString -> ByteString -> ByteString
zwp' ByteString
a' ByteString
b') ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as' [ByteString]
bs'
{-# INLINEABLE zwp #-}

-- gather a specified number of bytes from the list of bytestrings
collect :: Int -> [B.ByteString] -> [B.ByteString]
collect :: Int -> [ByteString] -> [ByteString]
collect Int
0 [ByteString]
_ = []
collect Int
_ [] = []
collect Int
i (ByteString
b:[ByteString]
bs)
        | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i  = ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int -> [ByteString] -> [ByteString]
collect (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len) [ByteString]
bs
        | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i = [Int -> ByteString -> ByteString
B.take Int
i ByteString
b]
  where
  len :: Int
len = ByteString -> Int
B.length ByteString
b
{-# INLINE collect #-}