-- |
-- Module      : Crypto.Internal.Builder
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : stable
-- Portability : Good
--
-- Delaying and merging ByteArray allocations.  This is similar to module
-- "Data.ByteArray.Pack" except the total length is computed automatically based
-- on what is appended.
--
{-# LANGUAGE BangPatterns #-}
module Crypto.Internal.Builder
    ( Builder
    , buildAndFreeze
    , builderLength
    , byte
    , bytes
    , zero
    ) where

import           Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B
import           Data.Memory.PtrMethods (memSet)

import           Foreign.Ptr (Ptr, plusPtr)
import           Foreign.Storable (poke)

import           Crypto.Internal.Imports hiding (empty)

data Builder =  Builder !Int (Ptr Word8 -> IO ())  -- size and initializer

instance Semigroup Builder where
    (Builder Int
s1 Ptr Word8 -> IO ()
f1) <> :: Builder -> Builder -> Builder
<> (Builder Int
s2 Ptr Word8 -> IO ()
f2) = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder (Int
s1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s2) Ptr Word8 -> IO ()
f
      where f :: Ptr Word8 -> IO ()
f Ptr Word8
p = Ptr Word8 -> IO ()
f1 Ptr Word8
p IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> IO ()
f2 (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s1)

builderLength :: Builder -> Int
builderLength :: Builder -> Int
builderLength (Builder Int
s Ptr Word8 -> IO ()
_) = Int
s

buildAndFreeze :: ByteArray ba => Builder -> ba
buildAndFreeze :: forall ba. ByteArray ba => Builder -> ba
buildAndFreeze (Builder Int
s Ptr Word8 -> IO ()
f) = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
s Ptr Word8 -> IO ()
f

byte :: Word8 -> Builder
byte :: Word8 -> Builder
byte !Word8
b = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder Int
1 (Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
`poke` Word8
b)

bytes :: ByteArrayAccess ba => ba -> Builder
bytes :: forall ba. ByteArrayAccess ba => ba -> Builder
bytes ba
bs = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs) (ba -> Ptr Word8 -> IO ()
forall p. ba -> Ptr p -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
B.copyByteArrayToPtr ba
bs)

zero :: Int -> Builder
zero :: Int -> Builder
zero Int
s = if Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int -> (Ptr Word8 -> IO ()) -> Builder
Builder Int
s (\Ptr Word8
p -> Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
p Word8
0 Int
s) else Builder
empty

empty :: Builder
empty :: Builder
empty = Int -> (Ptr Word8 -> IO ()) -> Builder
Builder Int
0 (IO () -> Ptr Word8 -> IO ()
forall a b. a -> b -> a
const (IO () -> Ptr Word8 -> IO ()) -> IO () -> Ptr Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())