{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}

-- | the State module contains calls related to state
-- initialization/manipulation which is use by the Receiving module
-- and the Sending module.
module Network.TLS.State (
    TLSState (..),
    TLSSt,
    runTLSState,
    newTLSState,
    withTLSRNG,
    setVerifyDataForSend,
    setVerifyDataForRecv,
    getVerifyData,
    getMyVerifyData,
    getPeerVerifyData,
    getFirstVerifyData,
    finishedHandshakeTypeMaterial,
    finishedHandshakeMaterial,
    certVerifyHandshakeTypeMaterial,
    certVerifyHandshakeMaterial,
    setVersion,
    setVersionIfUnset,
    getVersion,
    getVersionWithDefault,
    setSecureRenegotiation,
    getSecureRenegotiation,
    setExtensionALPN,
    getExtensionALPN,
    setNegotiatedProtocol,
    getNegotiatedProtocol,
    setClientALPNSuggest,
    getClientALPNSuggest,
    setClientEcPointFormatSuggest,
    getClientEcPointFormatSuggest,
    setClientSNI,
    getClientSNI,
    getClientCertificateChain,
    setClientCertificateChain,
    getServerCertificateChain,
    setServerCertificateChain,
    setSession,
    getSession,
    getRole,
    --
    setTLS12SessionResuming,
    getTLS12SessionResuming,
    --
    setTLS13ExporterSecret,
    getTLS13ExporterSecret,
    setTLS13KeyShare,
    getTLS13KeyShare,
    setTLS13PreSharedKey,
    getTLS13PreSharedKey,
    setTLS13HRR,
    getTLS13HRR,
    setTLS13Cookie,
    getTLS13Cookie,
    setTLS13ClientSupportsPHA,
    getTLS13ClientSupportsPHA,
    setTLS12SessionTicket,
    getTLS12SessionTicket,

    -- * random
    genRandom,
    withRNG,
) where

import Control.Monad.State.Strict
import Crypto.Random
import qualified Data.ByteString as B
import Data.X509 (CertificateChain)
import Network.TLS.ErrT
import Network.TLS.Extension
import Network.TLS.Imports
import Network.TLS.RNG
import Network.TLS.Struct
import Network.TLS.Types (HostName, Role (..), Ticket)
import Network.TLS.Wire (GetContinuation)

data TLSState = TLSState
    { TLSState -> Session
stSession :: Session
    , -- RFC 5746, Renegotiation Indication Extension
      -- RFC 5929, Channel Bindings for TLS, "tls-unique"
      TLSState -> Bool
stSecureRenegotiation :: Bool
    , TLSState -> Maybe ByteString
stClientVerifyData :: Maybe VerifyData
    , TLSState -> Maybe ByteString
stServerVerifyData :: Maybe VerifyData
    , -- RFC 5929, Channel Bindings for TLS, "tls-server-end-point"
      TLSState -> Maybe CertificateChain
stServerCertificateChain :: Maybe CertificateChain
    , TLSState -> Bool
stExtensionALPN :: Bool -- RFC 7301
    , TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
    , TLSState -> Maybe ByteString
stNegotiatedProtocol :: Maybe B.ByteString -- ALPN protocol
    , TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont13 :: Maybe (GetContinuation (HandshakeType, ByteString))
    , TLSState -> Maybe [ByteString]
stClientALPNSuggest :: Maybe [B.ByteString]
    , TLSState -> Maybe [Group]
stClientGroupSuggest :: Maybe [Group]
    , TLSState -> Maybe [EcPointFormat]
stClientEcPointFormatSuggest :: Maybe [EcPointFormat]
    , TLSState -> Maybe CertificateChain
stClientCertificateChain :: Maybe CertificateChain
    , TLSState -> Maybe HostName
stClientSNI :: Maybe HostName
    , TLSState -> StateRNG
stRandomGen :: StateRNG
    , TLSState -> Role
stClientContext :: Role
    , TLSState -> Maybe Version
stVersion :: Maybe Version
    , --
      TLSState -> Bool
stTLS12SessionResuming :: Bool
    , TLSState -> Maybe ByteString
stTLS12SessionTicket :: Maybe Ticket
    , --
      TLSState -> Maybe KeyShare
stTLS13KeyShare :: Maybe KeyShare
    , TLSState -> Maybe PreSharedKey
stTLS13PreSharedKey :: Maybe PreSharedKey
    , TLSState -> Bool
stTLS13HRR :: Bool
    , TLSState -> Maybe Cookie
stTLS13Cookie :: Maybe Cookie
    , TLSState -> Maybe ByteString
stTLS13ExporterSecret :: Maybe ByteString
    , TLSState -> Bool
stTLS13ClientSupportsPHA :: Bool -- Post-Handshake Authentication
    }

newtype TLSSt a = TLSSt {forall a. TLSSt a -> ErrT TLSError (State TLSState) a
runTLSSt :: ErrT TLSError (State TLSState) a}
    deriving (Applicative TLSSt
Applicative TLSSt =>
(forall a b. TLSSt a -> (a -> TLSSt b) -> TLSSt b)
-> (forall a b. TLSSt a -> TLSSt b -> TLSSt b)
-> (forall a. a -> TLSSt a)
-> Monad TLSSt
forall a. a -> TLSSt a
forall a b. TLSSt a -> TLSSt b -> TLSSt b
forall a b. TLSSt a -> (a -> TLSSt b) -> TLSSt b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. TLSSt a -> (a -> TLSSt b) -> TLSSt b
>>= :: forall a b. TLSSt a -> (a -> TLSSt b) -> TLSSt b
$c>> :: forall a b. TLSSt a -> TLSSt b -> TLSSt b
>> :: forall a b. TLSSt a -> TLSSt b -> TLSSt b
$creturn :: forall a. a -> TLSSt a
return :: forall a. a -> TLSSt a
Monad, MonadError TLSError, (forall a b. (a -> b) -> TLSSt a -> TLSSt b)
-> (forall a b. a -> TLSSt b -> TLSSt a) -> Functor TLSSt
forall a b. a -> TLSSt b -> TLSSt a
forall a b. (a -> b) -> TLSSt a -> TLSSt b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> TLSSt a -> TLSSt b
fmap :: forall a b. (a -> b) -> TLSSt a -> TLSSt b
$c<$ :: forall a b. a -> TLSSt b -> TLSSt a
<$ :: forall a b. a -> TLSSt b -> TLSSt a
Functor, Functor TLSSt
Functor TLSSt =>
(forall a. a -> TLSSt a)
-> (forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b)
-> (forall a b c. (a -> b -> c) -> TLSSt a -> TLSSt b -> TLSSt c)
-> (forall a b. TLSSt a -> TLSSt b -> TLSSt b)
-> (forall a b. TLSSt a -> TLSSt b -> TLSSt a)
-> Applicative TLSSt
forall a. a -> TLSSt a
forall a b. TLSSt a -> TLSSt b -> TLSSt a
forall a b. TLSSt a -> TLSSt b -> TLSSt b
forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
forall a b c. (a -> b -> c) -> TLSSt a -> TLSSt b -> TLSSt c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> TLSSt a
pure :: forall a. a -> TLSSt a
$c<*> :: forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
<*> :: forall a b. TLSSt (a -> b) -> TLSSt a -> TLSSt b
$cliftA2 :: forall a b c. (a -> b -> c) -> TLSSt a -> TLSSt b -> TLSSt c
liftA2 :: forall a b c. (a -> b -> c) -> TLSSt a -> TLSSt b -> TLSSt c
$c*> :: forall a b. TLSSt a -> TLSSt b -> TLSSt b
*> :: forall a b. TLSSt a -> TLSSt b -> TLSSt b
$c<* :: forall a b. TLSSt a -> TLSSt b -> TLSSt a
<* :: forall a b. TLSSt a -> TLSSt b -> TLSSt a
Applicative)

instance MonadState TLSState TLSSt where
    put :: TLSState -> TLSSt ()
put TLSState
x = ErrT TLSError (State TLSState) () -> TLSSt ()
forall a. ErrT TLSError (State TLSState) a -> TLSSt a
TLSSt (State TLSState () -> ErrT TLSError (State TLSState) ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT TLSError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State TLSState () -> ErrT TLSError (State TLSState) ())
-> State TLSState () -> ErrT TLSError (State TLSState) ()
forall a b. (a -> b) -> a -> b
$ TLSState -> State TLSState ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put TLSState
x)
    get :: TLSSt TLSState
get = ErrT TLSError (State TLSState) TLSState -> TLSSt TLSState
forall a. ErrT TLSError (State TLSState) a -> TLSSt a
TLSSt (State TLSState TLSState -> ErrT TLSError (State TLSState) TLSState
forall (m :: * -> *) a. Monad m => m a -> ExceptT TLSError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State TLSState TLSState
forall s (m :: * -> *). MonadState s m => m s
get)
    state :: forall a. (TLSState -> (a, TLSState)) -> TLSSt a
state TLSState -> (a, TLSState)
f = ErrT TLSError (State TLSState) a -> TLSSt a
forall a. ErrT TLSError (State TLSState) a -> TLSSt a
TLSSt (State TLSState a -> ErrT TLSError (State TLSState) a
forall (m :: * -> *) a. Monad m => m a -> ExceptT TLSError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State TLSState a -> ErrT TLSError (State TLSState) a)
-> State TLSState a -> ErrT TLSError (State TLSState) a
forall a b. (a -> b) -> a -> b
$ (TLSState -> (a, TLSState)) -> State TLSState a
forall a. (TLSState -> (a, TLSState)) -> StateT TLSState Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state TLSState -> (a, TLSState)
f)

runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState :: forall a. TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState TLSSt a
f TLSState
st = State TLSState (Either TLSError a)
-> TLSState -> (Either TLSError a, TLSState)
forall s a. State s a -> s -> (a, s)
runState (ExceptT TLSError (State TLSState) a
-> State TLSState (Either TLSError a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runErrT (TLSSt a -> ExceptT TLSError (State TLSState) a
forall a. TLSSt a -> ErrT TLSError (State TLSState) a
runTLSSt TLSSt a
f)) TLSState
st

newTLSState :: StateRNG -> Role -> TLSState
newTLSState :: StateRNG -> Role -> TLSState
newTLSState StateRNG
rng Role
clientContext =
    TLSState
        { stSession :: Session
stSession = Maybe ByteString -> Session
Session Maybe ByteString
forall a. Maybe a
Nothing
        , stSecureRenegotiation :: Bool
stSecureRenegotiation = Bool
False
        , stClientVerifyData :: Maybe ByteString
stClientVerifyData = Maybe ByteString
forall a. Maybe a
Nothing
        , stServerVerifyData :: Maybe ByteString
stServerVerifyData = Maybe ByteString
forall a. Maybe a
Nothing
        , stServerCertificateChain :: Maybe CertificateChain
stServerCertificateChain = Maybe CertificateChain
forall a. Maybe a
Nothing
        , stExtensionALPN :: Bool
stExtensionALPN = Bool
False
        , stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing
        , stHandshakeRecordCont13 :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont13 = Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing
        , stNegotiatedProtocol :: Maybe ByteString
stNegotiatedProtocol = Maybe ByteString
forall a. Maybe a
Nothing
        , stClientALPNSuggest :: Maybe [ByteString]
stClientALPNSuggest = Maybe [ByteString]
forall a. Maybe a
Nothing
        , stClientGroupSuggest :: Maybe [Group]
stClientGroupSuggest = Maybe [Group]
forall a. Maybe a
Nothing
        , stClientEcPointFormatSuggest :: Maybe [EcPointFormat]
stClientEcPointFormatSuggest = Maybe [EcPointFormat]
forall a. Maybe a
Nothing
        , stClientCertificateChain :: Maybe CertificateChain
stClientCertificateChain = Maybe CertificateChain
forall a. Maybe a
Nothing
        , stClientSNI :: Maybe HostName
stClientSNI = Maybe HostName
forall a. Maybe a
Nothing
        , stRandomGen :: StateRNG
stRandomGen = StateRNG
rng
        , stClientContext :: Role
stClientContext = Role
clientContext
        , stVersion :: Maybe Version
stVersion = Maybe Version
forall a. Maybe a
Nothing
        , stTLS12SessionResuming :: Bool
stTLS12SessionResuming = Bool
False
        , stTLS12SessionTicket :: Maybe ByteString
stTLS12SessionTicket = Maybe ByteString
forall a. Maybe a
Nothing
        , stTLS13KeyShare :: Maybe KeyShare
stTLS13KeyShare = Maybe KeyShare
forall a. Maybe a
Nothing
        , stTLS13PreSharedKey :: Maybe PreSharedKey
stTLS13PreSharedKey = Maybe PreSharedKey
forall a. Maybe a
Nothing
        , stTLS13HRR :: Bool
stTLS13HRR = Bool
False
        , stTLS13Cookie :: Maybe Cookie
stTLS13Cookie = Maybe Cookie
forall a. Maybe a
Nothing
        , stTLS13ExporterSecret :: Maybe ByteString
stTLS13ExporterSecret = Maybe ByteString
forall a. Maybe a
Nothing
        , stTLS13ClientSupportsPHA :: Bool
stTLS13ClientSupportsPHA = Bool
False
        }

setVerifyDataForSend :: VerifyData -> TLSSt ()
setVerifyDataForSend :: ByteString -> TLSSt ()
setVerifyDataForSend ByteString
bs = do
    Role
role <- TLSSt Role
getRole
    case Role
role of
        Role
ClientRole -> (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stClientVerifyData = Just bs})
        Role
ServerRole -> (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stServerVerifyData = Just bs})

setVerifyDataForRecv :: VerifyData -> TLSSt ()
setVerifyDataForRecv :: ByteString -> TLSSt ()
setVerifyDataForRecv ByteString
bs = do
    Role
role <- TLSSt Role
getRole
    case Role
role of
        Role
ClientRole -> (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stServerVerifyData = Just bs})
        Role
ServerRole -> (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stClientVerifyData = Just bs})

finishedHandshakeTypeMaterial :: HandshakeType -> Bool
finishedHandshakeTypeMaterial :: HandshakeType -> Bool
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_ClientHello = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_ServerHello = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_Certificate = Bool
True
-- finishedHandshakeTypeMaterial HandshakeType_HelloRequest = False
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_ServerHelloDone = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_ClientKeyXchg = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_ServerKeyXchg = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_CertRequest = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_CertVerify = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_NewSessionTicket = Bool
True
finishedHandshakeTypeMaterial HandshakeType
HandshakeType_Finished = Bool
True
finishedHandshakeTypeMaterial HandshakeType
_ = Bool
False

finishedHandshakeMaterial :: Handshake -> Bool
finishedHandshakeMaterial :: Handshake -> Bool
finishedHandshakeMaterial = HandshakeType -> Bool
finishedHandshakeTypeMaterial (HandshakeType -> Bool)
-> (Handshake -> HandshakeType) -> Handshake -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handshake -> HandshakeType
typeOfHandshake

certVerifyHandshakeTypeMaterial :: HandshakeType -> Bool
certVerifyHandshakeTypeMaterial :: HandshakeType -> Bool
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_ClientHello = Bool
True
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_ServerHello = Bool
True
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_Certificate = Bool
True
-- certVerifyHandshakeTypeMaterial HandshakeType_HelloRequest = False
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_ServerHelloDone = Bool
True
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_ClientKeyXchg = Bool
True
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_ServerKeyXchg = Bool
True
certVerifyHandshakeTypeMaterial HandshakeType
HandshakeType_CertRequest = Bool
True
-- certVerifyHandshakeTypeMaterial HandshakeType_CertVerify = False
-- certVerifyHandshakeTypeMaterial HandshakeType_Finished = False
certVerifyHandshakeTypeMaterial HandshakeType
_ = Bool
False

certVerifyHandshakeMaterial :: Handshake -> Bool
certVerifyHandshakeMaterial :: Handshake -> Bool
certVerifyHandshakeMaterial = HandshakeType -> Bool
certVerifyHandshakeTypeMaterial (HandshakeType -> Bool)
-> (Handshake -> HandshakeType) -> Handshake -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handshake -> HandshakeType
typeOfHandshake

setSession :: Session -> TLSSt ()
setSession :: Session -> TLSSt ()
setSession Session
session = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stSession = session})

getSession :: TLSSt Session
getSession :: TLSSt Session
getSession = (TLSState -> Session) -> TLSSt Session
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Session
stSession

setTLS12SessionResuming :: Bool -> TLSSt ()
setTLS12SessionResuming :: Bool -> TLSSt ()
setTLS12SessionResuming Bool
b = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS12SessionResuming = b})

getTLS12SessionResuming :: TLSSt Bool
getTLS12SessionResuming :: TLSSt Bool
getTLS12SessionResuming = (TLSState -> Bool) -> TLSSt Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Bool
stTLS12SessionResuming

setVersion :: Version -> TLSSt ()
setVersion :: Version -> TLSSt ()
setVersion Version
ver = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stVersion = Just ver})

setVersionIfUnset :: Version -> TLSSt ()
setVersionIfUnset :: Version -> TLSSt ()
setVersionIfUnset Version
ver = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify TLSState -> TLSState
maybeSet
  where
    maybeSet :: TLSState -> TLSState
maybeSet TLSState
st = case TLSState -> Maybe Version
stVersion TLSState
st of
        Maybe Version
Nothing -> TLSState
st{stVersion = Just ver}
        Just Version
_ -> TLSState
st

getVersion :: TLSSt Version
getVersion :: TLSSt Version
getVersion =
    Version -> Maybe Version -> Version
forall a. a -> Maybe a -> a
fromMaybe (HostName -> Version
forall a. HasCallStack => HostName -> a
error HostName
"internal error: version hasn't been set yet")
        (Maybe Version -> Version)
-> TLSSt (Maybe Version) -> TLSSt Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TLSState -> Maybe Version) -> TLSSt (Maybe Version)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe Version
stVersion

getVersionWithDefault :: Version -> TLSSt Version
getVersionWithDefault :: Version -> TLSSt Version
getVersionWithDefault Version
defaultVer = Version -> Maybe Version -> Version
forall a. a -> Maybe a -> a
fromMaybe Version
defaultVer (Maybe Version -> Version)
-> TLSSt (Maybe Version) -> TLSSt Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TLSState -> Maybe Version) -> TLSSt (Maybe Version)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe Version
stVersion

setSecureRenegotiation :: Bool -> TLSSt ()
setSecureRenegotiation :: Bool -> TLSSt ()
setSecureRenegotiation Bool
b = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stSecureRenegotiation = b})

getSecureRenegotiation :: TLSSt Bool
getSecureRenegotiation :: TLSSt Bool
getSecureRenegotiation = (TLSState -> Bool) -> TLSSt Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Bool
stSecureRenegotiation

setExtensionALPN :: Bool -> TLSSt ()
setExtensionALPN :: Bool -> TLSSt ()
setExtensionALPN Bool
b = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stExtensionALPN = b})

getExtensionALPN :: TLSSt Bool
getExtensionALPN :: TLSSt Bool
getExtensionALPN = (TLSState -> Bool) -> TLSSt Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Bool
stExtensionALPN

setNegotiatedProtocol :: B.ByteString -> TLSSt ()
setNegotiatedProtocol :: ByteString -> TLSSt ()
setNegotiatedProtocol ByteString
s = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stNegotiatedProtocol = Just s})

getNegotiatedProtocol :: TLSSt (Maybe B.ByteString)
getNegotiatedProtocol :: TLSSt (Maybe ByteString)
getNegotiatedProtocol = (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stNegotiatedProtocol

setClientALPNSuggest :: [B.ByteString] -> TLSSt ()
setClientALPNSuggest :: [ByteString] -> TLSSt ()
setClientALPNSuggest [ByteString]
ps = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stClientALPNSuggest = Just ps})

getClientALPNSuggest :: TLSSt (Maybe [B.ByteString])
getClientALPNSuggest :: TLSSt (Maybe [ByteString])
getClientALPNSuggest = (TLSState -> Maybe [ByteString]) -> TLSSt (Maybe [ByteString])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe [ByteString]
stClientALPNSuggest

setClientEcPointFormatSuggest :: [EcPointFormat] -> TLSSt ()
setClientEcPointFormatSuggest :: [EcPointFormat] -> TLSSt ()
setClientEcPointFormatSuggest [EcPointFormat]
epf = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stClientEcPointFormatSuggest = Just epf})

getClientEcPointFormatSuggest :: TLSSt (Maybe [EcPointFormat])
getClientEcPointFormatSuggest :: TLSSt (Maybe [EcPointFormat])
getClientEcPointFormatSuggest = (TLSState -> Maybe [EcPointFormat])
-> TLSSt (Maybe [EcPointFormat])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe [EcPointFormat]
stClientEcPointFormatSuggest

setClientCertificateChain :: CertificateChain -> TLSSt ()
setClientCertificateChain :: CertificateChain -> TLSSt ()
setClientCertificateChain CertificateChain
s = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stClientCertificateChain = Just s})

getClientCertificateChain :: TLSSt (Maybe CertificateChain)
getClientCertificateChain :: TLSSt (Maybe CertificateChain)
getClientCertificateChain = (TLSState -> Maybe CertificateChain)
-> TLSSt (Maybe CertificateChain)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe CertificateChain
stClientCertificateChain

setServerCertificateChain :: CertificateChain -> TLSSt ()
setServerCertificateChain :: CertificateChain -> TLSSt ()
setServerCertificateChain CertificateChain
s = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stServerCertificateChain = Just s})

getServerCertificateChain :: TLSSt (Maybe CertificateChain)
getServerCertificateChain :: TLSSt (Maybe CertificateChain)
getServerCertificateChain = (TLSState -> Maybe CertificateChain)
-> TLSSt (Maybe CertificateChain)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe CertificateChain
stServerCertificateChain

setClientSNI :: HostName -> TLSSt ()
setClientSNI :: HostName -> TLSSt ()
setClientSNI HostName
hn = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stClientSNI = Just hn})

getClientSNI :: TLSSt (Maybe HostName)
getClientSNI :: TLSSt (Maybe HostName)
getClientSNI = (TLSState -> Maybe HostName) -> TLSSt (Maybe HostName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe HostName
stClientSNI

getVerifyData :: Role -> TLSSt ByteString
getVerifyData :: Role -> TLSSt ByteString
getVerifyData Role
client = do
    Maybe ByteString
mVerifyData <-
        (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (if Role
client Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole then TLSState -> Maybe ByteString
stClientVerifyData else TLSState -> Maybe ByteString
stServerVerifyData)
    ByteString -> TLSSt ByteString
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> TLSSt ByteString) -> ByteString -> TLSSt ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" Maybe ByteString
mVerifyData

getMyVerifyData :: TLSSt (Maybe ByteString)
getMyVerifyData :: TLSSt (Maybe ByteString)
getMyVerifyData = do
    Role
role <- TLSSt Role
getRole
    if Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole
        then (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stClientVerifyData
        else (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stServerVerifyData

getPeerVerifyData :: TLSSt (Maybe ByteString)
getPeerVerifyData :: TLSSt (Maybe ByteString)
getPeerVerifyData = do
    Role
role <- TLSSt Role
getRole
    if Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole
        then (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stServerVerifyData
        else (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stClientVerifyData

getFirstVerifyData :: TLSSt (Maybe ByteString)
getFirstVerifyData :: TLSSt (Maybe ByteString)
getFirstVerifyData = do
    Version
ver <- TLSSt Version
getVersion
    case Version
ver of
        Version
TLS13 -> (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stServerVerifyData
        Version
_ -> do
            Bool
resuming <- TLSSt Bool
getTLS12SessionResuming
            if Bool
resuming
                then (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stServerVerifyData
                else (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stClientVerifyData

getRole :: TLSSt Role
getRole :: TLSSt Role
getRole = (TLSState -> Role) -> TLSSt Role
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Role
stClientContext

genRandom :: Int -> TLSSt ByteString
genRandom :: Int -> TLSSt ByteString
genRandom Int
n = do
    MonadPseudoRandom StateRNG ByteString -> TLSSt ByteString
forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG (Int -> MonadPseudoRandom StateRNG ByteString
forall byteArray.
ByteArray byteArray =>
Int -> MonadPseudoRandom StateRNG byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
n)

withRNG :: MonadPseudoRandom StateRNG a -> TLSSt a
withRNG :: forall a. MonadPseudoRandom StateRNG a -> TLSSt a
withRNG MonadPseudoRandom StateRNG a
f = do
    TLSState
st <- TLSSt TLSState
forall s (m :: * -> *). MonadState s m => m s
get
    let (a
a, StateRNG
rng') = StateRNG -> MonadPseudoRandom StateRNG a -> (a, StateRNG)
forall a. StateRNG -> MonadPseudoRandom StateRNG a -> (a, StateRNG)
withTLSRNG (TLSState -> StateRNG
stRandomGen TLSState
st) MonadPseudoRandom StateRNG a
f
    TLSState -> TLSSt ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (TLSState
st{stRandomGen = rng'})
    a -> TLSSt a
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

setTLS12SessionTicket :: Ticket -> TLSSt ()
setTLS12SessionTicket :: ByteString -> TLSSt ()
setTLS12SessionTicket ByteString
t = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS12SessionTicket = Just t})

getTLS12SessionTicket :: TLSSt (Maybe Ticket)
getTLS12SessionTicket :: TLSSt (Maybe ByteString)
getTLS12SessionTicket = (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stTLS12SessionTicket

setTLS13ExporterSecret :: ByteString -> TLSSt ()
setTLS13ExporterSecret :: ByteString -> TLSSt ()
setTLS13ExporterSecret ByteString
key = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS13ExporterSecret = Just key})

getTLS13ExporterSecret :: TLSSt (Maybe ByteString)
getTLS13ExporterSecret :: TLSSt (Maybe ByteString)
getTLS13ExporterSecret = (TLSState -> Maybe ByteString) -> TLSSt (Maybe ByteString)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe ByteString
stTLS13ExporterSecret

setTLS13KeyShare :: Maybe KeyShare -> TLSSt ()
setTLS13KeyShare :: Maybe KeyShare -> TLSSt ()
setTLS13KeyShare Maybe KeyShare
mks = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS13KeyShare = mks})

getTLS13KeyShare :: TLSSt (Maybe KeyShare)
getTLS13KeyShare :: TLSSt (Maybe KeyShare)
getTLS13KeyShare = (TLSState -> Maybe KeyShare) -> TLSSt (Maybe KeyShare)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe KeyShare
stTLS13KeyShare

setTLS13PreSharedKey :: Maybe PreSharedKey -> TLSSt ()
setTLS13PreSharedKey :: Maybe PreSharedKey -> TLSSt ()
setTLS13PreSharedKey Maybe PreSharedKey
mpsk = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS13PreSharedKey = mpsk})

getTLS13PreSharedKey :: TLSSt (Maybe PreSharedKey)
getTLS13PreSharedKey :: TLSSt (Maybe PreSharedKey)
getTLS13PreSharedKey = (TLSState -> Maybe PreSharedKey) -> TLSSt (Maybe PreSharedKey)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe PreSharedKey
stTLS13PreSharedKey

setTLS13HRR :: Bool -> TLSSt ()
setTLS13HRR :: Bool -> TLSSt ()
setTLS13HRR Bool
b = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS13HRR = b})

getTLS13HRR :: TLSSt Bool
getTLS13HRR :: TLSSt Bool
getTLS13HRR = (TLSState -> Bool) -> TLSSt Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Bool
stTLS13HRR

setTLS13Cookie :: Maybe Cookie -> TLSSt ()
setTLS13Cookie :: Maybe Cookie -> TLSSt ()
setTLS13Cookie Maybe Cookie
mcookie = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS13Cookie = mcookie})

getTLS13Cookie :: TLSSt (Maybe Cookie)
getTLS13Cookie :: TLSSt (Maybe Cookie)
getTLS13Cookie = (TLSState -> Maybe Cookie) -> TLSSt (Maybe Cookie)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe Cookie
stTLS13Cookie

setTLS13ClientSupportsPHA :: Bool -> TLSSt ()
setTLS13ClientSupportsPHA :: Bool -> TLSSt ()
setTLS13ClientSupportsPHA Bool
b = (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st{stTLS13ClientSupportsPHA = b})

getTLS13ClientSupportsPHA :: TLSSt Bool
getTLS13ClientSupportsPHA :: TLSSt Bool
getTLS13ClientSupportsPHA = (TLSState -> Bool) -> TLSSt Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Bool
stTLS13ClientSupportsPHA