{-# LANGUAGE RecordWildCards #-}

module Network.TLS.Handshake.Server.ClientHello (
    processClientHello,
) where

import Network.TLS.Context.Internal
import Network.TLS.Extension
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Process
import Network.TLS.Imports
import Network.TLS.Measurement
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct

processClientHello
    :: ServerParams -> Context -> Handshake -> IO (Version, CH)
processClientHello :: ServerParams -> Context -> Handshake -> IO (Version, CH)
processClientHello ServerParams
sparams Context
ctx clientHello :: Handshake
clientHello@(ClientHello Version
legacyVersion ClientRandom
cran [CompressionID]
compressions ch :: CH
ch@CH{[CipherID]
[ExtensionRaw]
Session
chSession :: Session
chCiphers :: [CipherID]
chExtensions :: [ExtensionRaw]
chSession :: CH -> Session
chCiphers :: CH -> [CipherID]
chExtensions :: CH -> [ExtensionRaw]
..}) = do
    Established
established <- Context -> IO Established
ctxEstablished Context
ctx
    -- renego is not allowed in TLS 1.3
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Established
established Established -> Established -> Bool
forall a. Eq a => a -> a -> Bool
/= Established
NotEstablished) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Version
ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (Version -> TLSSt Version
getVersionWithDefault Version
TLS12)
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
TLS13) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"renegotiation is not allowed in TLS 1.3" AlertDescription
UnexpectedMessage
    -- rejecting client initiated renegotiation to prevent DOS.
    Bool
eof <- Context -> IO Bool
ctxEOF Context
ctx
    let renegotiation :: Bool
renegotiation = Established
established Established -> Established -> Bool
forall a. Eq a => a -> a -> Bool
== Established
Established Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
eof
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
        (Bool
renegotiation Bool -> Bool -> Bool
&& Bool -> Bool
not (Supported -> Bool
supportedClientInitiatedRenegotiation (Supported -> Bool) -> Supported -> Bool
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx))
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
        (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol_Warning String
"renegotiation is not allowed" AlertDescription
NoRenegotiation
    -- check if policy allow this new handshake to happens
    Bool
handshakeAuthorized <- Context -> (Measurement -> IO Bool) -> IO Bool
forall a. Context -> (Measurement -> IO a) -> IO a
withMeasure Context
ctx (ServerHooks -> Measurement -> IO Bool
onNewHandshake (ServerHooks -> Measurement -> IO Bool)
-> ServerHooks -> Measurement -> IO Bool
forall a b. (a -> b) -> a -> b
$ ServerParams -> ServerHooks
serverHooks ServerParams
sparams)
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
        Bool
handshakeAuthorized
        (TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> TLSError
Error_HandshakePolicy String
"server: handshake denied")
    Context -> (Measurement -> Measurement) -> IO ()
updateMeasure Context
ctx Measurement -> Measurement
incrementNbHandshakes

    -- Handle Client hello
    Bool
hrr <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
hrr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> Version -> ClientRandom -> IO ()
startHandshake Context
ctx Version
legacyVersion ClientRandom
cran
    Context -> Handshake -> IO ()
processHandshake12 Context
ctx Handshake
clientHello

    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
legacyVersion Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
/= Version
TLS12) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol (Version -> String
forall a. Show a => a -> String
show Version
legacyVersion String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not supported") AlertDescription
ProtocolVersion

    -- Fallback SCSV: RFC7507
    -- TLS_FALLBACK_SCSV: {0x56, 0x00}
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
        ( Supported -> Bool
supportedFallbackScsv (Context -> Supported
ctxSupported Context
ctx)
            Bool -> Bool -> Bool
&& (CipherID
0x5600 CipherID -> [CipherID] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CipherID]
chCiphers)
            Bool -> Bool -> Bool
&& Version
legacyVersion Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS12
        )
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
        (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"fallback is not allowed" AlertDescription
InappropriateFallback
    -- choosing TLS version
    let clientVersions :: [Version]
clientVersions = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_SupportedVersions [ExtensionRaw]
chExtensions
            Maybe ByteString
-> (ByteString -> Maybe SupportedVersions)
-> Maybe SupportedVersions
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe SupportedVersions
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
            Just (SupportedVersionsClientHello [Version]
vers) -> [Version]
vers -- fixme: vers == []
            Maybe SupportedVersions
_ -> []
        clientVersion :: Version
clientVersion = Version -> Version -> Version
forall a. Ord a => a -> a -> a
min Version
TLS12 Version
legacyVersion
        serverVersions :: [Version]
serverVersions
            | Bool
renegotiation = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13) (Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx)
            | Bool
otherwise = Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
        mVersion :: Maybe Version
mVersion = DebugParams -> Maybe Version
debugVersionForced (DebugParams -> Maybe Version) -> DebugParams -> Maybe Version
forall a b. (a -> b) -> a -> b
$ ServerParams -> DebugParams
serverDebug ServerParams
sparams
    Version
chosenVersion <- case Maybe Version
mVersion of
        Just Version
cver -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
cver
        Maybe Version
Nothing ->
            if (Version
TLS13 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
serverVersions) Bool -> Bool -> Bool
&& [Version]
clientVersions [Version] -> [Version] -> Bool
forall a. Eq a => a -> a -> Bool
/= []
                then case [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 [Version]
clientVersions [Version]
serverVersions of
                    Maybe Version
Nothing ->
                        TLSError -> IO Version
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Version) -> TLSError -> IO Version
forall a b. (a -> b) -> a -> b
$
                            String -> AlertDescription -> TLSError
Error_Protocol
                                (String
"client versions " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Version] -> String
forall a. Show a => a -> String
show [Version]
clientVersions String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not supported")
                                AlertDescription
ProtocolVersion
                    Just Version
v -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
v
                else case Version -> [Version] -> Maybe Version
findHighestVersionFrom Version
clientVersion [Version]
serverVersions of
                    Maybe Version
Nothing ->
                        TLSError -> IO Version
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Version) -> TLSError -> IO Version
forall a b. (a -> b) -> a -> b
$
                            String -> AlertDescription -> TLSError
Error_Protocol
                                (String
"client version " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Version -> String
forall a. Show a => a -> String
show Version
clientVersion String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not supported")
                                AlertDescription
ProtocolVersion
                    Just Version
v -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
v

    -- SNI (Server Name Indication)
    let serverName :: Maybe String
serverName = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_ServerName [ExtensionRaw]
chExtensions Maybe ByteString
-> (ByteString -> Maybe ServerName) -> Maybe ServerName
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe ServerName
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
            Just (ServerName [ServerNameType]
ns) -> [String] -> Maybe String
forall a. [a] -> Maybe a
listToMaybe ((ServerNameType -> Maybe String) -> [ServerNameType] -> [String]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ServerNameType -> Maybe String
toHostName [ServerNameType]
ns)
              where
                toHostName :: ServerNameType -> Maybe String
toHostName (ServerNameHostName String
hostName) = String -> Maybe String
forall a. a -> Maybe a
Just String
hostName
                toHostName (ServerNameOther (CompressionID, ByteString)
_) = Maybe String
forall a. Maybe a
Nothing
            Maybe ServerName
_ -> Maybe String
forall a. Maybe a
Nothing
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
chosenVersion Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
TLS13) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        -- If this is done for TLS12, SSL Labs test does not continue, sigh.
        (CompressionID -> IO ()) -> [CompressionID] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ CompressionID -> IO ()
forall (m :: * -> *). MonadIO m => CompressionID -> m ()
ensureNullCompression [CompressionID]
compressions
    IO () -> (String -> IO ()) -> Maybe String -> IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> (String -> TLSSt ()) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TLSSt ()
setClientSNI) Maybe String
serverName
    (Version, CH) -> IO (Version, CH)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
chosenVersion, CH
ch)
processClientHello ServerParams
_ Context
_ Handshake
_ =
    TLSError -> IO (Version, CH)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO (Version, CH)) -> TLSError -> IO (Version, CH)
forall a b. (a -> b) -> a -> b
$
        String -> AlertDescription -> TLSError
Error_Protocol
            String
"unexpected handshake message received in handshakeServerWith"
            AlertDescription
HandshakeFailure

findHighestVersionFrom :: Version -> [Version] -> Maybe Version
findHighestVersionFrom :: Version -> [Version] -> Maybe Version
findHighestVersionFrom Version
clientVersion [Version]
allowedVersions =
    case (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version
clientVersion Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>=) ([Version] -> [Version]) -> [Version] -> [Version]
forall a b. (a -> b) -> a -> b
$ (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down [Version]
allowedVersions of
        [] -> Maybe Version
forall a. Maybe a
Nothing
        Version
v : [Version]
_ -> Version -> Maybe Version
forall a. a -> Maybe a
Just Version
v

findHighestVersionFrom13 :: [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 :: [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 [Version]
clientVersions [Version]
serverVersions = case [Version]
svs [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
cvs of
    [] -> Maybe Version
forall a. Maybe a
Nothing
    Version
v : [Version]
_ -> Version -> Maybe Version
forall a. a -> Maybe a
Just Version
v
  where
    svs :: [Version]
svs = (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down [Version]
serverVersions
    cvs :: [Version]
cvs = (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down ([Version] -> [Version]) -> [Version] -> [Version]
forall a b. (a -> b) -> a -> b
$ (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS12) [Version]
clientVersions