{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Sometimes incoming requests don't stick to the
-- "no duplicate headers" invariant, for a number
-- of possible reasons (e.g. proxy servers blindly
-- adding headers), or your application (or other
-- middleware) blindly adds headers.
--
-- In those cases, you can use this 'Middleware'
-- to make sure that headers that /can/ be combined
-- /are/ combined. (e.g. applications might only
-- check the first \"Accept\" header and fail, while
-- there might be another one that would match)
module Network.Wai.Middleware.CombineHeaders (
    combineHeaders,
    CombineSettings,
    defaultCombineSettings,
    HeaderMap,
    HandleType,
    defaultHeaderMap,

    -- * Adjusting the settings
    setHeader,
    removeHeader,
    setHeaderMap,
    regular,
    keepOnly,
    setRequestHeaders,
    setResponseHeaders,
) where

import qualified Data.ByteString as B
import qualified Data.List as L (foldl', reverse)
import qualified Data.Map.Strict as M
import Data.Word8 (_comma, _space, _tab)
import Network.HTTP.Types (Header, HeaderName, RequestHeaders)
import qualified Network.HTTP.Types.Header as H
import Network.Wai (Middleware, mapResponseHeaders, requestHeaders)
import Network.Wai.Util (dropWhileEnd)

-- | The mapping of 'HeaderName' to 'HandleType'
type HeaderMap = M.Map HeaderName HandleType

-- | These settings define which headers should be combined,
-- if the combining should happen on incoming (request) headers
-- and if it should happen on outgoing (response) headers.
--
-- Any header you put in the header map *will* be used to
-- combine those headers with commas. There's no check to see
-- if it is a header that allows comma-separated lists, so if
-- you want to combine custom headers, go ahead.
--
-- (You can check the documentation of 'defaultCombineSettings'
-- to see which standard headers are specified to be able to be
-- combined)
--
-- @since 3.1.13.0
data CombineSettings = CombineSettings
    { CombineSettings -> HeaderMap
combineHeaderMap :: HeaderMap
    -- ^ Which headers should be combined? And how? (cf. 'HandleType')
    , CombineSettings -> Bool
combineRequestHeaders :: Bool
    -- ^ Should request headers be combined?
    , CombineSettings -> Bool
combineResponseHeaders :: Bool
    -- ^ Should response headers be combined?
    }
    deriving (CombineSettings -> CombineSettings -> Bool
(CombineSettings -> CombineSettings -> Bool)
-> (CombineSettings -> CombineSettings -> Bool)
-> Eq CombineSettings
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CombineSettings -> CombineSettings -> Bool
== :: CombineSettings -> CombineSettings -> Bool
$c/= :: CombineSettings -> CombineSettings -> Bool
/= :: CombineSettings -> CombineSettings -> Bool
Eq, Int -> CombineSettings -> ShowS
[CombineSettings] -> ShowS
CombineSettings -> String
(Int -> CombineSettings -> ShowS)
-> (CombineSettings -> String)
-> ([CombineSettings] -> ShowS)
-> Show CombineSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CombineSettings -> ShowS
showsPrec :: Int -> CombineSettings -> ShowS
$cshow :: CombineSettings -> String
show :: CombineSettings -> String
$cshowList :: [CombineSettings] -> ShowS
showList :: [CombineSettings] -> ShowS
Show)

-- | Settings that combine request headers,
-- but don't touch response headers.
--
-- All types of headers that /can/ be combined
-- (as defined in the spec) /will/ be combined.
--
-- To be exact, this is the list:
--
-- * Accept
-- * Accept-CH
-- * Accept-Charset
-- * Accept-Encoding
-- * Accept-Language
-- * Accept-Post
-- * Access-Control-Allow-Headers
-- * Access-Control-Allow-Methods
-- * Access-Control-Expose-Headers
-- * Access-Control-Request-Headers
-- * Allow
-- * Alt-Svc @(KeepOnly \"clear\"")@
-- * Cache-Control
-- * Clear-Site-Data @(KeepOnly \"*\")@
-- * Connection
-- * Content-Encoding
-- * Content-Language
-- * Digest
-- * If-Match
-- * If-None-Match @(KeepOnly \"*\")@
-- * Link
-- * Permissions-Policy
-- * TE
-- * Timing-Allow-Origin @(KeepOnly \"*\")@
-- * Trailer
-- * Transfer-Encoding
-- * Upgrade
-- * Via
-- * Vary @(KeepOnly \"*\")@
-- * Want-Digest
--
-- N.B. Any header name that has \"KeepOnly\" after it
-- will be combined like normal, unless one of the values
-- is the one mentioned (\"*\" most of the time), then
-- that value is used and all others are dropped.
--
-- @since 3.1.13.0
defaultCombineSettings :: CombineSettings
defaultCombineSettings :: CombineSettings
defaultCombineSettings =
    CombineSettings
        { combineHeaderMap :: HeaderMap
combineHeaderMap = HeaderMap
defaultHeaderMap
        , combineRequestHeaders :: Bool
combineRequestHeaders = Bool
True
        , combineResponseHeaders :: Bool
combineResponseHeaders = Bool
False
        }

-- | Override the 'HeaderMap' of the 'CombineSettings'
--  (default: 'defaultHeaderMap')
--
-- @since 3.1.13.0
setHeaderMap :: HeaderMap -> CombineSettings -> CombineSettings
setHeaderMap :: HeaderMap -> CombineSettings -> CombineSettings
setHeaderMap HeaderMap
mp CombineSettings
set = CombineSettings
set{combineHeaderMap = mp}

-- | Set whether the combining of headers should be applied to
-- the incoming request headers. (default: True)
--
-- @since 3.1.13.0
setRequestHeaders :: Bool -> CombineSettings -> CombineSettings
setRequestHeaders :: Bool -> CombineSettings -> CombineSettings
setRequestHeaders Bool
b CombineSettings
set = CombineSettings
set{combineRequestHeaders = b}

-- | Set whether the combining of headers should be applied to
-- the outgoing response headers. (default: False)
--
-- @since 3.1.13.0
setResponseHeaders :: Bool -> CombineSettings -> CombineSettings
setResponseHeaders :: Bool -> CombineSettings -> CombineSettings
setResponseHeaders Bool
b CombineSettings
set = CombineSettings
set{combineResponseHeaders = b}

-- | Convenience function to add a header to the header map or,
-- if it is already in the map, to change the 'HandleType'.
--
-- @since 3.1.13.0
setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings
setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings
setHeader HeaderName
name HandleType
typ CombineSettings
settings =
    CombineSettings
settings
        { combineHeaderMap = M.insert name typ $ combineHeaderMap settings
        }

-- | Convenience function to remove a header from the header map.
--
-- @since 3.1.13.0
removeHeader :: HeaderName -> CombineSettings -> CombineSettings
removeHeader :: HeaderName -> CombineSettings -> CombineSettings
removeHeader HeaderName
name CombineSettings
settings =
    CombineSettings
settings
        { combineHeaderMap = M.delete name $ combineHeaderMap settings
        }

-- | This middleware will reorganize the incoming and/or outgoing
-- headers in such a way that it combines any duplicates of
-- headers that, on their own, can normally have more than one
-- value, and any other headers will stay untouched.
--
-- This middleware WILL change the global order of headers
-- (they will be put in alphabetical order), but keep the
-- order of the same type of header. I.e. if there are 3
-- \"Set-Cookie\" headers, the first one will still be first,
-- the second one will still be second, etc. But now they are
-- guaranteed to be next to each other.
--
-- N.B. This 'Middleware' assumes the headers it combines
-- are correctly formatted. If one of the to-be-combined
-- headers is malformed, the new combined header will also
-- (probably) be malformed.
--
-- @since 3.1.13.0
combineHeaders :: CombineSettings -> Middleware
combineHeaders :: CombineSettings -> Middleware
combineHeaders CombineSettings{Bool
HeaderMap
combineHeaderMap :: CombineSettings -> HeaderMap
combineRequestHeaders :: CombineSettings -> Bool
combineResponseHeaders :: CombineSettings -> Bool
combineHeaderMap :: HeaderMap
combineRequestHeaders :: Bool
combineResponseHeaders :: Bool
..} Application
app Request
req Response -> IO ResponseReceived
resFunc =
    Application
app Request
newReq ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
resFunc (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> Response
adjustRes
  where
    newReq :: Request
newReq
        | Bool
combineRequestHeaders = Request
req{requestHeaders = mkNewHeaders oldHeaders}
        | Bool
otherwise = Request
req
    oldHeaders :: [(HeaderName, ByteString)]
oldHeaders = Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
    adjustRes :: Response -> Response
adjustRes
        | Bool
combineResponseHeaders = ([(HeaderName, ByteString)] -> [(HeaderName, ByteString)])
-> Response -> Response
mapResponseHeaders [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
mkNewHeaders
        | Bool
otherwise = Response -> Response
forall a. a -> a
id
    mkNewHeaders :: [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
mkNewHeaders =
        (HeaderName
 -> HeaderHandling
 -> [(HeaderName, ByteString)]
 -> [(HeaderName, ByteString)])
-> [(HeaderName, ByteString)]
-> Map HeaderName HeaderHandling
-> [(HeaderName, ByteString)]
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey' HeaderName
-> HeaderHandling
-> [(HeaderName, ByteString)]
-> [(HeaderName, ByteString)]
finishHeaders [] (Map HeaderName HeaderHandling -> [(HeaderName, ByteString)])
-> ([(HeaderName, ByteString)] -> Map HeaderName HeaderHandling)
-> [(HeaderName, ByteString)]
-> [(HeaderName, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map HeaderName HeaderHandling
 -> (HeaderName, ByteString) -> Map HeaderName HeaderHandling)
-> Map HeaderName HeaderHandling
-> [(HeaderName, ByteString)]
-> Map HeaderName HeaderHandling
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go Map HeaderName HeaderHandling
forall a. Monoid a => a
mempty
    go :: Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go Map HeaderName HeaderHandling
acc hdr :: (HeaderName, ByteString)
hdr@(HeaderName
name, ByteString
_) =
        (Maybe HeaderHandling -> Maybe HeaderHandling)
-> HeaderName
-> Map HeaderName HeaderHandling
-> Map HeaderName HeaderHandling
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter ((HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName, ByteString)
hdr) HeaderName
name Map HeaderName HeaderHandling
acc
    checkHeader :: Header -> Maybe HeaderHandling -> Maybe HeaderHandling
    checkHeader :: (HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName
name, ByteString
newVal) =
        HeaderHandling -> Maybe HeaderHandling
forall a. a -> Maybe a
Just (HeaderHandling -> Maybe HeaderHandling)
-> (Maybe HeaderHandling -> HeaderHandling)
-> Maybe HeaderHandling
-> Maybe HeaderHandling
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
            Maybe HeaderHandling
Nothing -> (HeaderName
name HeaderName -> HeaderMap -> Maybe HandleType
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` HeaderMap
combineHeaderMap, [ByteString
newVal])
            -- Yes, this reverses the order of headers, but these
            -- will be reversed again in 'finishHeaders'
            Just (Maybe HandleType
mHandleType, [ByteString]
hdrs) -> (Maybe HandleType
mHandleType, ByteString
newVal ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
hdrs)

-- | Unpack 'HeaderHandling' back into 'Header's again
finishHeaders
    :: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
finishHeaders :: HeaderName
-> HeaderHandling
-> [(HeaderName, ByteString)]
-> [(HeaderName, ByteString)]
finishHeaders HeaderName
name (Maybe HandleType
shouldCombine, [ByteString]
xs) [(HeaderName, ByteString)]
hdrs =
    case Maybe HandleType
shouldCombine of
        Just HandleType
typ -> (HeaderName
name, HandleType -> ByteString
combinedHeader HandleType
typ) (HeaderName, ByteString)
-> [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
forall a. a -> [a] -> [a]
: [(HeaderName, ByteString)]
hdrs
        Maybe HandleType
Nothing ->
            -- Yes, this reverses the headers, but they
            -- were already reversed by 'checkHeader'
            ([(HeaderName, ByteString)]
 -> ByteString -> [(HeaderName, ByteString)])
-> [(HeaderName, ByteString)]
-> [ByteString]
-> [(HeaderName, ByteString)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\[(HeaderName, ByteString)]
acc ByteString
el -> (HeaderName
name, ByteString
el) (HeaderName, ByteString)
-> [(HeaderName, ByteString)] -> [(HeaderName, ByteString)]
forall a. a -> [a] -> [a]
: [(HeaderName, ByteString)]
acc) [(HeaderName, ByteString)]
hdrs [ByteString]
xs
  where
    combinedHeader :: HandleType -> ByteString
combinedHeader HandleType
Regular = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
    combinedHeader (KeepOnly ByteString
val)
        | ByteString
val ByteString -> [ByteString] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = ByteString
val
        | Bool
otherwise = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
    -- headers were reversed, so do 'reverse' before combining
    combineHdrs :: [ByteString] -> ByteString
combineHdrs = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
", " ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
clean ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
L.reverse
    clean :: ByteString -> ByteString
clean = (Word8 -> Bool) -> ByteString -> ByteString
dropWhileEnd ((Word8 -> Bool) -> ByteString -> ByteString)
-> (Word8 -> Bool) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ \Word8
w -> Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_comma Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_space Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_tab

type HeaderHandling = (Maybe HandleType, [B.ByteString])

-- | Both will concatenate with @,@ (commas), but 'KeepOnly' will drop all
-- values except the given one if present (e.g. in case of wildcards/special values)
--
-- For example: If there are multiple @"Clear-Site-Data"@ headers, but one of
-- them is the wildcard @\"*\"@ value, using @'KeepOnly' "*"@ will cause all
-- others to be dropped and only the wildcard value to remain.
-- (The @\"*\"@ wildcard in this case means /ALL site data/ should be cleared,
-- so no need to include more)
--
-- @since 3.1.13.0
data HandleType
    = Regular
    | KeepOnly B.ByteString
    deriving (HandleType -> HandleType -> Bool
(HandleType -> HandleType -> Bool)
-> (HandleType -> HandleType -> Bool) -> Eq HandleType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandleType -> HandleType -> Bool
== :: HandleType -> HandleType -> Bool
$c/= :: HandleType -> HandleType -> Bool
/= :: HandleType -> HandleType -> Bool
Eq, Int -> HandleType -> ShowS
[HandleType] -> ShowS
HandleType -> String
(Int -> HandleType -> ShowS)
-> (HandleType -> String)
-> ([HandleType] -> ShowS)
-> Show HandleType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandleType -> ShowS
showsPrec :: Int -> HandleType -> ShowS
$cshow :: HandleType -> String
show :: HandleType -> String
$cshowList :: [HandleType] -> ShowS
showList :: [HandleType] -> ShowS
Show)

-- | Use the regular strategy when combining headers.
-- (i.e. merge into one header and separate values with commas)
--
-- @since 3.1.13.0
regular :: HandleType
regular :: HandleType
regular = HandleType
Regular

-- | Use the regular strategy when combining headers,
-- but if the exact supplied 'ByteString' is encountered
-- then discard all other values and only keep that value.
--
-- e.g. @keepOnly "*"@ will drop all other encountered values
--
-- @since 3.1.13.0
keepOnly :: B.ByteString -> HandleType
keepOnly :: ByteString -> HandleType
keepOnly = ByteString -> HandleType
KeepOnly

-- | The default collection of HTTP headers that can be combined
-- in case there are multiples in one request or response.
--
-- See the documentation of 'defaultCombineSettings' for the exact list.
--
-- @since 3.1.13.0
defaultHeaderMap :: HeaderMap
defaultHeaderMap :: HeaderMap
defaultHeaderMap =
    [(HeaderName, HandleType)] -> HeaderMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
        [ (HeaderName
H.hAccept, HandleType
Regular)
        , (HeaderName
"Accept-CH", HandleType
Regular)
        , (HeaderName
H.hAcceptCharset, HandleType
Regular)
        , (HeaderName
H.hAcceptEncoding, HandleType
Regular)
        , (HeaderName
H.hAcceptLanguage, HandleType
Regular)
        , (HeaderName
"Accept-Post", HandleType
Regular)
        , (HeaderName
"Access-Control-Allow-Headers", HandleType
Regular) -- wildcard? yes, but can just add to list
        , (HeaderName
"Access-Control-Allow-Methods", HandleType
Regular) -- wildcard? yes, but can just add to list
        , (HeaderName
"Access-Control-Expose-Headers", HandleType
Regular) -- wildcard? yes, but can just add to list
        , (HeaderName
"Access-Control-Request-Headers", HandleType
Regular)
        , (HeaderName
H.hAllow, HandleType
Regular)
        , (HeaderName
"Alt-Svc", ByteString -> HandleType
KeepOnly ByteString
"clear") -- special "clear" value (if any is "clear", only keep that one)
        , (HeaderName
H.hCacheControl, HandleType
Regular)
        , (HeaderName
"Clear-Site-Data", ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard (if any is "*", only keep that one)
        , -- If "close" and anything else is used together, it's already F-ed,
          -- so just combine them.
          (HeaderName
H.hConnection, HandleType
Regular)
        , (HeaderName
H.hContentEncoding, HandleType
Regular)
        , (HeaderName
H.hContentLanguage, HandleType
Regular)
        , (HeaderName
"Digest", HandleType
Regular)
        , -- We could handle this, but it's experimental AND
          -- will be replaced by "Permissions-Policy"
          -- , "Feature-Policy" -- "semicolon ';' separated"

          (HeaderName
H.hIfMatch, HandleType
Regular)
        , (HeaderName
H.hIfNoneMatch, ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard? (if any is "*", only keep that one)
        , (HeaderName
"Link", HandleType
Regular)
        , (HeaderName
"Permissions-Policy", HandleType
Regular)
        , (HeaderName
H.hTE, HandleType
Regular)
        , (HeaderName
"Timing-Allow-Origin", ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard? (if any is "*", only keep that one)
        , (HeaderName
H.hTrailer, HandleType
Regular)
        , (HeaderName
H.hTransferEncoding, HandleType
Regular)
        , (HeaderName
H.hUpgrade, HandleType
Regular)
        , (HeaderName
H.hVia, HandleType
Regular)
        , (HeaderName
H.hVary, ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard? (if any is "*", only keep that one)
        , (HeaderName
"Want-Digest", HandleType
Regular)
        ]