{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

module SAML2.WebSSO.Test.MockResponse where

import Control.Lens
import Control.Monad.IO.Class
import Data.Generics.Uniplate.Data
import Data.String.Conversions
import Data.UUID as UUID
import GHC.Stack
import SAML2.Util
import SAML2.WebSSO
import Text.Hamlet.XML (xml)
import Text.XML
import Text.XML.DSig

newtype SignedAuthnResponse = SignedAuthnResponse {SignedAuthnResponse -> Document
fromSignedAuthnResponse :: Document}
  deriving (SignedAuthnResponse -> SignedAuthnResponse -> Bool
(SignedAuthnResponse -> SignedAuthnResponse -> Bool)
-> (SignedAuthnResponse -> SignedAuthnResponse -> Bool)
-> Eq SignedAuthnResponse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
== :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
$c/= :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
/= :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
Eq, Int -> SignedAuthnResponse -> ShowS
[SignedAuthnResponse] -> ShowS
SignedAuthnResponse -> String
(Int -> SignedAuthnResponse -> ShowS)
-> (SignedAuthnResponse -> String)
-> ([SignedAuthnResponse] -> ShowS)
-> Show SignedAuthnResponse
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SignedAuthnResponse -> ShowS
showsPrec :: Int -> SignedAuthnResponse -> ShowS
$cshow :: SignedAuthnResponse -> String
show :: SignedAuthnResponse -> String
$cshowList :: [SignedAuthnResponse] -> ShowS
showList :: [SignedAuthnResponse] -> ShowS
Show)

-- | See tests on how this is used.
mkAuthnResponse ::
  (HasCallStack, HasMonadSign m, HasLogger m, HasCreateUUID m, HasNow m) =>
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponse :: forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasLogger m, HasCreateUUID m,
 HasNow m) =>
SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponse SignPrivCreds
creds IdPConfig extra
idp SPMetadata
spmeta AuthnRequest
areq Bool
grant = do
  NameID
subj <- ST -> NameID
unspecifiedNameID (ST -> NameID) -> (UUID -> ST) -> UUID -> NameID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UUID -> ST
UUID.toText (UUID -> NameID) -> m UUID -> m NameID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m UUID
forall (m :: * -> *). HasCreateUUID m => m UUID
createUUID
  NameID
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
forall extra (m :: * -> *).
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
NameID
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithSubj NameID
subj SignPrivCreds
creds IdPConfig extra
idp SPMetadata
spmeta AuthnRequest
areq Bool
grant

-- | Replace the 'NameID' child of the 'Subject' with a given one.
--
-- (There is some code sharing between this and 'mkAuthnResponseWithRawSubj', but reducing it would
-- make both functions more complex.)
mkAuthnResponseWithSubj ::
  forall extra m.
  (HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
  NameID ->
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponseWithSubj :: forall extra (m :: * -> *).
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
NameID
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithSubj NameID
subj = ([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithModif [Node] -> [Node]
modif [Node] -> [Node]
forall a. a -> a
id
  where
    modif :: [Node] -> [Node]
modif =
      [[Transformer]] -> [Node] -> [Node]
forall a. Data a => [[Transformer]] -> a -> a
transformBis
        [ [ (Element -> Element) -> Transformer
forall a. Data a => (a -> a) -> Transformer
transformer ((Element -> Element) -> Transformer)
-> (Element -> Element) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
              el :: Element
el@(Element Name
"{urn:oasis:names:tc:SAML:2.0:assertion}Subject" Map Name ST
_ [Node]
_) ->
                case [Node] -> Either String Subject
forall a (m :: * -> *).
(HasXML a, MonadError String m) =>
[Node] -> m a
forall (m :: * -> *). MonadError String m => [Node] -> m Subject
parse [Element -> Node
NodeElement Element
el] of
                  Right (Subject NameID
_ [SubjectConfirmation]
sc) -> HasCallStack => [Node] -> Element
[Node] -> Element
nodesToElem ([Node] -> Element) -> (Subject -> [Node]) -> Subject -> Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subject -> [Node]
forall a. HasXML a => a -> [Node]
render (Subject -> Element) -> Subject -> Element
forall a b. (a -> b) -> a -> b
$ NameID -> [SubjectConfirmation] -> Subject
Subject NameID
subj [SubjectConfirmation]
sc
                  Left String
bad -> String -> Element
forall a. HasCallStack => String -> a
error (String -> Element) -> String -> Element
forall a b. (a -> b) -> a -> b
$ ShowS
forall a. Show a => a -> String
show String
bad
              Element
other -> Element
other
          ]
        ]

-- | Delete all children of 'Subject' and insert some new ones.
mkAuthnResponseWithRawSubj ::
  forall extra m.
  (HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
  [Node] ->
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponseWithRawSubj :: forall extra (m :: * -> *).
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
[Node]
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithRawSubj [Node]
subj = ([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithModif [Node] -> [Node]
modif [Node] -> [Node]
forall a. a -> a
id
  where
    modif :: [Node] -> [Node]
modif =
      [[Transformer]] -> [Node] -> [Node]
forall a. Data a => [[Transformer]] -> a -> a
transformBis
        [ [ (Element -> Element) -> Transformer
forall a. Data a => (a -> a) -> Transformer
transformer ((Element -> Element) -> Transformer)
-> (Element -> Element) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
              (Element tag :: Name
tag@Name
"{urn:oasis:names:tc:SAML:2.0:assertion}Subject" Map Name ST
attrs [Node]
_) ->
                Name -> Map Name ST -> [Node] -> Element
Element Name
tag Map Name ST
attrs [Node]
subj
              Element
other -> Element
other
          ]
        ]

mkAuthnResponseWithModif ::
  (HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
  ([Node] -> [Node]) ->
  ([Node] -> [Node]) ->
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponseWithModif :: forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithModif [Node] -> [Node]
modifUnsignedAssertion [Node] -> [Node]
modifAll SignPrivCreds
creds IdPConfig extra
idp SPMetadata
sp AuthnRequest
authnreq Bool
grantAccess = do
  let freshNCName :: f ST
freshNCName = (ST
"_" ST -> ST -> ST
forall a. Semigroup a => a -> a -> a
<>) (ST -> ST) -> (UUID -> ST) -> UUID -> ST
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UUID -> ST
UUID.toText (UUID -> ST) -> f UUID -> f ST
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f UUID
forall (m :: * -> *). HasCreateUUID m => m UUID
createUUID
  ST
assertionUuid <- m ST
forall {f :: * -> *}. (Functor f, HasCreateUUID f) => f ST
freshNCName
  ST
respUuid <- m ST
forall {f :: * -> *}. (Functor f, HasCreateUUID f) => f ST
freshNCName
  Time
now <- m Time
forall (m :: * -> *). HasNow m => m Time
getNow
  let issueInstant :: ST
issueInstant = Time -> ST
renderTime Time
now
      expires :: ST
expires = Time -> ST
renderTime (Time -> ST) -> Time -> ST
forall a b. (a -> b) -> a -> b
$ NominalDiffTime
3600 NominalDiffTime -> Time -> Time
`addTime` Time
now
      ST
idpissuer :: ST = IdPConfig extra
idp IdPConfig extra -> Getting ST (IdPConfig extra) ST -> ST
forall s a. s -> Getting a s a -> a
^. (IdPMetadata -> Const ST IdPMetadata)
-> IdPConfig extra -> Const ST (IdPConfig extra)
forall extra (f :: * -> *).
Functor f =>
(IdPMetadata -> f IdPMetadata)
-> IdPConfig extra -> f (IdPConfig extra)
idpMetadata ((IdPMetadata -> Const ST IdPMetadata)
 -> IdPConfig extra -> Const ST (IdPConfig extra))
-> ((ST -> Const ST ST) -> IdPMetadata -> Const ST IdPMetadata)
-> Getting ST (IdPConfig extra) ST
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Issuer -> Const ST Issuer) -> IdPMetadata -> Const ST IdPMetadata
Lens' IdPMetadata Issuer
edIssuer ((Issuer -> Const ST Issuer)
 -> IdPMetadata -> Const ST IdPMetadata)
-> ((ST -> Const ST ST) -> Issuer -> Const ST Issuer)
-> (ST -> Const ST ST)
-> IdPMetadata
-> Const ST IdPMetadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (URI -> Const ST URI) -> Issuer -> Const ST Issuer
Iso' Issuer URI
fromIssuer ((URI -> Const ST URI) -> Issuer -> Const ST Issuer)
-> ((ST -> Const ST ST) -> URI -> Const ST URI)
-> (ST -> Const ST ST)
-> Issuer
-> Const ST Issuer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (URI -> ST) -> (ST -> Const ST ST) -> URI -> Const ST URI
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to URI -> ST
renderURI
      ST
recipient :: ST = SPMetadata
sp SPMetadata -> Getting ST SPMetadata ST -> ST
forall s a. s -> Getting a s a -> a
^. (URI -> Const ST URI) -> SPMetadata -> Const ST SPMetadata
Lens' SPMetadata URI
spResponseURL ((URI -> Const ST URI) -> SPMetadata -> Const ST SPMetadata)
-> ((ST -> Const ST ST) -> URI -> Const ST URI)
-> Getting ST SPMetadata ST
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (URI -> ST) -> (ST -> Const ST ST) -> URI -> Const ST URI
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to URI -> ST
renderURI
      ST
spissuer :: ST = AuthnRequest
authnreq AuthnRequest -> Getting ST AuthnRequest ST -> ST
forall s a. s -> Getting a s a -> a
^. (Issuer -> Const ST Issuer)
-> AuthnRequest -> Const ST AuthnRequest
Lens' AuthnRequest Issuer
rqIssuer ((Issuer -> Const ST Issuer)
 -> AuthnRequest -> Const ST AuthnRequest)
-> ((ST -> Const ST ST) -> Issuer -> Const ST Issuer)
-> Getting ST AuthnRequest ST
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (URI -> Const ST URI) -> Issuer -> Const ST Issuer
Iso' Issuer URI
fromIssuer ((URI -> Const ST URI) -> Issuer -> Const ST Issuer)
-> ((ST -> Const ST ST) -> URI -> Const ST URI)
-> (ST -> Const ST ST)
-> Issuer
-> Const ST Issuer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (URI -> ST) -> (ST -> Const ST ST) -> URI -> Const ST URI
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to URI -> ST
renderURI
      inResponseTo :: ST
inResponseTo = XmlText -> ST
escapeXmlText (XmlText -> ST)
-> (ID AuthnRequest -> XmlText) -> ID AuthnRequest -> ST
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ID AuthnRequest -> XmlText
forall {k} (m :: k). ID m -> XmlText
fromID (ID AuthnRequest -> ST) -> ID AuthnRequest -> ST
forall a b. (a -> b) -> a -> b
$ AuthnRequest
authnreq AuthnRequest
-> Getting (ID AuthnRequest) AuthnRequest (ID AuthnRequest)
-> ID AuthnRequest
forall s a. s -> Getting a s a -> a
^. Getting (ID AuthnRequest) AuthnRequest (ID AuthnRequest)
Lens' AuthnRequest (ID AuthnRequest)
rqID
      status :: ST
status
        | Bool
grantAccess = ST
"urn:oasis:names:tc:SAML:2.0:status:Success"
        | Bool
otherwise = ST
"urn:oasis:names:tc:SAML:2.0:status:Requester"
  [Node]
assertion :: [Node] <-
    IO [Node] -> m [Node]
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [Node] -> m [Node]) -> IO [Node] -> m [Node]
forall a b. (a -> b) -> a -> b
$
      Int -> SignPrivCreds -> [Node] -> IO [Node]
forall (m :: * -> *).
(HasCallStack, HasMonadSign m) =>
Int -> SignPrivCreds -> [Node] -> m [Node]
signElementIOAt Int
1 SignPrivCreds
creds ([Node] -> IO [Node]) -> ([Node] -> [Node]) -> [Node] -> IO [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Node] -> [Node]
modifUnsignedAssertion ([Node] -> [Node]) -> ([Node] -> [Node]) -> [Node] -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => [Node] -> [Node]
[Node] -> [Node]
repairNamespaces ([Node] -> IO [Node]) -> [Node] -> IO [Node]
forall a b. (a -> b) -> a -> b
$
        [xml|
        <Assertion
          xmlns="urn:oasis:names:tc:SAML:2.0:assertion"
          Version="2.0"
          ID="#{assertionUuid}"
          IssueInstant="#{issueInstant}">
            <Issuer>
                #{idpissuer}
            <Subject>
                <NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">
                    #{"emil@email.com"}
                <SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
                    <SubjectConfirmationData
                      InResponseTo="#{inResponseTo}"
                      NotOnOrAfter="#{expires}"
                      Recipient="#{recipient}">
            <Conditions NotBefore="#{issueInstant}" NotOnOrAfter="#{expires}">
                <AudienceRestriction>
                    <Audience>
                        #{spissuer}
            <AuthnStatement AuthnInstant="#{issueInstant}" SessionIndex="_e9ae1025-bc03-4b5a-943c-c9fcb8730b21">
                <AuthnContext>
                    <AuthnContextClassRef>
                        urn:oasis:names:tc:SAML:2.0:ac:classes:Password
      |]
  let authnResponse :: Element
      [NodeElement Element
authnResponse] =
        [Node] -> [Node]
modifAll ([Node] -> [Node]) -> ([Node] -> [Node]) -> [Node] -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => [Node] -> [Node]
[Node] -> [Node]
repairNamespaces ([Node] -> [Node]) -> [Node] -> [Node]
forall a b. (a -> b) -> a -> b
$
          [xml|
          <samlp:Response
            xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
            ID="#{respUuid}"
            Version="2.0"
            Destination="#{recipient}"
            InResponseTo="#{inResponseTo}"
            IssueInstant="#{issueInstant}">
              <Issuer xmlns="urn:oasis:names:tc:SAML:2.0:assertion">
                  #{idpissuer}
              <samlp:Status>
                  <samlp:StatusCode Value="#{status}">
              ^{assertion}
        |]
  SignedAuthnResponse -> m SignedAuthnResponse
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SignedAuthnResponse -> m SignedAuthnResponse)
-> (Document -> SignedAuthnResponse)
-> Document
-> m SignedAuthnResponse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Document -> SignedAuthnResponse
SignedAuthnResponse (Document -> m SignedAuthnResponse)
-> Document -> m SignedAuthnResponse
forall a b. (a -> b) -> a -> b
$ Element -> Document
mkDocument Element
authnResponse