module Crypto.Util where
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen)
import Data.Bits (shiftL, shiftR)
import Data.Bits (xor, setBit, shiftR, shiftL)
import Control.Exception (Exception, throw)
import Data.Tagged
import System.IO.Unsafe
import Foreign.C.Types
import Foreign.Ptr
incBS :: B.ByteString -> B.ByteString
incBS :: ByteString -> ByteString
incBS ByteString
bs = [ByteString] -> ByteString
B.concat (ByteString -> Int -> [ByteString]
go ByteString
bs (ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
where
go :: ByteString -> Int -> [ByteString]
go ByteString
bs Int
i
| ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
| ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0xFF = (ByteString -> Int -> [ByteString]
go (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
B.init ByteString
bs) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton Word8
0]
| Bool
otherwise = [HasCallStack => ByteString -> ByteString
ByteString -> ByteString
B.init ByteString
bs] [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton (Word8 -> ByteString) -> Word8 -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i) Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1]
{-# INLINE incBS #-}
i2bs :: Int -> Integer -> B.ByteString
i2bs :: Int -> Integer -> ByteString
i2bs Int
l Integer
i = (Int -> Maybe (Word8, Int)) -> Int -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\Int
l' -> if Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Maybe (Word8, Int)
forall a. Maybe a
Nothing else (Word8, Int) -> Maybe (Word8, Int)
forall a. a -> Maybe a
Just (Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
l'), Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8)) (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
8)
{-# INLINE i2bs #-}
i2bs_unsized :: Integer -> B.ByteString
i2bs_unsized :: Integer -> ByteString
i2bs_unsized Integer
0 = Word8 -> ByteString
B.singleton Word8
0
i2bs_unsized Integer
i = ByteString -> ByteString
B.reverse (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (Integer -> Maybe (Word8, Integer)) -> Integer -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\Integer
i' -> if Integer
i' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0 then Maybe (Word8, Integer)
forall a. Maybe a
Nothing else (Word8, Integer) -> Maybe (Word8, Integer)
forall a. a -> Maybe a
Just (Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
i', (Integer
i' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
8))) Integer
i
{-# INLINE i2bs_unsized #-}
throwLeft :: Exception e => Either e a -> a
throwLeft :: forall e a. Exception e => Either e a -> a
throwLeft (Left e
e) = e -> a
forall a e. Exception e => e -> a
throw e
e
throwLeft (Right a
a) = a
a
for :: Tagged a b -> a -> b
for :: forall a b. Tagged a b -> a -> b
for Tagged a b
t a
_ = Tagged a b -> b
forall {k} (s :: k) b. Tagged s b -> b
unTagged Tagged a b
t
(.::.) :: Tagged a b -> a -> b
.::. :: forall a b. Tagged a b -> a -> b
(.::.) = Tagged a b -> a -> b
forall a b. Tagged a b -> a -> b
for
constTimeEq :: B.ByteString -> B.ByteString -> Bool
constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq ByteString
s1 ByteString
s2 =
IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
ByteString -> (CStringLen -> IO Bool) -> IO Bool
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s1 ((CStringLen -> IO Bool) -> IO Bool)
-> (CStringLen -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s1_ptr, Int
s1_len) ->
ByteString -> (CStringLen -> IO Bool) -> IO Bool
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s2 ((CStringLen -> IO Bool) -> IO Bool)
-> (CStringLen -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s2_ptr, Int
s2_len) ->
if Int
s1_len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
s2_len
then Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
else (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0) (CInt -> Bool) -> IO CInt -> IO Bool
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr CChar -> Ptr CChar -> CInt -> IO CInt
c_constTimeEq Ptr CChar
s1_ptr Ptr CChar
s2_ptr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s1_len)
foreign import ccall unsafe
c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
bs2i :: B.ByteString -> Integer
bs2i :: ByteString -> Integer
bs2i ByteString
bs = (Integer -> Word8 -> Integer) -> Integer -> ByteString -> Integer
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
B.foldl' (\Integer
i Word8
b -> (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Word8 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Integer
0 ByteString
bs
{-# INLINE bs2i #-}
zwp' :: B.ByteString -> B.ByteString -> B.ByteString
zwp' :: ByteString -> ByteString -> ByteString
zwp' ByteString
a = [Word8] -> ByteString
B.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a
{-# INLINE zwp' #-}
zwp :: L.ByteString -> L.ByteString -> L.ByteString
zwp :: ByteString -> ByteString -> ByteString
zwp ByteString
a ByteString
b =
let as :: [ByteString]
as = ByteString -> [ByteString]
L.toChunks ByteString
a
bs :: [ByteString]
bs = ByteString -> [ByteString]
L.toChunks ByteString
b
in [ByteString] -> ByteString
L.fromChunks ([ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as [ByteString]
bs)
where
go :: [ByteString] -> [ByteString] -> [ByteString]
go [] [ByteString]
_ = []
go [ByteString]
_ [] = []
go (ByteString
a:[ByteString]
as) (ByteString
b:[ByteString]
bs) =
let l :: Int
l = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (ByteString -> Int
B.length ByteString
a) (ByteString -> Int
B.length ByteString
b)
(ByteString
a',ByteString
ar) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
a
(ByteString
b',ByteString
br) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
b
as' :: [ByteString]
as' = if ByteString -> Int
B.length ByteString
ar Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then [ByteString]
as else ByteString
ar ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
as
bs' :: [ByteString]
bs' = if ByteString -> Int
B.length ByteString
br Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then [ByteString]
bs else ByteString
br ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bs
in (ByteString -> ByteString -> ByteString
zwp' ByteString
a' ByteString
b') ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as' [ByteString]
bs'
{-# INLINEABLE zwp #-}
collect :: Int -> [B.ByteString] -> [B.ByteString]
collect :: Int -> [ByteString] -> [ByteString]
collect Int
0 [ByteString]
_ = []
collect Int
_ [] = []
collect Int
i (ByteString
b:[ByteString]
bs)
| Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int -> [ByteString] -> [ByteString]
collect (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len) [ByteString]
bs
| Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i = [Int -> ByteString -> ByteString
B.take Int
i ByteString
b]
where
len :: Int
len = ByteString -> Int
B.length ByteString
b
{-# INLINE collect #-}