{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module SAML2.WebSSO.Test.MockResponse where

import Control.Lens
import Control.Monad.IO.Class
import Data.Generics.Uniplate.Data
import Data.String.Conversions
import Data.UUID as UUID
import GHC.Stack
import SAML2.Util
import SAML2.WebSSO
import Text.Hamlet.XML (xml)
import Text.XML
import Text.XML.DSig

newtype SignedAuthnResponse = SignedAuthnResponse {SignedAuthnResponse -> Document
fromSignedAuthnResponse :: Document}
  deriving (SignedAuthnResponse -> SignedAuthnResponse -> Bool
(SignedAuthnResponse -> SignedAuthnResponse -> Bool)
-> (SignedAuthnResponse -> SignedAuthnResponse -> Bool)
-> Eq SignedAuthnResponse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
== :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
$c/= :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
/= :: SignedAuthnResponse -> SignedAuthnResponse -> Bool
Eq, Int -> SignedAuthnResponse -> ShowS
[SignedAuthnResponse] -> ShowS
SignedAuthnResponse -> String
(Int -> SignedAuthnResponse -> ShowS)
-> (SignedAuthnResponse -> String)
-> ([SignedAuthnResponse] -> ShowS)
-> Show SignedAuthnResponse
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SignedAuthnResponse -> ShowS
showsPrec :: Int -> SignedAuthnResponse -> ShowS
$cshow :: SignedAuthnResponse -> String
show :: SignedAuthnResponse -> String
$cshowList :: [SignedAuthnResponse] -> ShowS
showList :: [SignedAuthnResponse] -> ShowS
Show)

-- | See tests on how this is used.
mkAuthnResponse ::
  (HasCallStack, HasMonadSign m, HasLogger m, HasCreateUUID m, HasNow m) =>
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  Maybe AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponse :: forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasLogger m, HasCreateUUID m,
 HasNow m) =>
SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponse SignPrivCreds
creds IdPConfig extra
idp SPMetadata
spmeta Maybe AuthnRequest
mbareq Bool
grant = do
  subj <- ST -> NameID
unspecifiedNameID (ST -> NameID) -> (UUID -> ST) -> UUID -> NameID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UUID -> ST
UUID.toText (UUID -> NameID) -> m UUID -> m NameID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m UUID
forall (m :: * -> *). HasCreateUUID m => m UUID
createUUID
  mkAuthnResponseWithSubj subj creds idp spmeta mbareq grant

-- | Replace the 'NameID' child of the 'Subject' with a given one.
--
-- (There is some code sharing between this and 'mkAuthnResponseWithRawSubj', but reducing it would
-- make both functions more complex.)
mkAuthnResponseWithSubj ::
  forall extra m.
  (HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
  NameID ->
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  Maybe AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponseWithSubj :: forall extra (m :: * -> *).
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
NameID
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithSubj NameID
subj = ([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithModif [Node] -> [Node]
modif [Node] -> [Node]
forall a. a -> a
id
  where
    modif :: [Node] -> [Node]
modif =
      [[Transformer]] -> [Node] -> [Node]
forall a. Data a => [[Transformer]] -> a -> a
transformBis
        [ [ (Element -> Element) -> Transformer
forall a. Data a => (a -> a) -> Transformer
transformer ((Element -> Element) -> Transformer)
-> (Element -> Element) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
              el :: Element
el@(Element Name
"{urn:oasis:names:tc:SAML:2.0:assertion}Subject" Map Name ST
_ [Node]
_) ->
                case [Node] -> Either String Subject
forall a (m :: * -> *).
(HasXML a, MonadError String m) =>
[Node] -> m a
forall (m :: * -> *). MonadError String m => [Node] -> m Subject
parse [Element -> Node
NodeElement Element
el] of
                  Right (Subject NameID
_ [SubjectConfirmation]
sc) -> HasCallStack => [Node] -> Element
[Node] -> Element
nodesToElem ([Node] -> Element) -> (Subject -> [Node]) -> Subject -> Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subject -> [Node]
forall a. HasXML a => a -> [Node]
render (Subject -> Element) -> Subject -> Element
forall a b. (a -> b) -> a -> b
$ NameID -> [SubjectConfirmation] -> Subject
Subject NameID
subj [SubjectConfirmation]
sc
                  Left String
bad -> String -> Element
forall a. HasCallStack => String -> a
error (String -> Element) -> String -> Element
forall a b. (a -> b) -> a -> b
$ ShowS
forall a. Show a => a -> String
show String
bad
              Element
other -> Element
other
          ]
        ]

-- | Delete all children of 'Subject' and insert some new ones.
mkAuthnResponseWithRawSubj ::
  forall extra m.
  (HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
  [Node] ->
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  Maybe AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponseWithRawSubj :: forall extra (m :: * -> *).
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
[Node]
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithRawSubj [Node]
subj = ([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithModif [Node] -> [Node]
modif [Node] -> [Node]
forall a. a -> a
id
  where
    modif :: [Node] -> [Node]
modif =
      [[Transformer]] -> [Node] -> [Node]
forall a. Data a => [[Transformer]] -> a -> a
transformBis
        [ [ (Element -> Element) -> Transformer
forall a. Data a => (a -> a) -> Transformer
transformer ((Element -> Element) -> Transformer)
-> (Element -> Element) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
              (Element tag :: Name
tag@Name
"{urn:oasis:names:tc:SAML:2.0:assertion}Subject" Map Name ST
attrs [Node]
_) ->
                Name -> Map Name ST -> [Node] -> Element
Element Name
tag Map Name ST
attrs [Node]
subj
              Element
other -> Element
other
          ]
        ]

mkAuthnResponseWithModif ::
  (HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
  ([Node] -> [Node]) ->
  ([Node] -> [Node]) ->
  SignPrivCreds ->
  IdPConfig extra ->
  SPMetadata ->
  Maybe AuthnRequest ->
  Bool ->
  m SignedAuthnResponse
mkAuthnResponseWithModif :: forall (m :: * -> *) extra.
(HasCallStack, HasMonadSign m, HasCreateUUID m, HasNow m) =>
([Node] -> [Node])
-> ([Node] -> [Node])
-> SignPrivCreds
-> IdPConfig extra
-> SPMetadata
-> Maybe AuthnRequest
-> Bool
-> m SignedAuthnResponse
mkAuthnResponseWithModif [Node] -> [Node]
modifUnsignedAssertion [Node] -> [Node]
modifAll SignPrivCreds
creds IdPConfig extra
idp SPMetadata
sp Maybe AuthnRequest
mbauthnreq Bool
grantAccess = do
  let freshNCName :: f ST
freshNCName = (ST
"_" ST -> ST -> ST
forall a. Semigroup a => a -> a -> a
<>) (ST -> ST) -> (UUID -> ST) -> UUID -> ST
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UUID -> ST
UUID.toText (UUID -> ST) -> f UUID -> f ST
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f UUID
forall (m :: * -> *). HasCreateUUID m => m UUID
createUUID
  assertionUuid <- m ST
forall {f :: * -> *}. (Functor f, HasCreateUUID f) => f ST
freshNCName
  respUuid <- freshNCName
  now <- getNow
  let issueInstant = Time -> ST
renderTime Time
now
      expires = Time -> ST
renderTime (Time -> ST) -> Time -> ST
forall a b. (a -> b) -> a -> b
$ NominalDiffTime
3600 NominalDiffTime -> Time -> Time
`addTime` Time
now
      idpissuer :: ST = idp ^. idpMetadata . edIssuer . fromIssuer . to renderURI
      recipient :: ST = sp ^. spResponseURL . to renderURI
      mbspissuer :: Maybe ST = (^. rqIssuer . fromIssuer . to renderURI) <$> mbauthnreq
      mbinResponseTo :: Maybe ST = fromID . (^. rqID) <$> mbauthnreq
      status
        | Bool
grantAccess = ST
"urn:oasis:names:tc:SAML:2.0:status:Success"
        | Bool
otherwise = ST
"urn:oasis:names:tc:SAML:2.0:status:Requester"
  assertion :: [Node] <-
    liftIO $
      signElementIOAt 1 creds . modifUnsignedAssertion . repairNamespaces $
        [xml|
        <Assertion
          xmlns="urn:oasis:names:tc:SAML:2.0:assertion"
          Version="2.0"
          ID="#{assertionUuid}"
          IssueInstant="#{issueInstant}">
            <Issuer>
                #{idpissuer}
            <Subject>
                <NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">
                    #{"emil@email.com"}
                <SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
                    $maybe inResponseTo <- mbinResponseTo
                      <SubjectConfirmationData
                        InResponseTo="#{inResponseTo}"
                        NotOnOrAfter="#{expires}"
                        Recipient="#{recipient}">
                    $nothing
                      <SubjectConfirmationData
                        NotOnOrAfter="#{expires}"
                        Recipient="#{recipient}">
            <Conditions NotBefore="#{issueInstant}" NotOnOrAfter="#{expires}">
                $maybe spissuer <- mbspissuer
                    <AudienceRestriction>
                        <Audience>
                            #{spissuer}
            <AuthnStatement AuthnInstant="#{issueInstant}" SessionIndex="_e9ae1025-bc03-4b5a-943c-c9fcb8730b21">
                <AuthnContext>
                    <AuthnContextClassRef>
                        urn:oasis:names:tc:SAML:2.0:ac:classes:Password
      |]
  let authnResponse :: Element
      [NodeElement authnResponse] =
        -- it's safe to skip the `inResponseTo` attribute here because it is optional and
        -- wire-server ignores it: it's not signed, and shouldn't be considered trustworthy no
        -- matter what it contains.
        modifAll
          . repairNamespaces
          $ [xml|
          <samlp:Response
            xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
            ID="#{respUuid}"
            Version="2.0"
            Destination="#{recipient}"
            IssueInstant="#{issueInstant}">
              <Issuer xmlns="urn:oasis:names:tc:SAML:2.0:assertion">
                  #{idpissuer}
              <samlp:Status>
                  <samlp:StatusCode Value="#{status}">
              ^{assertion}
        |]
  pure . SignedAuthnResponse $ mkDocument authnResponse