{-# LANGUAGE OverloadedStrings #-}

module SAML2.WebSSO.Servant
  ( module SAML2.WebSSO.Servant,
    module SAML2.WebSSO.Servant.CPP,
  )
where

import Data.EitherR
import Data.Function
import Data.List (nubBy)
import qualified Data.Map as Map
import Data.Proxy
import Data.String.Conversions
import Network.HTTP.Media ((//))
import Network.HTTP.Types
import qualified Network.HTTP.Types.Header as HttpTypes
import Network.Wai hiding (Response)
import Network.Wai.Internal as Wai
import SAML2.WebSSO.Servant.CPP
import SAML2.WebSSO.XML
import Servant.API as Servant hiding (MkLink, URI (..))
import Text.Hamlet.XML
import Text.XML

type GetRedir = Verb 'GET 307

type PostRedir = Verb 'POST 303

-- | There is a tiny package `servant-xml`, which does essentially what this type and its
-- 'Mime{,Un}Render' instances do, but inlining this package seems easier.
data XML

instance Accept XML where
  contentType :: Proxy XML -> MediaType
contentType Proxy XML
Proxy = ByteString
"application" ByteString -> ByteString -> MediaType
// ByteString
"xml"

instance {-# OVERLAPPABLE #-} HasXMLRoot a => MimeRender XML a where
  mimeRender :: Proxy XML -> a -> ByteString
mimeRender Proxy XML
Proxy = LT -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (LT -> ByteString) -> (a -> LT) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> LT
forall a. HasXMLRoot a => a -> LT
encode

instance {-# OVERLAPPABLE #-} HasXMLRoot a => MimeUnrender XML a where
  mimeUnrender :: Proxy XML -> ByteString -> Either String a
mimeUnrender Proxy XML
Proxy = (String -> String) -> Either String a -> Either String a
forall a b r. (a -> b) -> Either a r -> Either b r
fmapL String -> String
forall a. Show a => a -> String
show (Either String a -> Either String a)
-> (ByteString -> Either String a) -> ByteString -> Either String a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LT -> Either String a
forall (m :: * -> *) a.
(HasXMLRoot a, MonadError String m) =>
LT -> m a
decode (LT -> Either String a)
-> (ByteString -> LT) -> ByteString -> Either String a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> LT
forall a b. ConvertibleStrings a b => a -> b
cs

data HTML

instance Accept HTML where
  contentType :: Proxy HTML -> MediaType
contentType Proxy HTML
Proxy = ByteString
"text" ByteString -> ByteString -> MediaType
// ByteString
"html"

instance MimeRender HTML ST where
  mimeRender :: Proxy HTML -> Text -> ByteString
mimeRender Proxy HTML
Proxy Text
msg =
    [Node] -> ByteString
mkHtml
      [xml|
      <body>
        <p>
          #{msg}
    |]

mkHtml :: [Node] -> LBS
mkHtml :: [Node] -> ByteString
mkHtml [Node]
nodes = RenderSettings -> Document -> ByteString
renderLBS RenderSettings
forall a. Default a => a
def Document
doc
  where
    doc :: Document
doc = Prologue -> Element -> [Miscellaneous] -> Document
Document ([Miscellaneous] -> Maybe Doctype -> [Miscellaneous] -> Prologue
Prologue [] (Doctype -> Maybe Doctype
forall a. a -> Maybe a
Just Doctype
doctyp) []) Element
root []
    doctyp :: Doctype
doctyp = Text -> Maybe ExternalID -> Doctype
Doctype Text
"html" (ExternalID -> Maybe ExternalID
forall a. a -> Maybe a
Just (ExternalID -> Maybe ExternalID) -> ExternalID -> Maybe ExternalID
forall a b. (a -> b) -> a -> b
$ Text -> Text -> ExternalID
PublicID Text
"-//W3C//DTD XHTML 1.1//EN" Text
"http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd")
    root :: Element
root = Name -> Map Name Text -> [Node] -> Element
Element Name
"html" Map Name Text
forall {k} {a}. (Ord k, IsString k, IsString a) => Map k a
rootattr [Node]
nodes
    rootattr :: Map k a
rootattr = [(k, a)] -> Map k a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(k
"xmlns", a
"http://www.w3.org/1999/xhtml"), (k
"xml:lang", a
"en")]

-- | [3.5.5.1] Caching
setHttpCachePolicy :: Middleware
setHttpCachePolicy :: Middleware
setHttpCachePolicy Application
ap Request
rq Response -> IO ResponseReceived
respond = Application
ap Request
rq ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ResponseHeaders -> Response -> Response
addHeadersToResponse ResponseHeaders
httpCachePolicy
  where
    httpCachePolicy :: HttpTypes.ResponseHeaders
    httpCachePolicy :: ResponseHeaders
httpCachePolicy = [(HeaderName
"Cache-Control", ByteString
"no-cache, no-store"), (HeaderName
"Pragma", ByteString
"no-cache")]
    addHeadersToResponse :: HttpTypes.ResponseHeaders -> Wai.Response -> Wai.Response
    addHeadersToResponse :: ResponseHeaders -> Response -> Response
addHeadersToResponse ResponseHeaders
extraHeaders Response
resp = case Response
resp of
      ResponseFile Status
status ResponseHeaders
hdrs String
filepath Maybe FilePart
part -> Status -> ResponseHeaders -> String -> Maybe FilePart -> Response
ResponseFile Status
status (ResponseHeaders -> ResponseHeaders
updH ResponseHeaders
hdrs) String
filepath Maybe FilePart
part
      ResponseBuilder Status
status ResponseHeaders
hdrs Builder
builder -> Status -> ResponseHeaders -> Builder -> Response
ResponseBuilder Status
status (ResponseHeaders -> ResponseHeaders
updH ResponseHeaders
hdrs) Builder
builder
      ResponseStream Status
status ResponseHeaders
hdrs StreamingBody
body -> Status -> ResponseHeaders -> StreamingBody -> Response
ResponseStream Status
status (ResponseHeaders -> ResponseHeaders
updH ResponseHeaders
hdrs) StreamingBody
body
      ResponseRaw IO ByteString -> (ByteString -> IO ()) -> IO ()
action Response
resp' ->
        (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response -> Response
ResponseRaw IO ByteString -> (ByteString -> IO ()) -> IO ()
action (Response -> Response) -> Response -> Response
forall a b. (a -> b) -> a -> b
$
          ResponseHeaders -> Response -> Response
addHeadersToResponse ResponseHeaders
extraHeaders Response
resp'
      where
        updH :: ResponseHeaders -> ResponseHeaders
updH ResponseHeaders
hdrs = ((HeaderName, ByteString) -> (HeaderName, ByteString) -> Bool)
-> ResponseHeaders -> ResponseHeaders
forall a. (a -> a -> Bool) -> [a] -> [a]
nubBy (HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
(==) (HeaderName -> HeaderName -> Bool)
-> ((HeaderName, ByteString) -> HeaderName)
-> (HeaderName, ByteString)
-> (HeaderName, ByteString)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (HeaderName, ByteString) -> HeaderName
forall a b. (a, b) -> a
fst) (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ ResponseHeaders
extraHeaders ResponseHeaders -> ResponseHeaders -> ResponseHeaders
forall a. [a] -> [a] -> [a]
++ ResponseHeaders
hdrs