{-# LANGUAGE FlexibleContexts #-}

module Network.TLS.Receiving (
    processPacket,
    processPacket13,
) where

import Control.Concurrent.MVar
import Control.Monad.State.Strict

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Util
import Network.TLS.Wire

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet
AppData (ByteString -> Packet) -> ByteString -> Packet
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
processPacket Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet
Alert ([(AlertLevel, AlertDescription)] -> Packet)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
processPacket Context
ctx (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError ()
decodeChangeCipherSpec (ByteString -> Either TLSError ())
-> ByteString -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
        Right ()
_ -> do
            Context -> IO ()
switchRxEncryption Context
ctx
            Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right Packet
ChangeCipherSpec
processPacket Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
ver Fragment Plaintext
fragment) = do
    Maybe CipherKeyExchangeType
keyxchg <-
        Context -> IO (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx IO (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe CipherKeyExchangeType))
-> IO (Maybe CipherKeyExchangeType)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe HandshakeState
hs -> Maybe CipherKeyExchangeType -> IO (Maybe CipherKeyExchangeType)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState
hs Maybe HandshakeState
-> (HandshakeState -> Maybe Cipher) -> Maybe Cipher
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> Maybe Cipher
hstPendingCipher Maybe Cipher
-> (Cipher -> Maybe CipherKeyExchangeType)
-> Maybe CipherKeyExchangeType
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CipherKeyExchangeType -> Maybe CipherKeyExchangeType
forall a. a -> Maybe a
Just (CipherKeyExchangeType -> Maybe CipherKeyExchangeType)
-> (Cipher -> CipherKeyExchangeType)
-> Cipher
-> Maybe CipherKeyExchangeType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherKeyExchangeType
cipherKeyExchange)
    Context -> TLSSt Packet -> IO (Either TLSError Packet)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet -> IO (Either TLSError Packet))
-> TLSSt Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ do
        let currentParams :: CurrentParams
currentParams =
                CurrentParams
                    { cParamsVersion :: Version
cParamsVersion = Version
ver
                    , cParamsKeyXchgType :: Maybe CipherKeyExchangeType
cParamsKeyXchgType = Maybe CipherKeyExchangeType
keyxchg
                    }
        -- get back the optional continuation, and parse as many handshake record as possible.
        Maybe (GetContinuation (HandshakeType, ByteString))
mCont <- (TLSState -> Maybe (GetContinuation (HandshakeType, ByteString)))
-> TLSSt (Maybe (GetContinuation (HandshakeType, ByteString)))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont
        (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stHandshakeRecordCont = Nothing})
        [Handshake]
hss <- CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> TLSSt [Handshake]
forall {m :: * -> *}.
(MonadError TLSError m, MonadState TLSState m) =>
CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
        Packet -> TLSSt Packet
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return (Packet -> TLSSt Packet) -> Packet -> TLSSt Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake]
hss
  where
    parseMany :: CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs =
        case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
            GotError TLSError
err -> TLSError -> m [Handshake]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
            GotPartial GetContinuation (HandshakeType, ByteString)
cont ->
                (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stHandshakeRecordCont = Just cont}) m () -> m [Handshake] -> m [Handshake]
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Handshake] -> m [Handshake]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
            GotSuccess (HandshakeType
ty, ByteString
content) ->
                (TLSError -> m [Handshake])
-> (Handshake -> m [Handshake])
-> Either TLSError Handshake
-> m [Handshake]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> m [Handshake]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Handshake] -> m [Handshake]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Handshake] -> m [Handshake])
-> (Handshake -> [Handshake]) -> Handshake -> m [Handshake]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Handshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
: [])) (Either TLSError Handshake -> m [Handshake])
-> Either TLSError Handshake -> m [Handshake]
forall a b. (a -> b) -> a -> b
$ CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content
            GotSuccessRemaining (HandshakeType
ty, ByteString
content) ByteString
left ->
                case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
                    Left TLSError
err -> TLSError -> m [Handshake]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                    Right Handshake
hh -> (Handshake
hh Handshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:) ([Handshake] -> [Handshake]) -> m [Handshake] -> m [Handshake]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing ByteString
left
processPacket Context
_ Record Plaintext
_ = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing String
"unknown protocol type")

switchRxEncryption :: Context -> IO ()
switchRxEncryption :: Context -> IO ()
switchRxEncryption Context
ctx =
    Context -> HandshakeM (Maybe RecordState) -> IO (Maybe RecordState)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx ((HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingRxState) IO (Maybe RecordState) -> (Maybe RecordState -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe RecordState
rx ->
        MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxRxRecordState Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState -> IO RecordState) -> RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ Maybe RecordState -> RecordState
forall a. HasCallStack => Maybe a -> a
fromJust Maybe RecordState
rx)

----------------------------------------------------------------

processPacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 Context
_ (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
_) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ Packet13 -> Either TLSError Packet13
forall a b. b -> Either a b
Right Packet13
ChangeCipherSpec13
processPacket13 Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ Packet13 -> Either TLSError Packet13
forall a b. b -> Either a b
Right (Packet13 -> Either TLSError Packet13)
-> Packet13 -> Either TLSError Packet13
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet13
AppData13 (ByteString -> Packet13) -> ByteString -> Packet13
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
processPacket13 Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet13
Alert13 ([(AlertLevel, AlertDescription)] -> Packet13)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet13
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
processPacket13 Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
_ Fragment Plaintext
fragment) = Context -> TLSSt Packet13 -> IO (Either TLSError Packet13)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet13 -> IO (Either TLSError Packet13))
-> TLSSt Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ do
    Maybe (GetContinuation (HandshakeType, ByteString))
mCont <- (TLSState -> Maybe (GetContinuation (HandshakeType, ByteString)))
-> TLSSt (Maybe (GetContinuation (HandshakeType, ByteString)))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont13
    (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stHandshakeRecordCont13 = Nothing})
    [Handshake13]
hss <- Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString -> TLSSt [Handshake13]
forall {m :: * -> *}.
(MonadError TLSError m, MonadState TLSState m) =>
Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType, ByteString))
mCont (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
    Packet13 -> TLSSt Packet13
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return (Packet13 -> TLSSt Packet13) -> Packet13 -> TLSSt Packet13
forall a b. (a -> b) -> a -> b
$ [Handshake13] -> Packet13
Handshake13 [Handshake13]
hss
  where
    parseMany :: Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs =
        case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord13 Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
            GotError TLSError
err -> TLSError -> m [Handshake13]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
            GotPartial GetContinuation (HandshakeType, ByteString)
cont ->
                (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stHandshakeRecordCont13 = Just cont}) m () -> m [Handshake13] -> m [Handshake13]
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Handshake13] -> m [Handshake13]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
            GotSuccess (HandshakeType
ty, ByteString
content) ->
                (TLSError -> m [Handshake13])
-> (Handshake13 -> m [Handshake13])
-> Either TLSError Handshake13
-> m [Handshake13]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> m [Handshake13]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Handshake13] -> m [Handshake13]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Handshake13] -> m [Handshake13])
-> (Handshake13 -> [Handshake13]) -> Handshake13 -> m [Handshake13]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Handshake13 -> [Handshake13] -> [Handshake13]
forall a. a -> [a] -> [a]
: [])) (Either TLSError Handshake13 -> m [Handshake13])
-> Either TLSError Handshake13 -> m [Handshake13]
forall a b. (a -> b) -> a -> b
$ HandshakeType -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType
ty ByteString
content
            GotSuccessRemaining (HandshakeType
ty, ByteString
content) ByteString
left ->
                case HandshakeType -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType
ty ByteString
content of
                    Left TLSError
err -> TLSError -> m [Handshake13]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                    Right Handshake13
hh -> (Handshake13
hh Handshake13 -> [Handshake13] -> [Handshake13]
forall a. a -> [a] -> [a]
:) ([Handshake13] -> [Handshake13])
-> m [Handshake13] -> m [Handshake13]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing ByteString
left
processPacket13 Context
_ Record Plaintext
_ = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet13
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing String
"unknown protocol type")