{-# LANGUAGE TypeApplications #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Ssl.Util
  ( -- * Public Key Pinning
    verifyFingerprint,

    -- ** RSA-specific
    rsaFingerprint,
    verifyRsaFingerprint,

    -- * Cipher suites
    rsaCiphers,

    -- * Network
    withVerifiedSslConnection,
  )
where

import Control.Exception
import Data.ByteString.Builder
import Data.Byteable (constEqBytes)
import Data.Dynamic (fromDynamic)
import Data.Time.Clock (getCurrentTime)
import Imports
import Network.HTTP.Client.Internal
import OpenSSL.BN (integerToMPI)
import OpenSSL.EVP.Digest (Digest, digestLBS)
import OpenSSL.EVP.PKey (SomePublicKey, toPublicKey)
import OpenSSL.EVP.Verify (VerifyStatus (..))
import OpenSSL.RSA
import OpenSSL.Session as SSL
import OpenSSL.X509 as X509

-- Cipher Suites ------------------------------------------------------------

-- | A small list of strong cipher suites for use with 'contextSetCiphers'
-- that includes only a selected subset of those based on RSA signatures over
-- ephemeral DH key exchanges (for perfect forward secrecy) and are thus
-- compatible with the RSA public key pinning implemented by the functions
-- 'rsaFingerprint' and 'verifyRsaFingerprint'.
--
-- As in TLS 1.3 [1], only AEAD cipher suites are included, specifically only
-- AES-GCM and CHACHA20-POLY1305. Thereby preference is applied as follows:
--
--  * Elliptic curve DH variants are preferred over "classic" finite
--    field variants for efficiency.
--  * AES variants are preferred over ChaCha20 variants for performance,
--    assuming AES-NI support [2].
--  * AES-256 is preferred over AES-128 "because we can" and performance
--    is not significantly worse, though the comparable key sizes needed for
--    RSA and DH to achieve a comparable level of security to 256 bit
--    symmetric keys are typically not used (see [3]).
--
-- This list requires on both ends of a connection either a TLS 1.2
-- implementation that includes RFC5288 [4] (e.g. OpenSSL 1.0.1+) or a
-- TLS 1.3 implementation that includes at least the mandatory cipher
-- suites. For a list of OpenSSL cipher suites and how they map to TLS
-- names, see also [5].
--
-- References:
--
-- [1] https://tlswg.github.io/tls13-spec/#rfc.appendix.A.4
-- [2] https://calomel.org/aesni_ssl_performance.html
-- [3] https://www.keylength.com/en/3/
-- [4] https://tools.ietf.org/html/rfc5288#section-3
-- [5] https://www.openssl.org/docs/manmaster/apps/ciphers.html
rsaCiphers :: String
rsaCiphers :: String
rsaCiphers =
  String -> ShowS
showString String
"ECDHE-RSA-AES256-GCM-SHA384," -- TLS 1.3
    ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"ECDHE-RSA-AES128-GCM-SHA256," -- TLS 1.3 (mandatory)
    ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"ECDHE-RSA-CHACHA20-POLY1305," -- TLS 1.3
    ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"DHE-RSA-AES256-GCM-SHA384," -- TLS 1.2 / TLS 1.3
    ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"DHE-RSA-AES128-GCM-SHA256," -- TLS 1.2 / TLS 1.3
    ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"DHE-RSA-CHACHA20-POLY1305" -- TLS 1.3
    ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String
""

-- Public Key Pinning ----------------------------------------------------
--
-- Overview: https://www.owasp.org/index.php/Certificate_and_Public_Key_Pinning

-- | Exception thrown by 'verifyFingerprint'.
data PinPubKeyException
  = -- | No peer certificate was found.
    PinMissingCert
  | -- | A peer certificate failed validation (e.g. signature or expiry).
    PinInvalidCert
  | -- | The peer certificate does not contain a valid public key.
    PinInvalidPubKey
  | -- | The public key fingerprint of the peer certificate
    -- did not match any of the pinned fingerprints.
    PinFingerprintMismatch
  deriving (PinPubKeyException -> PinPubKeyException -> Bool
(PinPubKeyException -> PinPubKeyException -> Bool)
-> (PinPubKeyException -> PinPubKeyException -> Bool)
-> Eq PinPubKeyException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PinPubKeyException -> PinPubKeyException -> Bool
== :: PinPubKeyException -> PinPubKeyException -> Bool
$c/= :: PinPubKeyException -> PinPubKeyException -> Bool
/= :: PinPubKeyException -> PinPubKeyException -> Bool
Eq, Int -> PinPubKeyException -> ShowS
[PinPubKeyException] -> ShowS
PinPubKeyException -> String
(Int -> PinPubKeyException -> ShowS)
-> (PinPubKeyException -> String)
-> ([PinPubKeyException] -> ShowS)
-> Show PinPubKeyException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PinPubKeyException -> ShowS
showsPrec :: Int -> PinPubKeyException -> ShowS
$cshow :: PinPubKeyException -> String
show :: PinPubKeyException -> String
$cshowList :: [PinPubKeyException] -> ShowS
showList :: [PinPubKeyException] -> ShowS
Show)

instance Exception PinPubKeyException

-- | Verify the fingerprint of the public key taken from the peer certificate
-- of the given 'SSL' connection against a list of /pinned/ fingerprints.
--
-- To use this function with 'opensslManagerSettingsWith'', the 'VerificationMode'
-- must be set to 'VerifyNone'. Certificate validation is still performed by OpenSSL
-- but the TLS handshake won't be aborted early, giving this function a chance
-- to check for a self-signed certificate after evaluating OpenSSL's verification
-- result using 'getVerifyResult'.
verifyFingerprint ::
  -- | Compute the fingerprint of the peer's public key.
  (SomePublicKey -> IO (Maybe ByteString)) ->
  -- | The list of /pinned/ fingerprints.
  [ByteString] ->
  -- | The 'SSL' connection from which to obtain the peer
  -- certificate and public key.
  SSL ->
  IO ()
verifyFingerprint :: (SomePublicKey -> IO (Maybe ByteString))
-> [ByteString] -> SSL -> IO ()
verifyFingerprint SomePublicKey -> IO (Maybe ByteString)
hash [ByteString]
fprs SSL
ssl = do
  X509
cert <- SSL -> IO (Maybe X509)
SSL.getPeerCertificate SSL
ssl IO (Maybe X509) -> (Maybe X509 -> IO X509) -> IO X509
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO X509 -> (X509 -> IO X509) -> Maybe X509 -> IO X509
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (PinPubKeyException -> IO X509
forall e a. Exception e => e -> IO a
throwIO PinPubKeyException
PinMissingCert) X509 -> IO X509
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  SomePublicKey
pkey <- X509 -> IO SomePublicKey
X509.getPublicKey X509
cert
  Maybe ByteString
mfpr <- SomePublicKey -> IO (Maybe ByteString)
hash SomePublicKey
pkey
  case Maybe ByteString
mfpr of
    Maybe ByteString
Nothing -> PinPubKeyException -> IO ()
forall e a. Exception e => e -> IO a
throwIO PinPubKeyException
PinInvalidPubKey
    Just ByteString
fp -> do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((ByteString -> Bool) -> [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (ByteString -> ByteString -> Bool
forall a. Byteable a => a -> a -> Bool
constEqBytes ByteString
fp) [ByteString]
fprs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        PinPubKeyException -> IO ()
forall e a. Exception e => e -> IO a
throwIO PinPubKeyException
PinFingerprintMismatch
      Bool
vok <- SSL -> IO Bool
SSL.getVerifyResult SSL
ssl
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
vok (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        -- Check if the certificate is self-signed.
        VerifyStatus
self <- X509 -> SomePublicKey -> IO VerifyStatus
forall key. PublicKey key => X509 -> key -> IO VerifyStatus
verifyX509 X509
cert SomePublicKey
pkey
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VerifyStatus
self VerifyStatus -> VerifyStatus -> Bool
forall a. Eq a => a -> a -> Bool
== VerifyStatus
VerifySuccess) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          PinPubKeyException -> IO ()
forall e a. Exception e => e -> IO a
throwIO PinPubKeyException
PinInvalidCert
        -- For completeness, perform a date check as well.
        UTCTime
now <- IO UTCTime
getCurrentTime
        UTCTime
notBefore <- X509 -> IO UTCTime
X509.getNotBefore X509
cert
        UTCTime
notAfter <- X509 -> IO UTCTime
X509.getNotAfter X509
cert
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (UTCTime
now UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
>= UTCTime
notBefore Bool -> Bool -> Bool
&& UTCTime
now UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
<= UTCTime
notAfter) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          PinPubKeyException -> IO ()
forall e a. Exception e => e -> IO a
throwIO PinPubKeyException
PinInvalidCert

-- [Note: Hostname verification]

-- RSA ------------------------------------------------------------------------

-- | Compute a simple (non-standard) fingerprint of an RSA
-- public key for use with 'verifyRsaFingerprint' with the given
-- 'Digest'.
rsaFingerprint :: (RSAKey k) => Digest -> k -> IO ByteString
rsaFingerprint :: forall k. RSAKey k => Digest -> k -> IO ByteString
rsaFingerprint Digest
d k
k = (Builder -> ByteString) -> IO Builder -> IO ByteString
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Digest -> ByteString -> ByteString
digestLBS Digest
d (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
toLazyByteString) (IO Builder -> IO ByteString) -> IO Builder -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
  let s :: Int
s = k -> Int
forall k. RSAKey k => k -> Int
rsaSize k
k
  ByteString
n <- Integer -> IO ByteString
integerToMPI (k -> Integer
forall k. RSAKey k => k -> Integer
rsaN k
k)
  ByteString
e <- Integer -> IO ByteString
integerToMPI (k -> Integer
forall k. RSAKey k => k -> Integer
rsaE k
k)
  Builder -> IO Builder
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Builder -> IO Builder) -> Builder -> IO Builder
forall a b. (a -> b) -> a -> b
$! Int -> Builder
intDec Int
s Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
byteString ByteString
n Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
byteString ByteString
e

-- | 'verifyFingerprint' specialised to 'RSAPubKey's using 'rsaFingerprint'.
verifyRsaFingerprint :: Digest -> [ByteString] -> SSL -> IO ()
verifyRsaFingerprint :: Digest -> [ByteString] -> SSL -> IO ()
verifyRsaFingerprint Digest
d = (SomePublicKey -> IO (Maybe ByteString))
-> [ByteString] -> SSL -> IO ()
verifyFingerprint ((SomePublicKey -> IO (Maybe ByteString))
 -> [ByteString] -> SSL -> IO ())
-> (SomePublicKey -> IO (Maybe ByteString))
-> [ByteString]
-> SSL
-> IO ()
forall a b. (a -> b) -> a -> b
$ \SomePublicKey
pk ->
  case SomePublicKey -> Maybe RSAPubKey
forall k. PublicKey k => SomePublicKey -> Maybe k
toPublicKey SomePublicKey
pk of
    Maybe RSAPubKey
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteString
forall a. Maybe a
Nothing
    Just RSAPubKey
k -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> IO ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Digest -> RSAPubKey -> IO ByteString
forall k. RSAKey k => Digest -> k -> IO ByteString
rsaFingerprint Digest
d (RSAPubKey
k :: RSAPubKey)

-- [Note: Hostname verification]
-- Ideally, we would like to perform proper hostname verification, which
-- is not done automatically by OpenSSL [1]. However, the necessary APIs
-- are not yet available via HsOpenSSL. Note though that public key pinning
-- is already supposed to thwart attacks based on a lack of or incorrect
-- hostname verification (see [2] for many common attacks and mistakes).
--
-- [1] https://wiki.openssl.org/index.php/Hostname_validation
-- [2] https://www.cs.utexas.edu/~shmat/shmat_ccs12.pdf

-- Utilities -----------------------------------------------------------------

-- | Get an SSL connection that has definitely had its fingerprints checked
-- (internally it just grabs a connection from a pool and does verification
-- if it's a fresh one).
--
-- Throws an error for other types of connections.
withVerifiedSslConnection ::
  -- | A function to verify fingerprints given an SSL connection
  (SSL -> IO ()) ->
  Manager ->
  -- | Request builder
  (Request -> Request) ->
  -- | This callback will be passed a modified
  --   request that always uses the verified
  --   connection
  (Request -> IO a) ->
  IO a
withVerifiedSslConnection :: forall a.
(SSL -> IO ())
-> Manager -> (Request -> Request) -> (Request -> IO a) -> IO a
withVerifiedSslConnection SSL -> IO ()
verify Manager
man Request -> Request
reqBuilder Request -> IO a
act =
  Request -> Manager -> Reuse -> (Managed Connection -> IO a) -> IO a
forall a.
Request -> Manager -> Reuse -> (Managed Connection -> IO a) -> IO a
withConnection' Request
req Manager
man Reuse
Reuse ((Managed Connection -> IO a) -> IO a)
-> (Managed Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Managed Connection
mConn -> do
    -- If we see this connection for the first time, verify fingerprints
    let conn :: Connection
conn = Managed Connection -> Connection
forall resource. Managed resource -> resource
managedResource Managed Connection
mConn
        seen :: Bool
seen = Managed Connection -> Bool
forall resource. Managed resource -> Bool
managedReused Managed Connection
mConn
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
seen (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ case forall a. Typeable a => Dynamic -> Maybe a
fromDynamic @SSL (Connection -> Dynamic
connectionRaw Connection
conn) of
      Maybe SSL
Nothing -> String -> IO ()
forall a. HasCallStack => String -> a
error (String
"withVerifiedSslConnection: only SSL allowed: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Request -> String
forall a. Show a => a -> String
show Request
req)
      Just SSL
ssl -> SSL -> IO ()
verify SSL
ssl
    -- Make a request using this connection and return it back to the
    -- pool (that's what 'Reuse' is for)
    Request -> IO a
act Request
req {connectionOverride = Just mConn}
  where
    req :: Request
req = Request -> Request
reqBuilder Request
defaultRequest