{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- SAML and XML Signature Syntax and Processing
--
-- <https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf saml-core-2.0-os> §5
module SAML2.Core.Signature
  ( signSAMLProtocol
  , verifySAMLProtocol
  , verifySAMLProtocolWithKeys
  ) where

import Control.Lens ((^.), (?~))
import qualified Data.ByteString.Lazy as BSL
import Data.List.NonEmpty (NonEmpty((:|)))
import Network.URI (URI(uriFragment), nullURI)
import Text.XML.HXT.DOM.TypeDefs

import SAML2.XML
import qualified SAML2.XML.Canonical as C14N
import qualified SAML2.XML.Signature as DS
import qualified SAML2.Core.Protocols as SAMLP

signSAMLProtocol :: SAMLP.SAMLProtocol a => DS.SigningKey -> a -> IO a
signSAMLProtocol :: forall a. SAMLProtocol a => SigningKey -> a -> IO a
signSAMLProtocol SigningKey
sk a
m = do
  r <- Reference -> XmlTree -> IO Reference
DS.generateReference DS.Reference
    { referenceId :: Maybe ID
DS.referenceId = Maybe ID
forall a. Maybe a
Nothing
    , referenceURI :: Maybe AnyURI
DS.referenceURI = AnyURI -> Maybe AnyURI
forall a. a -> Maybe a
Just AnyURI
nullURI{ uriFragment = '#':SAMLP.protocolID p }
    , referenceType :: Maybe AnyURI
DS.referenceType = Maybe AnyURI
forall a. Maybe a
Nothing
    , referenceTransforms :: Maybe Transforms
DS.referenceTransforms = Transforms -> Maybe Transforms
forall a. a -> Maybe a
Just (Transforms -> Maybe Transforms) -> Transforms -> Maybe Transforms
forall a b. (a -> b) -> a -> b
$ List1 Transform -> Transforms
DS.Transforms
      (List1 Transform -> Transforms) -> List1 Transform -> Transforms
forall a b. (a -> b) -> a -> b
$  TransformAlgorithm -> Transform
DS.simpleTransform TransformAlgorithm
DS.TransformEnvelopedSignature
      Transform -> [Transform] -> List1 Transform
forall a. a -> [a] -> NonEmpty a
:| TransformAlgorithm -> Transform
DS.simpleTransform (CanonicalizationAlgorithm -> TransformAlgorithm
DS.TransformCanonicalization (CanonicalizationAlgorithm -> TransformAlgorithm)
-> CanonicalizationAlgorithm -> TransformAlgorithm
forall a b. (a -> b) -> a -> b
$ Bool -> CanonicalizationAlgorithm
C14N.CanonicalXMLExcl10 Bool
False)
      Transform -> [Transform] -> [Transform]
forall a. a -> [a] -> [a]
: []
    , referenceDigestMethod :: DigestMethod
DS.referenceDigestMethod = DigestAlgorithm -> DigestMethod
DS.simpleDigest DigestAlgorithm
DS.DigestSHA1
    , referenceDigestValue :: Base64Binary
DS.referenceDigestValue = ID -> Base64Binary
forall a. HasCallStack => ID -> a
error ID
"signSAMLProtocol: referenceDigestValue"
    } (XmlTree -> IO Reference) -> XmlTree -> IO Reference
forall a b. (a -> b) -> a -> b
$ a -> XmlTree
forall a. XmlPickler a => a -> XmlTree
samlToDoc a
m
  s' <- DS.generateSignature sk $ maybe DS.SignedInfo
    { DS.signedInfoId = Nothing
    , DS.signedInfoCanonicalizationMethod = DS.simpleCanonicalization $ C14N.CanonicalXMLExcl10 False
    , DS.signedInfoSignatureMethod = DS.SignatureMethod
      { DS.signatureMethodAlgorithm = Identified $ DS.signingKeySignatureAlgorithm sk
      , DS.signatureMethodHMACOutputLength = Nothing
      , DS.signatureMethod = []
      }
    , DS.signedInfoReference = r :| []
    } DS.signatureSignedInfo $ SAMLP.protocolSignature p
  return $ DS.signature' ?~ s' $ m
  where
  p :: ProtocolType
p = a
m a -> Getting ProtocolType a ProtocolType -> ProtocolType
forall s a. s -> Getting a s a -> a
^. Getting ProtocolType a ProtocolType
forall a. SAMLProtocol a => Lens' a ProtocolType
Lens' a ProtocolType
SAMLP.samlProtocol'

verifySAMLProtocol :: SAMLP.SAMLProtocol a => BSL.ByteString -> IO a
verifySAMLProtocol :: forall a. SAMLProtocol a => ByteString -> IO a
verifySAMLProtocol ByteString
b = do
  x <- IO XmlTree
-> (XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ID -> IO XmlTree
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"invalid XML") XmlTree -> IO XmlTree
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe XmlTree
xmlToDoc ByteString
b
  m <- either fail return $ docToSAML x
  v <- DS.verifySignatureUnenvelopedSigs mempty (DS.signedID m) x
  case v of
    Left SignatureError
msg -> ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail (ID -> IO a) -> ID -> IO a
forall a b. (a -> b) -> a -> b
$ ID
"verifySAMLProtocol: invalid or missing signature: " ID -> ID -> ID
forall a. [a] -> [a] -> [a]
++ SignatureError -> ID
forall a. Show a => a -> ID
show SignatureError
msg
    Right () -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
m

-- | A variant of 'verifySAMLProtocol' that is more symmetric to 'signSAMLProtocol'.  The reason it
-- takes an 'XmlTree' and not an @a@ is that signature verification needs both.
verifySAMLProtocolWithKeys :: SAMLP.SAMLProtocol a => DS.PublicKeys -> XmlTree -> IO a
verifySAMLProtocolWithKeys :: forall a. SAMLProtocol a => PublicKeys -> XmlTree -> IO a
verifySAMLProtocolWithKeys PublicKeys
pubkeys XmlTree
x = do
  m <- (ID -> IO a) -> (a -> IO a) -> Either ID a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ID a -> IO a) -> Either ID a -> IO a
forall a b. (a -> b) -> a -> b
$ XmlTree -> Either ID a
forall a. XmlPickler a => XmlTree -> Either ID a
docToSAML XmlTree
x
  v <- DS.verifySignatureUnenvelopedSigs pubkeys (DS.signedID m) x
  either (fail . show) (const $ pure m) v