{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.State (
    HandshakeState (..),
    HandshakeDigest (..),
    HandshakeMode13 (..),
    RTT0Status (..),
    CertReqCBdata,
    HandshakeM,
    newEmptyHandshake,
    runHandshake,

    -- * key accessors
    setPublicKey,
    setPublicPrivateKeys,
    getLocalPublicPrivateKeys,
    getRemotePublicKey,
    setServerDHParams,
    getServerDHParams,
    setServerECDHParams,
    getServerECDHParams,
    setDHPrivate,
    getDHPrivate,
    setGroupPrivate,
    getGroupPrivate,

    -- * cert accessors
    setClientCertSent,
    getClientCertSent,
    setCertReqSent,
    getCertReqSent,
    setClientCertChain,
    getClientCertChain,
    setCertReqToken,
    getCertReqToken,
    setCertReqCBdata,
    getCertReqCBdata,
    setCertReqSigAlgsCert,
    getCertReqSigAlgsCert,

    -- * digest accessors
    addHandshakeMessage,
    updateHandshakeDigest,
    getHandshakeMessages,
    getHandshakeMessagesRev,
    getHandshakeDigest,
    foldHandshakeDigest,

    -- * main secret
    setMainSecret,
    setMainSecretFromPre,

    -- * misc accessor
    getPendingCipher,
    setServerHelloParameters,
    setExtendedMainSecret,
    getExtendedMainSecret,
    setSupportedGroup,
    getSupportedGroup,
    setTLS13HandshakeMode,
    getTLS13HandshakeMode,
    setTLS13RTT0Status,
    getTLS13RTT0Status,
    setTLS13EarlySecret,
    getTLS13EarlySecret,
    setTLS13ResumptionSecret,
    getTLS13ResumptionSecret,
    setCCS13Sent,
    getCCS13Sent,
) where

import Control.Monad.State.Strict
import Data.ByteArray (ByteArrayAccess)
import Data.X509 (CertificateChain)

import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Crypto
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types
import Network.TLS.Util

data HandshakeKeyState = HandshakeKeyState
    { HandshakeKeyState -> Maybe PubKey
hksRemotePublicKey :: Maybe PubKey
    , HandshakeKeyState -> Maybe (PubKey, PrivKey)
hksLocalPublicPrivateKeys :: Maybe (PubKey, PrivKey)
    }
    deriving (Int -> HandshakeKeyState -> ShowS
[HandshakeKeyState] -> ShowS
HandshakeKeyState -> String
(Int -> HandshakeKeyState -> ShowS)
-> (HandshakeKeyState -> String)
-> ([HandshakeKeyState] -> ShowS)
-> Show HandshakeKeyState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeKeyState -> ShowS
showsPrec :: Int -> HandshakeKeyState -> ShowS
$cshow :: HandshakeKeyState -> String
show :: HandshakeKeyState -> String
$cshowList :: [HandshakeKeyState] -> ShowS
showList :: [HandshakeKeyState] -> ShowS
Show)

data HandshakeDigest
    = HandshakeMessages [ByteString]
    | HandshakeDigestContext HashCtx
    deriving (Int -> HandshakeDigest -> ShowS
[HandshakeDigest] -> ShowS
HandshakeDigest -> String
(Int -> HandshakeDigest -> ShowS)
-> (HandshakeDigest -> String)
-> ([HandshakeDigest] -> ShowS)
-> Show HandshakeDigest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeDigest -> ShowS
showsPrec :: Int -> HandshakeDigest -> ShowS
$cshow :: HandshakeDigest -> String
show :: HandshakeDigest -> String
$cshowList :: [HandshakeDigest] -> ShowS
showList :: [HandshakeDigest] -> ShowS
Show)

data HandshakeState = HandshakeState
    { HandshakeState -> Version
hstClientVersion :: Version
    , HandshakeState -> ClientRandom
hstClientRandom :: ClientRandom
    , HandshakeState -> Maybe ServerRandom
hstServerRandom :: Maybe ServerRandom
    , HandshakeState -> Maybe ByteString
hstMainSecret :: Maybe ByteString
    , HandshakeState -> HandshakeKeyState
hstKeyState :: HandshakeKeyState
    , HandshakeState -> Maybe ServerDHParams
hstServerDHParams :: Maybe ServerDHParams
    , HandshakeState -> Maybe DHPrivate
hstDHPrivate :: Maybe DHPrivate
    , HandshakeState -> Maybe ServerECDHParams
hstServerECDHParams :: Maybe ServerECDHParams
    , HandshakeState -> Maybe GroupPrivate
hstGroupPrivate :: Maybe GroupPrivate
    , HandshakeState -> HandshakeDigest
hstHandshakeDigest :: HandshakeDigest
    , HandshakeState -> [ByteString]
hstHandshakeMessages :: [ByteString]
    , HandshakeState -> Maybe ByteString
hstCertReqToken :: Maybe ByteString
    -- ^ Set to Just-value when a TLS13 certificate request is received
    , HandshakeState -> Maybe CertReqCBdata
hstCertReqCBdata :: Maybe CertReqCBdata
    -- ^ Set to Just-value when a certificate request is received
    , HandshakeState -> Maybe [HashAndSignatureAlgorithm]
hstCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm]
    -- ^ In TLS 1.3, these are separate from the certificate
    -- issuer signature algorithm hints in the callback data.
    -- In TLS 1.2 the same list is overloaded for both purposes.
    -- Not present in TLS 1.1 and earlier
    , HandshakeState -> Bool
hstClientCertSent :: Bool
    -- ^ Set to true when a client certificate chain was sent
    , HandshakeState -> Bool
hstCertReqSent :: Bool
    -- ^ Set to true when a certificate request was sent.  This applies
    -- only to requests sent during handshake (not post-handshake).
    , HandshakeState -> Maybe CertificateChain
hstClientCertChain :: Maybe CertificateChain
    , HandshakeState -> Maybe RecordState
hstPendingTxState :: Maybe RecordState
    , HandshakeState -> Maybe RecordState
hstPendingRxState :: Maybe RecordState
    , HandshakeState -> Maybe Cipher
hstPendingCipher :: Maybe Cipher
    , HandshakeState -> Compression
hstPendingCompression :: Compression
    , HandshakeState -> Bool
hstExtendedMainSecret :: Bool
    , HandshakeState -> Maybe Group
hstSupportedGroup :: Maybe Group
    , HandshakeState -> HandshakeMode13
hstTLS13HandshakeMode :: HandshakeMode13
    , HandshakeState -> RTT0Status
hstTLS13RTT0Status :: RTT0Status
    , HandshakeState -> Maybe (BaseSecret EarlySecret)
hstTLS13EarlySecret :: Maybe (BaseSecret EarlySecret) -- xxx
    , HandshakeState -> Maybe (BaseSecret ResumptionSecret)
hstTLS13ResumptionSecret :: Maybe (BaseSecret ResumptionSecret)
    , HandshakeState -> Bool
hstCCS13Sent :: Bool
    }
    deriving (Int -> HandshakeState -> ShowS
[HandshakeState] -> ShowS
HandshakeState -> String
(Int -> HandshakeState -> ShowS)
-> (HandshakeState -> String)
-> ([HandshakeState] -> ShowS)
-> Show HandshakeState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeState -> ShowS
showsPrec :: Int -> HandshakeState -> ShowS
$cshow :: HandshakeState -> String
show :: HandshakeState -> String
$cshowList :: [HandshakeState] -> ShowS
showList :: [HandshakeState] -> ShowS
Show)

-- | When we receive a CertificateRequest from a server, a just-in-time
--    callback is issued to the application to obtain a suitable certificate.
--    Somewhat unfortunately, the callback parameters don't abstract away the
--    details of the TLS 1.2 Certificate Request message, which combines the
--    legacy @certificate_types@ and new @supported_signature_algorithms@
--    parameters is a rather subtle way.
--
--    TLS 1.2 also (again unfortunately, in the opinion of the author of this
--    comment) overloads the signature algorithms parameter to constrain not only
--    the algorithms used in TLS, but also the algorithms used by issuing CAs in
--    the X.509 chain.  Best practice is to NOT treat such that restriction as a
--    MUST, but rather take it as merely a preference, when a choice exists.  If
--    the best chain available does not match the provided signature algorithm
--    list, go ahead and use it anyway, it will probably work, and the server may
--    not even care about the issuer CAs at all, it may be doing DANE or have
--    explicit mappings for the client's public key, ...
--
--    The TLS 1.3 @CertificateRequest@ message, drops @certificate_types@ and no
--    longer overloads @supported_signature_algorithms@ to cover X.509.  It also
--    includes a new opaque context token that the client must echo back, which
--    makes certain client authentication replay attacks more difficult.  We will
--    store that context separately, it does not need to be presented in the user
--    callback.  The certificate signature algorithms preferred by the peer are
--    now in the separate @signature_algorithms_cert@ extension, but we cannot
--    report these to the application callback without an API change.  The good
--    news is that filtering the X.509 signature types is generally unnecessary,
--    unwise and difficult.  So we just ignore this extension.
--
--    As a result, the information we provide to the callback is no longer a
--    verbatim copy of the certificate request payload.  In the case of TLS 1.3
--    The 'CertificateType' list is synthetically generated from the server's
--    @signature_algorithms@ extension, and the @signature_algorithms_certs@
--    extension is ignored.
--
--    Since the original TLS 1.2 'CertificateType' has no provision for the newer
--    certificate types that have appeared in TLS 1.3 we're adding some synthetic
--    values that have no equivalent values in the TLS 1.2 'CertificateType' as
--    defined in the IANA
--    <https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-2
--    TLS ClientCertificateType Identifiers> registry.  These values are inferred
--    from the TLS 1.3 @signature_algorithms@ extension, and will allow clients to
--    present Ed25519 and Ed448 certificates when these become supported.
type CertReqCBdata =
    ( [CertificateType]
    , Maybe [HashAndSignatureAlgorithm]
    , [DistinguishedName]
    )

newtype HandshakeM a = HandshakeM {forall a. HandshakeM a -> State HandshakeState a
runHandshakeM :: State HandshakeState a}
    deriving ((forall a b. (a -> b) -> HandshakeM a -> HandshakeM b)
-> (forall a b. a -> HandshakeM b -> HandshakeM a)
-> Functor HandshakeM
forall a b. a -> HandshakeM b -> HandshakeM a
forall a b. (a -> b) -> HandshakeM a -> HandshakeM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> HandshakeM a -> HandshakeM b
fmap :: forall a b. (a -> b) -> HandshakeM a -> HandshakeM b
$c<$ :: forall a b. a -> HandshakeM b -> HandshakeM a
<$ :: forall a b. a -> HandshakeM b -> HandshakeM a
Functor, Functor HandshakeM
Functor HandshakeM =>
(forall a. a -> HandshakeM a)
-> (forall a b.
    HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b)
-> (forall a b c.
    (a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c)
-> (forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b)
-> (forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a)
-> Applicative HandshakeM
forall a. a -> HandshakeM a
forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a
forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
forall a b. HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b
forall a b c.
(a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> HandshakeM a
pure :: forall a. a -> HandshakeM a
$c<*> :: forall a b. HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b
<*> :: forall a b. HandshakeM (a -> b) -> HandshakeM a -> HandshakeM b
$cliftA2 :: forall a b c.
(a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c
liftA2 :: forall a b c.
(a -> b -> c) -> HandshakeM a -> HandshakeM b -> HandshakeM c
$c*> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
*> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
$c<* :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a
<* :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM a
Applicative, Applicative HandshakeM
Applicative HandshakeM =>
(forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b)
-> (forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b)
-> (forall a. a -> HandshakeM a)
-> Monad HandshakeM
forall a. a -> HandshakeM a
forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
>>= :: forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
$c>> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
>> :: forall a b. HandshakeM a -> HandshakeM b -> HandshakeM b
$creturn :: forall a. a -> HandshakeM a
return :: forall a. a -> HandshakeM a
Monad)

instance MonadState HandshakeState HandshakeM where
    put :: HandshakeState -> HandshakeM ()
put HandshakeState
x = State HandshakeState () -> HandshakeM ()
forall a. State HandshakeState a -> HandshakeM a
HandshakeM (HandshakeState -> State HandshakeState ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put HandshakeState
x)
    get :: HandshakeM HandshakeState
get = State HandshakeState HandshakeState -> HandshakeM HandshakeState
forall a. State HandshakeState a -> HandshakeM a
HandshakeM State HandshakeState HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
    state :: forall a. (HandshakeState -> (a, HandshakeState)) -> HandshakeM a
state HandshakeState -> (a, HandshakeState)
f = State HandshakeState a -> HandshakeM a
forall a. State HandshakeState a -> HandshakeM a
HandshakeM ((HandshakeState -> (a, HandshakeState)) -> State HandshakeState a
forall a.
(HandshakeState -> (a, HandshakeState))
-> StateT HandshakeState Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state HandshakeState -> (a, HandshakeState)
f)

-- create a new empty handshake state
newEmptyHandshake :: Version -> ClientRandom -> HandshakeState
newEmptyHandshake :: Version -> ClientRandom -> HandshakeState
newEmptyHandshake Version
ver ClientRandom
crand =
    HandshakeState
        { hstClientVersion :: Version
hstClientVersion = Version
ver
        , hstClientRandom :: ClientRandom
hstClientRandom = ClientRandom
crand
        , hstServerRandom :: Maybe ServerRandom
hstServerRandom = Maybe ServerRandom
forall a. Maybe a
Nothing
        , hstMainSecret :: Maybe ByteString
hstMainSecret = Maybe ByteString
forall a. Maybe a
Nothing
        , hstKeyState :: HandshakeKeyState
hstKeyState = Maybe PubKey -> Maybe (PubKey, PrivKey) -> HandshakeKeyState
HandshakeKeyState Maybe PubKey
forall a. Maybe a
Nothing Maybe (PubKey, PrivKey)
forall a. Maybe a
Nothing
        , hstServerDHParams :: Maybe ServerDHParams
hstServerDHParams = Maybe ServerDHParams
forall a. Maybe a
Nothing
        , hstDHPrivate :: Maybe DHPrivate
hstDHPrivate = Maybe DHPrivate
forall a. Maybe a
Nothing
        , hstServerECDHParams :: Maybe ServerECDHParams
hstServerECDHParams = Maybe ServerECDHParams
forall a. Maybe a
Nothing
        , hstGroupPrivate :: Maybe GroupPrivate
hstGroupPrivate = Maybe GroupPrivate
forall a. Maybe a
Nothing
        , hstHandshakeDigest :: HandshakeDigest
hstHandshakeDigest = [ByteString] -> HandshakeDigest
HandshakeMessages []
        , hstHandshakeMessages :: [ByteString]
hstHandshakeMessages = []
        , hstCertReqToken :: Maybe ByteString
hstCertReqToken = Maybe ByteString
forall a. Maybe a
Nothing
        , hstCertReqCBdata :: Maybe CertReqCBdata
hstCertReqCBdata = Maybe CertReqCBdata
forall a. Maybe a
Nothing
        , hstCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm]
hstCertReqSigAlgsCert = Maybe [HashAndSignatureAlgorithm]
forall a. Maybe a
Nothing
        , hstClientCertSent :: Bool
hstClientCertSent = Bool
False
        , hstCertReqSent :: Bool
hstCertReqSent = Bool
False
        , hstClientCertChain :: Maybe CertificateChain
hstClientCertChain = Maybe CertificateChain
forall a. Maybe a
Nothing
        , hstPendingTxState :: Maybe RecordState
hstPendingTxState = Maybe RecordState
forall a. Maybe a
Nothing
        , hstPendingRxState :: Maybe RecordState
hstPendingRxState = Maybe RecordState
forall a. Maybe a
Nothing
        , hstPendingCipher :: Maybe Cipher
hstPendingCipher = Maybe Cipher
forall a. Maybe a
Nothing
        , hstPendingCompression :: Compression
hstPendingCompression = Compression
nullCompression
        , hstExtendedMainSecret :: Bool
hstExtendedMainSecret = Bool
False
        , hstSupportedGroup :: Maybe Group
hstSupportedGroup = Maybe Group
forall a. Maybe a
Nothing
        , hstTLS13HandshakeMode :: HandshakeMode13
hstTLS13HandshakeMode = HandshakeMode13
FullHandshake
        , hstTLS13RTT0Status :: RTT0Status
hstTLS13RTT0Status = RTT0Status
RTT0None
        , hstTLS13EarlySecret :: Maybe (BaseSecret EarlySecret)
hstTLS13EarlySecret = Maybe (BaseSecret EarlySecret)
forall a. Maybe a
Nothing
        , hstTLS13ResumptionSecret :: Maybe (BaseSecret ResumptionSecret)
hstTLS13ResumptionSecret = Maybe (BaseSecret ResumptionSecret)
forall a. Maybe a
Nothing
        , hstCCS13Sent :: Bool
hstCCS13Sent = Bool
False
        }

runHandshake :: HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake :: forall a. HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake HandshakeState
hst HandshakeM a
f = State HandshakeState a -> HandshakeState -> (a, HandshakeState)
forall s a. State s a -> s -> (a, s)
runState (HandshakeM a -> State HandshakeState a
forall a. HandshakeM a -> State HandshakeState a
runHandshakeM HandshakeM a
f) HandshakeState
hst

setPublicKey :: PubKey -> HandshakeM ()
setPublicKey :: PubKey -> HandshakeM ()
setPublicKey PubKey
pk = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstKeyState = setPK (hstKeyState hst)})
  where
    setPK :: HandshakeKeyState -> HandshakeKeyState
setPK HandshakeKeyState
hks = HandshakeKeyState
hks{hksRemotePublicKey = Just pk}

setPublicPrivateKeys :: (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys :: (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys (PubKey, PrivKey)
keys = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstKeyState = setKeys (hstKeyState hst)})
  where
    setKeys :: HandshakeKeyState -> HandshakeKeyState
setKeys HandshakeKeyState
hks = HandshakeKeyState
hks{hksLocalPublicPrivateKeys = Just keys}

getRemotePublicKey :: HandshakeM PubKey
getRemotePublicKey :: HandshakeM PubKey
getRemotePublicKey = Maybe PubKey -> PubKey
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe PubKey -> PubKey)
-> HandshakeM (Maybe PubKey) -> HandshakeM PubKey
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe PubKey) -> HandshakeM (Maybe PubKey)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (HandshakeKeyState -> Maybe PubKey
hksRemotePublicKey (HandshakeKeyState -> Maybe PubKey)
-> (HandshakeState -> HandshakeKeyState)
-> HandshakeState
-> Maybe PubKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> HandshakeKeyState
hstKeyState)

getLocalPublicPrivateKeys :: HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys :: HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys =
    Maybe (PubKey, PrivKey) -> (PubKey, PrivKey)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (PubKey, PrivKey) -> (PubKey, PrivKey))
-> HandshakeM (Maybe (PubKey, PrivKey))
-> HandshakeM (PubKey, PrivKey)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe (PubKey, PrivKey))
-> HandshakeM (Maybe (PubKey, PrivKey))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (HandshakeKeyState -> Maybe (PubKey, PrivKey)
hksLocalPublicPrivateKeys (HandshakeKeyState -> Maybe (PubKey, PrivKey))
-> (HandshakeState -> HandshakeKeyState)
-> HandshakeState
-> Maybe (PubKey, PrivKey)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> HandshakeKeyState
hstKeyState)

setServerDHParams :: ServerDHParams -> HandshakeM ()
setServerDHParams :: ServerDHParams -> HandshakeM ()
setServerDHParams ServerDHParams
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstServerDHParams = Just shp})

getServerDHParams :: HandshakeM ServerDHParams
getServerDHParams :: HandshakeM ServerDHParams
getServerDHParams = Maybe ServerDHParams -> ServerDHParams
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ServerDHParams -> ServerDHParams)
-> HandshakeM (Maybe ServerDHParams) -> HandshakeM ServerDHParams
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe ServerDHParams)
-> HandshakeM (Maybe ServerDHParams)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe ServerDHParams
hstServerDHParams

setServerECDHParams :: ServerECDHParams -> HandshakeM ()
setServerECDHParams :: ServerECDHParams -> HandshakeM ()
setServerECDHParams ServerECDHParams
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstServerECDHParams = Just shp})

getServerECDHParams :: HandshakeM ServerECDHParams
getServerECDHParams :: HandshakeM ServerECDHParams
getServerECDHParams = Maybe ServerECDHParams -> ServerECDHParams
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ServerECDHParams -> ServerECDHParams)
-> HandshakeM (Maybe ServerECDHParams)
-> HandshakeM ServerECDHParams
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe ServerECDHParams)
-> HandshakeM (Maybe ServerECDHParams)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe ServerECDHParams
hstServerECDHParams

setDHPrivate :: DHPrivate -> HandshakeM ()
setDHPrivate :: DHPrivate -> HandshakeM ()
setDHPrivate DHPrivate
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstDHPrivate = Just shp})

getDHPrivate :: HandshakeM DHPrivate
getDHPrivate :: HandshakeM DHPrivate
getDHPrivate = Maybe DHPrivate -> DHPrivate
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe DHPrivate -> DHPrivate)
-> HandshakeM (Maybe DHPrivate) -> HandshakeM DHPrivate
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe DHPrivate) -> HandshakeM (Maybe DHPrivate)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe DHPrivate
hstDHPrivate

getGroupPrivate :: HandshakeM GroupPrivate
getGroupPrivate :: HandshakeM GroupPrivate
getGroupPrivate = Maybe GroupPrivate -> GroupPrivate
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe GroupPrivate -> GroupPrivate)
-> HandshakeM (Maybe GroupPrivate) -> HandshakeM GroupPrivate
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe GroupPrivate)
-> HandshakeM (Maybe GroupPrivate)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe GroupPrivate
hstGroupPrivate

setGroupPrivate :: GroupPrivate -> HandshakeM ()
setGroupPrivate :: GroupPrivate -> HandshakeM ()
setGroupPrivate GroupPrivate
shp = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstGroupPrivate = Just shp})

setExtendedMainSecret :: Bool -> HandshakeM ()
setExtendedMainSecret :: Bool -> HandshakeM ()
setExtendedMainSecret Bool
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstExtendedMainSecret = b})

getExtendedMainSecret :: HandshakeM Bool
getExtendedMainSecret :: HandshakeM Bool
getExtendedMainSecret = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstExtendedMainSecret

setSupportedGroup :: Group -> HandshakeM ()
setSupportedGroup :: Group -> HandshakeM ()
setSupportedGroup Group
g = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstSupportedGroup = Just g})

getSupportedGroup :: HandshakeM (Maybe Group)
getSupportedGroup :: HandshakeM (Maybe Group)
getSupportedGroup = (HandshakeState -> Maybe Group) -> HandshakeM (Maybe Group)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe Group
hstSupportedGroup

-- | Type to show which handshake mode is used in TLS 1.3.
data HandshakeMode13
    = -- | Full handshake is used.
      FullHandshake
    | -- | Full handshake is used with hello retry request.
      HelloRetryRequest
    | -- | Server authentication is skipped.
      PreSharedKey
    | -- | Server authentication is skipped and early data is sent.
      RTT0
    deriving (Int -> HandshakeMode13 -> ShowS
[HandshakeMode13] -> ShowS
HandshakeMode13 -> String
(Int -> HandshakeMode13 -> ShowS)
-> (HandshakeMode13 -> String)
-> ([HandshakeMode13] -> ShowS)
-> Show HandshakeMode13
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeMode13 -> ShowS
showsPrec :: Int -> HandshakeMode13 -> ShowS
$cshow :: HandshakeMode13 -> String
show :: HandshakeMode13 -> String
$cshowList :: [HandshakeMode13] -> ShowS
showList :: [HandshakeMode13] -> ShowS
Show, HandshakeMode13 -> HandshakeMode13 -> Bool
(HandshakeMode13 -> HandshakeMode13 -> Bool)
-> (HandshakeMode13 -> HandshakeMode13 -> Bool)
-> Eq HandshakeMode13
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandshakeMode13 -> HandshakeMode13 -> Bool
== :: HandshakeMode13 -> HandshakeMode13 -> Bool
$c/= :: HandshakeMode13 -> HandshakeMode13 -> Bool
/= :: HandshakeMode13 -> HandshakeMode13 -> Bool
Eq)

setTLS13HandshakeMode :: HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode :: HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
s = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstTLS13HandshakeMode = s})

getTLS13HandshakeMode :: HandshakeM HandshakeMode13
getTLS13HandshakeMode :: HandshakeM HandshakeMode13
getTLS13HandshakeMode = (HandshakeState -> HandshakeMode13) -> HandshakeM HandshakeMode13
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> HandshakeMode13
hstTLS13HandshakeMode

data RTT0Status
    = RTT0None
    | RTT0Sent
    | RTT0Accepted
    | RTT0Rejected
    deriving (Int -> RTT0Status -> ShowS
[RTT0Status] -> ShowS
RTT0Status -> String
(Int -> RTT0Status -> ShowS)
-> (RTT0Status -> String)
-> ([RTT0Status] -> ShowS)
-> Show RTT0Status
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RTT0Status -> ShowS
showsPrec :: Int -> RTT0Status -> ShowS
$cshow :: RTT0Status -> String
show :: RTT0Status -> String
$cshowList :: [RTT0Status] -> ShowS
showList :: [RTT0Status] -> ShowS
Show, RTT0Status -> RTT0Status -> Bool
(RTT0Status -> RTT0Status -> Bool)
-> (RTT0Status -> RTT0Status -> Bool) -> Eq RTT0Status
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RTT0Status -> RTT0Status -> Bool
== :: RTT0Status -> RTT0Status -> Bool
$c/= :: RTT0Status -> RTT0Status -> Bool
/= :: RTT0Status -> RTT0Status -> Bool
Eq)

setTLS13RTT0Status :: RTT0Status -> HandshakeM ()
setTLS13RTT0Status :: RTT0Status -> HandshakeM ()
setTLS13RTT0Status RTT0Status
s = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstTLS13RTT0Status = s})

getTLS13RTT0Status :: HandshakeM RTT0Status
getTLS13RTT0Status :: HandshakeM RTT0Status
getTLS13RTT0Status = (HandshakeState -> RTT0Status) -> HandshakeM RTT0Status
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> RTT0Status
hstTLS13RTT0Status

setTLS13EarlySecret :: BaseSecret EarlySecret -> HandshakeM ()
setTLS13EarlySecret :: BaseSecret EarlySecret -> HandshakeM ()
setTLS13EarlySecret BaseSecret EarlySecret
secret = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstTLS13EarlySecret = Just secret})

getTLS13EarlySecret :: HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret :: HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret = (HandshakeState -> Maybe (BaseSecret EarlySecret))
-> HandshakeM (Maybe (BaseSecret EarlySecret))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe (BaseSecret EarlySecret)
hstTLS13EarlySecret

setTLS13ResumptionSecret :: BaseSecret ResumptionSecret -> HandshakeM ()
setTLS13ResumptionSecret :: BaseSecret ResumptionSecret -> HandshakeM ()
setTLS13ResumptionSecret BaseSecret ResumptionSecret
secret = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstTLS13ResumptionSecret = Just secret})

getTLS13ResumptionSecret :: HandshakeM (Maybe (BaseSecret ResumptionSecret))
getTLS13ResumptionSecret :: HandshakeM (Maybe (BaseSecret ResumptionSecret))
getTLS13ResumptionSecret = (HandshakeState -> Maybe (BaseSecret ResumptionSecret))
-> HandshakeM (Maybe (BaseSecret ResumptionSecret))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe (BaseSecret ResumptionSecret)
hstTLS13ResumptionSecret

setCCS13Sent :: Bool -> HandshakeM ()
setCCS13Sent :: Bool -> HandshakeM ()
setCCS13Sent Bool
sent = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstCCS13Sent = sent})

getCCS13Sent :: HandshakeM Bool
getCCS13Sent :: HandshakeM Bool
getCCS13Sent = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstCCS13Sent

setCertReqSent :: Bool -> HandshakeM ()
setCertReqSent :: Bool -> HandshakeM ()
setCertReqSent Bool
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstCertReqSent = b})

getCertReqSent :: HandshakeM Bool
getCertReqSent :: HandshakeM Bool
getCertReqSent = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstCertReqSent

setClientCertSent :: Bool -> HandshakeM ()
setClientCertSent :: Bool -> HandshakeM ()
setClientCertSent Bool
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstClientCertSent = b})

getClientCertSent :: HandshakeM Bool
getClientCertSent :: HandshakeM Bool
getClientCertSent = (HandshakeState -> Bool) -> HandshakeM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Bool
hstClientCertSent

setClientCertChain :: CertificateChain -> HandshakeM ()
setClientCertChain :: CertificateChain -> HandshakeM ()
setClientCertChain CertificateChain
b = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstClientCertChain = Just b})

getClientCertChain :: HandshakeM (Maybe CertificateChain)
getClientCertChain :: HandshakeM (Maybe CertificateChain)
getClientCertChain = (HandshakeState -> Maybe CertificateChain)
-> HandshakeM (Maybe CertificateChain)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe CertificateChain
hstClientCertChain

--
setCertReqToken :: Maybe ByteString -> HandshakeM ()
setCertReqToken :: Maybe ByteString -> HandshakeM ()
setCertReqToken Maybe ByteString
token = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst -> HandshakeState
hst{hstCertReqToken = token}

getCertReqToken :: HandshakeM (Maybe ByteString)
getCertReqToken :: HandshakeM (Maybe ByteString)
getCertReqToken = (HandshakeState -> Maybe ByteString)
-> HandshakeM (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe ByteString
hstCertReqToken

--
setCertReqCBdata :: Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata :: Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata Maybe CertReqCBdata
d = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\HandshakeState
hst -> HandshakeState
hst{hstCertReqCBdata = d})

getCertReqCBdata :: HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata :: HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata = (HandshakeState -> Maybe CertReqCBdata)
-> HandshakeM (Maybe CertReqCBdata)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe CertReqCBdata
hstCertReqCBdata

-- Dead code, until we find some use for the extension
setCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm] -> HandshakeM ()
setCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm] -> HandshakeM ()
setCertReqSigAlgsCert Maybe [HashAndSignatureAlgorithm]
as = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst -> HandshakeState
hst{hstCertReqSigAlgsCert = as}

getCertReqSigAlgsCert :: HandshakeM (Maybe [HashAndSignatureAlgorithm])
getCertReqSigAlgsCert :: HandshakeM (Maybe [HashAndSignatureAlgorithm])
getCertReqSigAlgsCert = (HandshakeState -> Maybe [HashAndSignatureAlgorithm])
-> HandshakeM (Maybe [HashAndSignatureAlgorithm])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe [HashAndSignatureAlgorithm]
hstCertReqSigAlgsCert

--
getPendingCipher :: HandshakeM Cipher
getPendingCipher :: HandshakeM Cipher
getPendingCipher = Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher)
-> HandshakeM (Maybe Cipher) -> HandshakeM Cipher
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe Cipher) -> HandshakeM (Maybe Cipher)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe Cipher
hstPendingCipher

addHandshakeMessage :: ByteString -> HandshakeM ()
addHandshakeMessage :: ByteString -> HandshakeM ()
addHandshakeMessage ByteString
content = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hs -> HandshakeState
hs{hstHandshakeMessages = content : hstHandshakeMessages hs}

getHandshakeMessages :: HandshakeM [ByteString]
getHandshakeMessages :: HandshakeM [ByteString]
getHandshakeMessages = (HandshakeState -> [ByteString]) -> HandshakeM [ByteString]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse ([ByteString] -> [ByteString])
-> (HandshakeState -> [ByteString])
-> HandshakeState
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> [ByteString]
hstHandshakeMessages)

getHandshakeMessagesRev :: HandshakeM [ByteString]
getHandshakeMessagesRev :: HandshakeM [ByteString]
getHandshakeMessagesRev = (HandshakeState -> [ByteString]) -> HandshakeM [ByteString]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> [ByteString]
hstHandshakeMessages

updateHandshakeDigest :: ByteString -> HandshakeM ()
updateHandshakeDigest :: ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
content = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hs ->
    HandshakeState
hs
        { hstHandshakeDigest = case hstHandshakeDigest hs of
            HandshakeMessages [ByteString]
bytes -> [ByteString] -> HandshakeDigest
HandshakeMessages (ByteString
content ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bytes)
            HandshakeDigestContext HashCtx
hashCtx -> HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString -> HashCtx
hashUpdate HashCtx
hashCtx ByteString
content
        }

-- | Compress the whole transcript with the specified function.  Function @f@
-- takes the handshake digest as input and returns an encoded handshake message
-- to replace the transcript with.
foldHandshakeDigest :: Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest :: Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest Hash
hashAlg ByteString -> ByteString
f = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hs ->
    case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hs of
        HandshakeMessages [ByteString]
bytes ->
            let hashCtx :: HashCtx
hashCtx = (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
                folded :: ByteString
folded = ByteString -> ByteString
f (HashCtx -> ByteString
hashFinal HashCtx
hashCtx)
             in HandshakeState
hs
                    { hstHandshakeDigest = HandshakeMessages [folded]
                    , hstHandshakeMessages = [folded]
                    }
        HandshakeDigestContext HashCtx
hashCtx ->
            let folded :: ByteString
folded = ByteString -> ByteString
f (HashCtx -> ByteString
hashFinal HashCtx
hashCtx)
                hashCtx' :: HashCtx
hashCtx' = HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ByteString
folded
             in HandshakeState
hs
                    { hstHandshakeDigest = HandshakeDigestContext hashCtx'
                    , hstHandshakeMessages = [folded]
                    }

getSessionHash :: HandshakeM ByteString
getSessionHash :: HandshakeM ByteString
getSessionHash = (HandshakeState -> ByteString) -> HandshakeM ByteString
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((HandshakeState -> ByteString) -> HandshakeM ByteString)
-> (HandshakeState -> ByteString) -> HandshakeM ByteString
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
    case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
        HandshakeDigestContext HashCtx
hashCtx -> HashCtx -> ByteString
hashFinal HashCtx
hashCtx
        HandshakeMessages [ByteString]
_ -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized session hash"

getHandshakeDigest :: Version -> Role -> HandshakeM ByteString
getHandshakeDigest :: Version -> Role -> HandshakeM ByteString
getHandshakeDigest Version
ver Role
role = (HandshakeState -> ByteString) -> HandshakeM ByteString
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> ByteString
gen
  where
    gen :: HandshakeState -> ByteString
gen HandshakeState
hst = case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
        HandshakeDigestContext HashCtx
hashCtx ->
            let msecret :: ByteString
msecret = Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe ByteString
hstMainSecret HandshakeState
hst
                cipher :: Cipher
cipher = Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst
             in Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateFinished Version
ver Cipher
cipher ByteString
msecret HashCtx
hashCtx
        HandshakeMessages [ByteString]
_ ->
            String -> ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized handshake digest"
    generateFinished :: Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateFinished
        | Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole = Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateClientFinished
        | Bool
otherwise = Version -> Cipher -> ByteString -> HashCtx -> ByteString
generateServerFinished

-- | Generate the main secret from the pre-main secret.
setMainSecretFromPre
    :: ByteArrayAccess preMain
    => Version
    -- ^ chosen transmission version
    -> Role
    -- ^ the role (Client or Server) of the generating side
    -> preMain
    -- ^ the pre-main secret
    -> HandshakeM ByteString
setMainSecretFromPre :: forall preMain.
ByteArrayAccess preMain =>
Version -> Role -> preMain -> HandshakeM ByteString
setMainSecretFromPre Version
ver Role
role preMain
preMainSecret = do
    Bool
ems <- HandshakeM Bool
getExtendedMainSecret
    ByteString
secret <- if Bool
ems then HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get HandshakeM HandshakeState
-> (HandshakeState -> HandshakeM ByteString)
-> HandshakeM ByteString
forall a b. HandshakeM a -> (a -> HandshakeM b) -> HandshakeM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> HandshakeM ByteString
genExtendedSecret else HandshakeState -> ByteString
genSecret (HandshakeState -> ByteString)
-> HandshakeM HandshakeState -> HandshakeM ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
    Version -> Role -> ByteString -> HandshakeM ()
setMainSecret Version
ver Role
role ByteString
secret
    ByteString -> HandshakeM ByteString
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
secret
  where
    genSecret :: HandshakeState -> ByteString
genSecret HandshakeState
hst =
        Version
-> Cipher -> preMain -> ClientRandom -> ServerRandom -> ByteString
forall preMain.
ByteArrayAccess preMain =>
Version
-> Cipher -> preMain -> ClientRandom -> ServerRandom -> ByteString
generateMainSecret
            Version
ver
            (Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst)
            preMain
preMainSecret
            (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hst)
            (Maybe ServerRandom -> ServerRandom
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ServerRandom -> ServerRandom)
-> Maybe ServerRandom -> ServerRandom
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe ServerRandom
hstServerRandom HandshakeState
hst)
    genExtendedSecret :: HandshakeState -> HandshakeM ByteString
genExtendedSecret HandshakeState
hst =
        Version -> Cipher -> preMain -> ByteString -> ByteString
forall preMain.
ByteArrayAccess preMain =>
Version -> Cipher -> preMain -> ByteString -> ByteString
generateExtendedMainSecret
            Version
ver
            (Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst)
            preMain
preMainSecret
            (ByteString -> ByteString)
-> HandshakeM ByteString -> HandshakeM ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeM ByteString
getSessionHash

-- | Set main secret and as a side effect generate the key block
-- with all the right parameters, and setup the pending tx/rx state.
setMainSecret :: Version -> Role -> ByteString -> HandshakeM ()
setMainSecret :: Version -> Role -> ByteString -> HandshakeM ()
setMainSecret Version
ver Role
role ByteString
mainSecret = (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
    let (RecordState
pendingTx, RecordState
pendingRx) = HandshakeState
-> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock HandshakeState
hst ByteString
mainSecret Version
ver Role
role
     in HandshakeState
hst
            { hstMainSecret = Just mainSecret
            , hstPendingTxState = Just pendingTx
            , hstPendingRxState = Just pendingRx
            }

computeKeyBlock
    :: HandshakeState -> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock :: HandshakeState
-> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock HandshakeState
hst ByteString
mainSecret Version
ver Role
cc = (RecordState
pendingTx, RecordState
pendingRx)
  where
    cipher :: Cipher
cipher = Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst
    keyblockSize :: Int
keyblockSize = Cipher -> Int
cipherKeyBlockSize Cipher
cipher

    bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
    digestSize :: Int
digestSize =
        if BulkFunctions -> Bool
hasMAC (Bulk -> BulkFunctions
bulkF Bulk
bulk)
            then Hash -> Int
hashDigestSize (Cipher -> Hash
cipherHash Cipher
cipher)
            else Int
0
    keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
    ivSize :: Int
ivSize = Bulk -> Int
bulkIVSize Bulk
bulk
    kb :: ByteString
kb =
        Version
-> Cipher
-> ClientRandom
-> ServerRandom
-> ByteString
-> Int
-> ByteString
generateKeyBlock
            Version
ver
            Cipher
cipher
            (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hst)
            (Maybe ServerRandom -> ServerRandom
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ServerRandom -> ServerRandom)
-> Maybe ServerRandom -> ServerRandom
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe ServerRandom
hstServerRandom HandshakeState
hst)
            ByteString
mainSecret
            Int
keyblockSize

    (ByteString
cMACSecret, ByteString
sMACSecret, ByteString
cWriteKey, ByteString
sWriteKey, ByteString
cWriteIV, ByteString
sWriteIV) =
        Maybe
  (ByteString, ByteString, ByteString, ByteString, ByteString,
   ByteString)
-> (ByteString, ByteString, ByteString, ByteString, ByteString,
    ByteString)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe
   (ByteString, ByteString, ByteString, ByteString, ByteString,
    ByteString)
 -> (ByteString, ByteString, ByteString, ByteString, ByteString,
     ByteString))
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
-> (ByteString, ByteString, ByteString, ByteString, ByteString,
    ByteString)
forall a b. (a -> b) -> a -> b
$
            ByteString
-> (Int, Int, Int, Int, Int, Int)
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
partition6 ByteString
kb (Int
digestSize, Int
digestSize, Int
keySize, Int
keySize, Int
ivSize, Int
ivSize)

    cstClient :: CryptState
cstClient =
        CryptState
            { cstKey :: BulkState
cstKey = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk (BulkDirection
BulkEncrypt BulkDirection -> BulkDirection -> BulkDirection
forall {p}. p -> p -> p
`orOnServer` BulkDirection
BulkDecrypt) ByteString
cWriteKey
            , cstIV :: ByteString
cstIV = ByteString
cWriteIV
            , cstMacSecret :: ByteString
cstMacSecret = ByteString
cMACSecret
            }
    cstServer :: CryptState
cstServer =
        CryptState
            { cstKey :: BulkState
cstKey = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk (BulkDirection
BulkDecrypt BulkDirection -> BulkDirection -> BulkDirection
forall {p}. p -> p -> p
`orOnServer` BulkDirection
BulkEncrypt) ByteString
sWriteKey
            , cstIV :: ByteString
cstIV = ByteString
sWriteIV
            , cstMacSecret :: ByteString
cstMacSecret = ByteString
sMACSecret
            }
    msClient :: MacState
msClient = MacState{msSequence :: Word64
msSequence = Word64
0}
    msServer :: MacState
msServer = MacState{msSequence :: Word64
msSequence = Word64
0}

    pendingTx :: RecordState
pendingTx =
        RecordState
            { stCryptState :: CryptState
stCryptState = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then CryptState
cstClient else CryptState
cstServer
            , stMacState :: MacState
stMacState = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then MacState
msClient else MacState
msServer
            , stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
CryptMainSecret
            , stCipher :: Maybe Cipher
stCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
            , stCompression :: Compression
stCompression = HandshakeState -> Compression
hstPendingCompression HandshakeState
hst
            }
    pendingRx :: RecordState
pendingRx =
        RecordState
            { stCryptState :: CryptState
stCryptState = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then CryptState
cstServer else CryptState
cstClient
            , stMacState :: MacState
stMacState = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then MacState
msServer else MacState
msClient
            , stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
CryptMainSecret
            , stCipher :: Maybe Cipher
stCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
            , stCompression :: Compression
stCompression = HandshakeState -> Compression
hstPendingCompression HandshakeState
hst
            }

    orOnServer :: p -> p -> p
orOnServer p
f p
g = if Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then p
f else p
g

setServerHelloParameters
    :: Version
    -- ^ chosen version
    -> ServerRandom
    -> Cipher
    -> Compression
    -> HandshakeM ()
setServerHelloParameters :: Version -> ServerRandom -> Cipher -> Compression -> HandshakeM ()
setServerHelloParameters Version
ver ServerRandom
sran Cipher
cipher Compression
compression = do
    (HandshakeState -> HandshakeState) -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HandshakeState -> HandshakeState) -> HandshakeM ())
-> (HandshakeState -> HandshakeState) -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ \HandshakeState
hst ->
        HandshakeState
hst
            { hstServerRandom = Just sran
            , hstPendingCipher = Just cipher
            , hstPendingCompression = compression
            , hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
            }
  where
    hashAlg :: Hash
hashAlg = Version -> Cipher -> Hash
getHash Version
ver Cipher
cipher
    updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages [ByteString]
bytes) = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
    updateDigest (HandshakeDigestContext HashCtx
_) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error String
"cannot initialize digest with another digest"

-- The TLS12 Hash is cipher specific, and some TLS12 algorithms use SHA384
-- instead of the default SHA256.
getHash :: Version -> Cipher -> Hash
getHash :: Version -> Cipher -> Hash
getHash Version
ver Cipher
ciph
    | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS12 = Hash
SHA1_MD5
    | Bool -> (Version -> Bool) -> Maybe Version -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS12) (Cipher -> Maybe Version
cipherMinVer Cipher
ciph) = Hash
SHA256
    | Bool
otherwise = Cipher -> Hash
cipherHash Cipher
ciph