-- |
-- Module      : Data.Byteable
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : good
--
module Data.Byteable
    ( Byteable(..)
    , constEqBytes
    ) where

import Foreign.Ptr (Ptr, plusPtr)
import Foreign.ForeignPtr (withForeignPtr)
import Data.ByteString (ByteString)
import Data.List (foldl')
import Data.Word (Word8)
import qualified Data.ByteString as B (length, zipWith)
import qualified Data.ByteString.Internal as B (toForeignPtr)

-- | Class of things that can generate sequence of bytes
class Byteable a where
    -- | Convert a byteable type to a bytestring
    toBytes        :: a -> ByteString

    -- | Return the size of the byteable .
    byteableLength :: a -> Int
    byteableLength = ByteString -> Int
B.length (ByteString -> Int) -> (a -> ByteString) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ByteString
forall a. Byteable a => a -> ByteString
toBytes

    -- | Provide a way to look at the data of a byteable type with a ptr.
    withBytePtr :: a -> (Ptr Word8 -> IO b) -> IO b
    withBytePtr a
a Ptr Word8 -> IO b
f = ForeignPtr Word8 -> (Ptr Word8 -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO b) -> IO b) -> (Ptr Word8 -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Word8 -> IO b
f (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)
      where (ForeignPtr Word8
fptr, Int
off, Int
_) = ByteString -> (ForeignPtr Word8, Int, Int)
B.toForeignPtr (ByteString -> (ForeignPtr Word8, Int, Int))
-> ByteString -> (ForeignPtr Word8, Int, Int)
forall a b. (a -> b) -> a -> b
$ a -> ByteString
forall a. Byteable a => a -> ByteString
toBytes a
a

instance Byteable ByteString where
    toBytes :: ByteString -> ByteString
toBytes ByteString
bs = ByteString
bs

-- | A constant time equality test for 2 byteable objects.
--
-- If objects are of 2 different sizes, the function will abort early
-- without comparing any bytes.
--
-- compared to == , this function will go over all the bytes
-- present before yielding a result even when knowing the
-- overall result early in the processing.
constEqBytes :: Byteable a => a -> a -> Bool
constEqBytes :: forall a. Byteable a => a -> a -> Bool
constEqBytes a
a a
b = ByteString -> ByteString -> Bool
constEqByteString (a -> ByteString
forall a. Byteable a => a -> ByteString
toBytes a
a) (a -> ByteString
forall a. Byteable a => a -> ByteString
toBytes a
b)
{-# RULES "constEqBytes/ByteString" constEqBytes = constEqByteString #-}

{-# INLINE constEqByteString #-}
constEqByteString :: ByteString -> ByteString -> Bool
constEqByteString :: ByteString -> ByteString -> Bool
constEqByteString ByteString
a ByteString
b
    | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
B.length ByteString
b = Bool
False
    | Bool
otherwise         = (Bool -> Bool -> Bool) -> Bool -> [Bool] -> Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Bool -> Bool -> Bool
(&&!) Bool
True ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Bool) -> ByteString -> ByteString -> [Bool]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
(==) ByteString
a ByteString
b
  where len :: Int
len = ByteString -> Int
B.length ByteString
a

        (&&!) :: Bool -> Bool -> Bool
        Bool
True  &&! :: Bool -> Bool -> Bool
&&! Bool
True  = Bool
True
        Bool
True  &&! Bool
False = Bool
False
        Bool
False &&! Bool
True  = Bool
False
        Bool
False &&! Bool
False = Bool
False