{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Context (
    
    TLSParams,
    
    Context (..),
    Hooks (..),
    Established (..),
    RecordLayer (..),
    ctxEOF,
    ctxEstablished,
    withLog,
    ctxWithHooks,
    contextModifyHooks,
    setEOF,
    setEstablished,
    contextFlush,
    contextClose,
    contextSend,
    contextRecv,
    updateMeasure,
    withMeasure,
    withReadLock,
    withWriteLock,
    withStateLock,
    withRWLock,
    
    Information (..),
    contextGetInformation,
    
    contextNew,
    
    contextHookSetHandshakeRecv,
    contextHookSetHandshake13Recv,
    contextHookSetCertificateRecv,
    contextHookSetLogging,
    
    throwCore,
    usingState,
    usingState_,
    runTxRecordState,
    runRxRecordState,
    usingHState,
    getHState,
    getStateRNG,
    tls13orLater,
    getTLSUnique,
    getTLSExporter,
    getTLSServerEndPoint,
    getFinished,
    getPeerFinished,
    TLS13State (..),
    getTLS13State,
    modifyTLS13State,
) where
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.IORef
import Network.TLS.Backend
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake (
    handshakeClient,
    handshakeClientWith,
    handshakeServer,
    handshakeServerWith,
 )
import Network.TLS.Handshake.State13
import Network.TLS.Hooks
import Network.TLS.Imports
import Network.TLS.KeySchedule
import Network.TLS.Measurement
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.PostHandshake (
    postHandshakeAuthClientWith,
    postHandshakeAuthServerWith,
    requestCertificateServer,
 )
import Network.TLS.RNG
import Network.TLS.Record.Reading
import Network.TLS.Record.State
import Network.TLS.Record.Writing
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types (Role (..))
import Network.TLS.X509
class TLSParams a where
    getTLSCommonParams :: a -> CommonParams
    getTLSRole :: a -> Role
    doHandshake :: a -> Context -> IO ()
    doHandshakeWith :: a -> Context -> Handshake -> IO ()
    doRequestCertificate :: a -> Context -> IO Bool
    doPostHandshakeAuthWith :: a -> Context -> Handshake13 -> IO ()
instance TLSParams ClientParams where
    getTLSCommonParams :: ClientParams -> CommonParams
getTLSCommonParams ClientParams
cparams =
        ( ClientParams -> Supported
clientSupported ClientParams
cparams
        , ClientParams -> Shared
clientShared ClientParams
cparams
        , ClientParams -> DebugParams
clientDebug ClientParams
cparams
        )
    getTLSRole :: ClientParams -> Role
getTLSRole ClientParams
_ = Role
ClientRole
    doHandshake :: ClientParams -> Context -> IO ()
doHandshake = ClientParams -> Context -> IO ()
handshakeClient
    doHandshakeWith :: ClientParams -> Context -> Handshake -> IO ()
doHandshakeWith = ClientParams -> Context -> Handshake -> IO ()
handshakeClientWith
    doRequestCertificate :: ClientParams -> Context -> IO Bool
doRequestCertificate ClientParams
_ Context
_ = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    doPostHandshakeAuthWith :: ClientParams -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith = ClientParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthClientWith
instance TLSParams ServerParams where
    getTLSCommonParams :: ServerParams -> CommonParams
getTLSCommonParams ServerParams
sparams =
        ( ServerParams -> Supported
serverSupported ServerParams
sparams
        , ServerParams -> Shared
serverShared ServerParams
sparams
        , ServerParams -> DebugParams
serverDebug ServerParams
sparams
        )
    getTLSRole :: ServerParams -> Role
getTLSRole ServerParams
_ = Role
ServerRole
    doHandshake :: ServerParams -> Context -> IO ()
doHandshake = ServerParams -> Context -> IO ()
handshakeServer
    doHandshakeWith :: ServerParams -> Context -> Handshake -> IO ()
doHandshakeWith = ServerParams -> Context -> Handshake -> IO ()
handshakeServerWith
    doRequestCertificate :: ServerParams -> Context -> IO Bool
doRequestCertificate = ServerParams -> Context -> IO Bool
requestCertificateServer
    doPostHandshakeAuthWith :: ServerParams -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith = ServerParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthServerWith
contextNew
    :: (MonadIO m, HasBackend backend, TLSParams params)
    => backend
    
    -> params
    
    -> m Context
contextNew :: forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew backend
backend params
params = IO Context -> m Context
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Context -> m Context) -> IO Context -> m Context
forall a b. (a -> b) -> a -> b
$ do
    backend -> IO ()
forall a. HasBackend a => a -> IO ()
initializeBackend backend
backend
    let (Supported
supported, Shared
shared, DebugParams
debug) = params -> CommonParams
forall a. TLSParams a => a -> CommonParams
getTLSCommonParams params
params
    Seed
seed <- case DebugParams -> Maybe Seed
debugSeed DebugParams
debug of
        Maybe Seed
Nothing -> do
            Seed
seed <- IO Seed
forall (randomly :: * -> *). MonadRandom randomly => randomly Seed
seedNew
            DebugParams -> Seed -> IO ()
debugPrintSeed DebugParams
debug Seed
seed
            Seed -> IO Seed
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Seed
seed
        Just Seed
determ -> Seed -> IO Seed
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Seed
determ
    let rng :: StateRNG
rng = Seed -> StateRNG
newStateRNG Seed
seed
    let role :: Role
role = params -> Role
forall a. TLSParams a => a -> Role
getTLSRole params
params
        st :: TLSState
st = StateRNG -> Role -> TLSState
newTLSState StateRNG
rng Role
role
    MVar TLSState
tlsstate <- TLSState -> IO (MVar TLSState)
forall a. a -> IO (MVar a)
newMVar TLSState
st
    IORef Bool
eof <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
    IORef Established
established <- Established -> IO (IORef Established)
forall a. a -> IO (IORef a)
newIORef Established
NotEstablished
    IORef Measurement
stats <- Measurement -> IO (IORef Measurement)
forall a. a -> IO (IORef a)
newIORef Measurement
newMeasurement
    IORef Bool
needEmptyPacket <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
    IORef Hooks
hooks <- Hooks -> IO (IORef Hooks)
forall a. a -> IO (IORef a)
newIORef Hooks
defaultHooks
    MVar RecordState
tx <- RecordState -> IO (MVar RecordState)
forall a. a -> IO (MVar a)
newMVar RecordState
newRecordState
    MVar RecordState
rx <- RecordState -> IO (MVar RecordState)
forall a. a -> IO (MVar a)
newMVar RecordState
newRecordState
    MVar (Maybe HandshakeState)
hs <- Maybe HandshakeState -> IO (MVar (Maybe HandshakeState))
forall a. a -> IO (MVar a)
newMVar Maybe HandshakeState
forall a. Maybe a
Nothing
    IORef [PendingRecvAction]
recvActionsRef <- [PendingRecvAction] -> IO (IORef [PendingRecvAction])
forall a. a -> IO (IORef a)
newIORef []
    IORef (Maybe (Context -> IO ()))
sendActionRef <- Maybe (Context -> IO ()) -> IO (IORef (Maybe (Context -> IO ())))
forall a. a -> IO (IORef a)
newIORef Maybe (Context -> IO ())
forall a. Maybe a
Nothing
    IORef [Handshake13]
crs <- [Handshake13] -> IO (IORef [Handshake13])
forall a. a -> IO (IORef a)
newIORef []
    Locks
locks <- MVar () -> MVar () -> MVar () -> Locks
Locks (MVar () -> MVar () -> MVar () -> Locks)
-> IO (MVar ()) -> IO (MVar () -> MVar () -> Locks)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar () IO (MVar () -> MVar () -> Locks)
-> IO (MVar ()) -> IO (MVar () -> Locks)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar () IO (MVar () -> Locks) -> IO (MVar ()) -> IO Locks
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()
    IORef TLS13State
st13ref <- TLS13State -> IO (IORef TLS13State)
forall a. a -> IO (IORef a)
newIORef TLS13State
defaultTLS13State
    let roleParams :: RoleParams
roleParams =
            RoleParams
                { doHandshake_ :: Context -> IO ()
doHandshake_ = params -> Context -> IO ()
forall a. TLSParams a => a -> Context -> IO ()
doHandshake params
params
                , doHandshakeWith_ :: Context -> Handshake -> IO ()
doHandshakeWith_ = params -> Context -> Handshake -> IO ()
forall a. TLSParams a => a -> Context -> Handshake -> IO ()
doHandshakeWith params
params
                , doRequestCertificate_ :: Context -> IO Bool
doRequestCertificate_ = params -> Context -> IO Bool
forall a. TLSParams a => a -> Context -> IO Bool
doRequestCertificate params
params
                , doPostHandshakeAuthWith_ :: Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith_ = params -> Context -> Handshake13 -> IO ()
forall a. TLSParams a => a -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith params
params
                }
    let ctx :: Context
ctx =
            Context
                { ctxBackend :: Backend
ctxBackend = backend -> Backend
forall a. HasBackend a => a -> Backend
getBackend backend
backend
                , ctxShared :: Shared
ctxShared = Shared
shared
                , ctxSupported :: Supported
ctxSupported = Supported
supported
                , ctxTLSState :: MVar TLSState
ctxTLSState = MVar TLSState
tlsstate
                , ctxFragmentSize :: Maybe Int
ctxFragmentSize = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
16384
                , ctxTxRecordState :: MVar RecordState
ctxTxRecordState = MVar RecordState
tx
                , ctxRxRecordState :: MVar RecordState
ctxRxRecordState = MVar RecordState
rx
                , ctxHandshakeState :: MVar (Maybe HandshakeState)
ctxHandshakeState = MVar (Maybe HandshakeState)
hs
                , ctxRoleParams :: RoleParams
ctxRoleParams = RoleParams
roleParams
                , ctxMeasurement :: IORef Measurement
ctxMeasurement = IORef Measurement
stats
                , ctxEOF_ :: IORef Bool
ctxEOF_ = IORef Bool
eof
                , ctxEstablished_ :: IORef Established
ctxEstablished_ = IORef Established
established
                , ctxNeedEmptyPacket :: IORef Bool
ctxNeedEmptyPacket = IORef Bool
needEmptyPacket
                , ctxHooks :: IORef Hooks
ctxHooks = IORef Hooks
hooks
                , ctxLocks :: Locks
ctxLocks = Locks
locks
                , ctxPendingRecvActions :: IORef [PendingRecvAction]
ctxPendingRecvActions = IORef [PendingRecvAction]
recvActionsRef
                , ctxPendingSendAction :: IORef (Maybe (Context -> IO ()))
ctxPendingSendAction = IORef (Maybe (Context -> IO ()))
sendActionRef
                , ctxCertRequests :: IORef [Handshake13]
ctxCertRequests = IORef [Handshake13]
crs
                , ctxKeyLogger :: String -> IO ()
ctxKeyLogger = DebugParams -> String -> IO ()
debugKeyLogger DebugParams
debug
                , ctxRecordLayer :: RecordLayer ByteString
ctxRecordLayer = RecordLayer ByteString
recordLayer
                , ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync Context -> ClientState -> IO ()
forall {m :: * -> *} {p} {p}. Monad m => p -> p -> m ()
syncNoOp Context -> ServerState -> IO ()
forall {m :: * -> *} {p} {p}. Monad m => p -> p -> m ()
syncNoOp
                , ctxQUICMode :: Bool
ctxQUICMode = Bool
False
                , ctxTLS13State :: IORef TLS13State
ctxTLS13State = IORef TLS13State
st13ref
                }
        syncNoOp :: p -> p -> m ()
syncNoOp p
_ p
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        recordLayer :: RecordLayer ByteString
recordLayer =
            RecordLayer
                { recordEncode :: Context -> Record Plaintext -> IO (Either TLSError ByteString)
recordEncode = Context -> Record Plaintext -> IO (Either TLSError ByteString)
encodeRecord
                , recordEncode13 :: Context -> Record Plaintext -> IO (Either TLSError ByteString)
recordEncode13 = Context -> Record Plaintext -> IO (Either TLSError ByteString)
encodeRecord13
                , recordSendBytes :: Context -> ByteString -> IO ()
recordSendBytes = Context -> ByteString -> IO ()
sendBytes
                , recordRecv :: Context -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv = Context -> Int -> IO (Either TLSError (Record Plaintext))
recvRecord
                , recordRecv13 :: Context -> IO (Either TLSError (Record Plaintext))
recordRecv13 = Context -> IO (Either TLSError (Record Plaintext))
recvRecord13
                }
    Context -> IO Context
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Context
ctx
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv Context
context Handshake -> IO Handshake
f =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks{hookRecvHandshake = f})
contextHookSetHandshake13Recv
    :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv Context
context Handshake13 -> IO Handshake13
f =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks{hookRecvHandshake13 = f})
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv Context
context CertificateChain -> IO ()
f =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks{hookRecvCertificates = f})
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging Context
context Logging
loggingCallbacks =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks{hookLogging = loggingCallbacks})
{-# DEPRECATED getFinished "Use getTLSUnique instead" #-}
getFinished :: Context -> IO (Maybe VerifyData)
getFinished :: Context -> IO (Maybe ByteString)
getFinished Context
ctx = Context -> TLSSt (Maybe ByteString) -> IO (Maybe ByteString)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe ByteString)
getMyVerifyData
{-# DEPRECATED getPeerFinished "Use getTLSUnique instead" #-}
getPeerFinished :: Context -> IO (Maybe VerifyData)
getPeerFinished :: Context -> IO (Maybe ByteString)
getPeerFinished Context
ctx = Context -> TLSSt (Maybe ByteString) -> IO (Maybe ByteString)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe ByteString)
getPeerVerifyData
getTLSUnique :: Context -> IO (Maybe ByteString)
getTLSUnique :: Context -> IO (Maybe ByteString)
getTLSUnique Context
ctx = do
    Version
ver <- IO Version -> IO Version
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Version -> IO Version) -> IO Version -> IO Version
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    if Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
TLS12
        then Context -> TLSSt (Maybe ByteString) -> IO (Maybe ByteString)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe ByteString)
getFirstVerifyData
        else Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
getTLSExporter :: Context -> IO (Maybe ByteString)
getTLSExporter :: Context -> IO (Maybe ByteString)
getTLSExporter Context
ctx = do
    Version
ver <- IO Version -> IO Version
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Version -> IO Version) -> IO Version -> IO Version
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    if Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
TLS13
        then Context -> ByteString -> ByteString -> Int -> IO (Maybe ByteString)
exporter Context
ctx ByteString
"EXPORTER-Channel-Binding" ByteString
"" Int
32
        else Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
exporter :: Context -> ByteString -> ByteString -> Int -> IO (Maybe ByteString)
exporter :: Context -> ByteString -> ByteString -> Int -> IO (Maybe ByteString)
exporter Context
ctx ByteString
label ByteString
context Int
outlen = do
    Maybe ByteString
msecret <- Context -> TLSSt (Maybe ByteString) -> IO (Maybe ByteString)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe ByteString)
getTLS13ExporterSecret
    Maybe Cipher
mcipher <- IO (Either TLSError (Maybe Cipher)) -> IO (Maybe Cipher)
forall (m :: * -> *) a. MonadIO m => m (Either TLSError a) -> m a
failOnEitherError (IO (Either TLSError (Maybe Cipher)) -> IO (Maybe Cipher))
-> IO (Either TLSError (Maybe Cipher)) -> IO (Maybe Cipher)
forall a b. (a -> b) -> a -> b
$ Context
-> RecordM (Maybe Cipher) -> IO (Either TLSError (Maybe Cipher))
forall a. Context -> RecordM a -> IO (Either TLSError a)
runRxRecordState Context
ctx (RecordM (Maybe Cipher) -> IO (Either TLSError (Maybe Cipher)))
-> RecordM (Maybe Cipher) -> IO (Either TLSError (Maybe Cipher))
forall a b. (a -> b) -> a -> b
$ (RecordState -> Maybe Cipher) -> RecordM (Maybe Cipher)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RecordState -> Maybe Cipher
stCipher
    Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ case (Maybe ByteString
msecret, Maybe Cipher
mcipher) of
        (Just ByteString
secret, Just Cipher
cipher) ->
            let h :: Hash
h = Cipher -> Hash
cipherHash Cipher
cipher
                secret' :: ByteString
secret' = Hash -> ByteString -> ByteString -> ByteString -> ByteString
deriveSecret Hash
h ByteString
secret ByteString
label ByteString
""
                label' :: ByteString
label' = ByteString
"exporter"
                value' :: ByteString
value' = Hash -> ByteString -> ByteString
hash Hash
h ByteString
context
                key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret' ByteString
label' ByteString
value' Int
outlen
             in ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
key
        (Maybe ByteString, Maybe Cipher)
_ -> Maybe ByteString
forall a. Maybe a
Nothing
getTLSServerEndPoint :: Context -> IO (Maybe ByteString)
getTLSServerEndPoint :: Context -> IO (Maybe ByteString)
getTLSServerEndPoint Context
ctx = do
    Maybe CertificateChain
mcc <- Context
-> TLSSt (Maybe CertificateChain) -> IO (Maybe CertificateChain)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe CertificateChain)
getServerCertificateChain
    case Maybe CertificateChain
mcc of
        Maybe CertificateChain
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Just CertificateChain
cc -> do
            (Hash
usedHash, Cipher
_, CryptLevel
_, ByteString
_) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState Context
ctx
            Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Hash -> ByteString -> ByteString
hash Hash
usedHash (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ CertificateChain -> ByteString
encodeCertificate CertificateChain
cc