module Database.CQL.Protocol.Header
    ( Header     (..)
    , HeaderType (..)
    , header
    , encodeHeader
    , decodeHeader

      -- ** Length
    , Length     (..)
    , encodeLength
    , decodeLength

      -- ** StreamId
    , StreamId
    , mkStreamId
    , fromStreamId
    , encodeStreamId
    , decodeStreamId

      -- ** Flags
    , Flags
    , compress
    , customPayload
    , tracing
    , warning
    , isSet
    , encodeFlags
    , decodeFlags
    ) where

import Control.Applicative
import Data.Bits
import Data.ByteString.Lazy (ByteString)
import Data.Int
import Data.Monoid hiding ((<>))
import Data.Semigroup
import Data.Serialize
import Data.Word
import Database.CQL.Protocol.Codec
import Database.CQL.Protocol.Types
import Prelude

-- | Protocol frame header.
data Header = Header
    { Header -> HeaderType
headerType :: !HeaderType
    , Header -> Version
version    :: !Version
    , Header -> Flags
flags      :: !Flags
    , Header -> StreamId
streamId   :: !StreamId
    , Header -> OpCode
opCode     :: !OpCode
    , Header -> Length
bodyLength :: !Length
    } deriving Int -> Header -> ShowS
[Header] -> ShowS
Header -> String
(Int -> Header -> ShowS)
-> (Header -> String) -> ([Header] -> ShowS) -> Show Header
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Header -> ShowS
showsPrec :: Int -> Header -> ShowS
$cshow :: Header -> String
show :: Header -> String
$cshowList :: [Header] -> ShowS
showList :: [Header] -> ShowS
Show

data HeaderType
    = RqHeader -- ^ A request frame header.
    | RsHeader -- ^ A response frame header.
    deriving Int -> HeaderType -> ShowS
[HeaderType] -> ShowS
HeaderType -> String
(Int -> HeaderType -> ShowS)
-> (HeaderType -> String)
-> ([HeaderType] -> ShowS)
-> Show HeaderType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HeaderType -> ShowS
showsPrec :: Int -> HeaderType -> ShowS
$cshow :: HeaderType -> String
show :: HeaderType -> String
$cshowList :: [HeaderType] -> ShowS
showList :: [HeaderType] -> ShowS
Show

encodeHeader :: Version -> HeaderType -> Flags -> StreamId -> OpCode -> Length -> PutM ()
encodeHeader :: Version
-> HeaderType -> Flags -> StreamId -> OpCode -> Length -> PutM ()
encodeHeader Version
v HeaderType
t Flags
f StreamId
i OpCode
o Length
l = do
    Putter Word8
encodeByte Putter Word8 -> Putter Word8
forall a b. (a -> b) -> a -> b
$ case HeaderType
t of
        HeaderType
RqHeader -> Version -> Word8
mapVersion Version
v
        HeaderType
RsHeader -> Version -> Word8
mapVersion Version
v Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`setBit` Int
7
    Putter Flags
encodeFlags Flags
f
    Version -> Putter StreamId
encodeStreamId Version
v StreamId
i
    Putter OpCode
encodeOpCode OpCode
o
    Length -> PutM ()
encodeLength Length
l

decodeHeader :: Version -> Get Header
decodeHeader :: Version -> Get Header
decodeHeader Version
v = do
    Word8
b <- Get Word8
getWord8
    HeaderType
-> Version -> Flags -> StreamId -> OpCode -> Length -> Header
Header (Word8 -> HeaderType
mapHeaderType Word8
b)
        (Version -> Flags -> StreamId -> OpCode -> Length -> Header)
-> Get Version
-> Get (Flags -> StreamId -> OpCode -> Length -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word8 -> Get Version
toVersion (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0x7F)
        Get (Flags -> StreamId -> OpCode -> Length -> Header)
-> Get Flags -> Get (StreamId -> OpCode -> Length -> Header)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Flags
decodeFlags
        Get (StreamId -> OpCode -> Length -> Header)
-> Get StreamId -> Get (OpCode -> Length -> Header)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Version -> Get StreamId
decodeStreamId Version
v
        Get (OpCode -> Length -> Header)
-> Get OpCode -> Get (Length -> Header)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get OpCode
decodeOpCode
        Get (Length -> Header) -> Get Length -> Get Header
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Length
decodeLength

mapHeaderType :: Word8 -> HeaderType
mapHeaderType :: Word8 -> HeaderType
mapHeaderType Word8
b = if Word8
b Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
7 then HeaderType
RsHeader else HeaderType
RqHeader

-- | Deserialise a frame header using the version specific decoding format.
header :: Version -> ByteString -> Either String Header
header :: Version -> ByteString -> Either String Header
header Version
v = Get Header -> ByteString -> Either String Header
forall a. Get a -> ByteString -> Either String a
runGetLazy (Version -> Get Header
decodeHeader Version
v)

------------------------------------------------------------------------------
-- Version

mapVersion :: Version -> Word8
mapVersion :: Version -> Word8
mapVersion Version
V4 = Word8
4
mapVersion Version
V3 = Word8
3

toVersion :: Word8 -> Get Version
toVersion :: Word8 -> Get Version
toVersion Word8
3 = Version -> Get Version
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
V3
toVersion Word8
4 = Version -> Get Version
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
V4
toVersion Word8
w = String -> Get Version
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get Version) -> String -> Get Version
forall a b. (a -> b) -> a -> b
$ String
"decode-version: unknown: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> String
forall a. Show a => a -> String
show Word8
w

------------------------------------------------------------------------------
-- Length

-- | The type denoting a protocol frame length.
newtype Length = Length { Length -> Int32
lengthRepr :: Int32 } deriving (Length -> Length -> Bool
(Length -> Length -> Bool)
-> (Length -> Length -> Bool) -> Eq Length
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Length -> Length -> Bool
== :: Length -> Length -> Bool
$c/= :: Length -> Length -> Bool
/= :: Length -> Length -> Bool
Eq, Int -> Length -> ShowS
[Length] -> ShowS
Length -> String
(Int -> Length -> ShowS)
-> (Length -> String) -> ([Length] -> ShowS) -> Show Length
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Length -> ShowS
showsPrec :: Int -> Length -> ShowS
$cshow :: Length -> String
show :: Length -> String
$cshowList :: [Length] -> ShowS
showList :: [Length] -> ShowS
Show)

encodeLength :: Putter Length
encodeLength :: Length -> PutM ()
encodeLength (Length Int32
x) = Putter Int32
encodeInt Int32
x

decodeLength :: Get Length
decodeLength :: Get Length
decodeLength = Int32 -> Length
Length (Int32 -> Length) -> Get Int32 -> Get Length
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Int32
decodeInt

------------------------------------------------------------------------------
-- StreamId

-- | Streams allow multiplexing of requests over a single communication
-- channel. The 'StreamId' correlates 'Request's with 'Response's.
newtype StreamId = StreamId Int16 deriving (StreamId -> StreamId -> Bool
(StreamId -> StreamId -> Bool)
-> (StreamId -> StreamId -> Bool) -> Eq StreamId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: StreamId -> StreamId -> Bool
== :: StreamId -> StreamId -> Bool
$c/= :: StreamId -> StreamId -> Bool
/= :: StreamId -> StreamId -> Bool
Eq, Int -> StreamId -> ShowS
[StreamId] -> ShowS
StreamId -> String
(Int -> StreamId -> ShowS)
-> (StreamId -> String) -> ([StreamId] -> ShowS) -> Show StreamId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> StreamId -> ShowS
showsPrec :: Int -> StreamId -> ShowS
$cshow :: StreamId -> String
show :: StreamId -> String
$cshowList :: [StreamId] -> ShowS
showList :: [StreamId] -> ShowS
Show)

-- | Create a StreamId from the given integral value. In version 2,
-- a StreamId is an 'Int8' and in version 3 an 'Int16'.
mkStreamId :: Integral i => i -> StreamId
mkStreamId :: forall i. Integral i => i -> StreamId
mkStreamId = Int16 -> StreamId
StreamId (Int16 -> StreamId) -> (i -> Int16) -> i -> StreamId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral

-- | Convert the stream ID to an integer.
fromStreamId :: StreamId -> Int
fromStreamId :: StreamId -> Int
fromStreamId (StreamId Int16
i) = Int16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
i

encodeStreamId :: Version -> Putter StreamId
encodeStreamId :: Version -> Putter StreamId
encodeStreamId Version
V4 (StreamId Int16
x) = Putter Int16
encodeSignedShort (Int16 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
x)
encodeStreamId Version
V3 (StreamId Int16
x) = Putter Int16
encodeSignedShort (Int16 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
x)

decodeStreamId :: Version -> Get StreamId
decodeStreamId :: Version -> Get StreamId
decodeStreamId Version
V4 = Int16 -> StreamId
StreamId (Int16 -> StreamId) -> Get Int16 -> Get StreamId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Int16
decodeSignedShort
decodeStreamId Version
V3 = Int16 -> StreamId
StreamId (Int16 -> StreamId) -> Get Int16 -> Get StreamId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Int16
decodeSignedShort

------------------------------------------------------------------------------
-- Flags

-- | Type representing header flags. Flags form a monoid and can be used
-- as in @compress <> tracing <> mempty@.
newtype Flags = Flags Word8 deriving (Flags -> Flags -> Bool
(Flags -> Flags -> Bool) -> (Flags -> Flags -> Bool) -> Eq Flags
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Flags -> Flags -> Bool
== :: Flags -> Flags -> Bool
$c/= :: Flags -> Flags -> Bool
/= :: Flags -> Flags -> Bool
Eq, Int -> Flags -> ShowS
[Flags] -> ShowS
Flags -> String
(Int -> Flags -> ShowS)
-> (Flags -> String) -> ([Flags] -> ShowS) -> Show Flags
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Flags -> ShowS
showsPrec :: Int -> Flags -> ShowS
$cshow :: Flags -> String
show :: Flags -> String
$cshowList :: [Flags] -> ShowS
showList :: [Flags] -> ShowS
Show)

instance Semigroup Flags where
    (Flags Word8
a) <> :: Flags -> Flags -> Flags
<> (Flags Word8
b) = Word8 -> Flags
Flags (Word8
a Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
b)

instance Monoid Flags where
    mempty :: Flags
mempty  = Word8 -> Flags
Flags Word8
0
    mappend :: Flags -> Flags -> Flags
mappend = Flags -> Flags -> Flags
forall a. Semigroup a => a -> a -> a
(<>)

encodeFlags :: Putter Flags
encodeFlags :: Putter Flags
encodeFlags (Flags Word8
x) = Putter Word8
encodeByte Word8
x

decodeFlags :: Get Flags
decodeFlags :: Get Flags
decodeFlags = Word8 -> Flags
Flags (Word8 -> Flags) -> Get Word8 -> Get Flags
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
decodeByte

-- | Compression flag. If set, the frame body is compressed.
compress :: Flags
compress :: Flags
compress = Word8 -> Flags
Flags Word8
1

-- | Tracing flag. If a request support tracing and the tracing flag was set,
-- the response to this request will have the tracing flag set and contain
-- tracing information.
tracing :: Flags
tracing :: Flags
tracing = Word8 -> Flags
Flags Word8
2

customPayload :: Flags
customPayload :: Flags
customPayload = Word8 -> Flags
Flags Word8
4

warning :: Flags
warning :: Flags
warning = Word8 -> Flags
Flags Word8
8

-- | Check if a particular flag is present.
isSet :: Flags -> Flags -> Bool
isSet :: Flags -> Flags -> Bool
isSet (Flags Word8
a) (Flags Word8
b) = Word8
a Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
b Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
a