{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- SAML and XML Signature Syntax and Processing
--
-- <https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf saml-core-2.0-os> §5
module SAML2.Core.Signature
  ( signSAMLProtocol
  , verifySAMLProtocol
  , verifySAMLProtocolWithKeys
  ) where

import Control.Lens ((^.), (?~))
import qualified Data.ByteString.Lazy as BSL
import Data.List.NonEmpty (NonEmpty((:|)))
import Network.URI (URI(uriFragment), nullURI)
import Text.XML.HXT.DOM.TypeDefs

import SAML2.XML
import qualified SAML2.XML.Canonical as C14N
import qualified SAML2.XML.Signature as DS
import qualified SAML2.Core.Protocols as SAMLP

signSAMLProtocol :: SAMLP.SAMLProtocol a => DS.SigningKey -> a -> IO a
signSAMLProtocol :: forall a. SAMLProtocol a => SigningKey -> a -> IO a
signSAMLProtocol SigningKey
sk a
m = do
  Reference
r <- Reference -> XmlTree -> IO Reference
DS.generateReference DS.Reference
    { referenceId :: Maybe ID
DS.referenceId = Maybe ID
forall a. Maybe a
Nothing
    , referenceURI :: Maybe AnyURI
DS.referenceURI = AnyURI -> Maybe AnyURI
forall a. a -> Maybe a
Just AnyURI
nullURI{ uriFragment = '#':SAMLP.protocolID p }
    , referenceType :: Maybe AnyURI
DS.referenceType = Maybe AnyURI
forall a. Maybe a
Nothing
    , referenceTransforms :: Maybe Transforms
DS.referenceTransforms = Transforms -> Maybe Transforms
forall a. a -> Maybe a
Just (Transforms -> Maybe Transforms) -> Transforms -> Maybe Transforms
forall a b. (a -> b) -> a -> b
$ List1 Transform -> Transforms
DS.Transforms
      (List1 Transform -> Transforms) -> List1 Transform -> Transforms
forall a b. (a -> b) -> a -> b
$  TransformAlgorithm -> Transform
DS.simpleTransform TransformAlgorithm
DS.TransformEnvelopedSignature
      Transform -> [Transform] -> List1 Transform
forall a. a -> [a] -> NonEmpty a
:| TransformAlgorithm -> Transform
DS.simpleTransform (CanonicalizationAlgorithm -> TransformAlgorithm
DS.TransformCanonicalization (CanonicalizationAlgorithm -> TransformAlgorithm)
-> CanonicalizationAlgorithm -> TransformAlgorithm
forall a b. (a -> b) -> a -> b
$ Bool -> CanonicalizationAlgorithm
C14N.CanonicalXMLExcl10 Bool
False)
      Transform -> [Transform] -> [Transform]
forall a. a -> [a] -> [a]
: []
    , referenceDigestMethod :: DigestMethod
DS.referenceDigestMethod = DigestAlgorithm -> DigestMethod
DS.simpleDigest DigestAlgorithm
DS.DigestSHA1
    , referenceDigestValue :: Base64Binary
DS.referenceDigestValue = ID -> Base64Binary
forall a. HasCallStack => ID -> a
error ID
"signSAMLProtocol: referenceDigestValue"
    } (XmlTree -> IO Reference) -> XmlTree -> IO Reference
forall a b. (a -> b) -> a -> b
$ a -> XmlTree
forall a. XmlPickler a => a -> XmlTree
samlToDoc a
m
  Signature
s' <- SigningKey -> SignedInfo -> IO Signature
DS.generateSignature SigningKey
sk (SignedInfo -> IO Signature) -> SignedInfo -> IO Signature
forall a b. (a -> b) -> a -> b
$ SignedInfo
-> (Signature -> SignedInfo) -> Maybe Signature -> SignedInfo
forall b a. b -> (a -> b) -> Maybe a -> b
maybe DS.SignedInfo
    { signedInfoId :: Maybe ID
DS.signedInfoId = Maybe ID
forall a. Maybe a
Nothing
    , signedInfoCanonicalizationMethod :: CanonicalizationMethod
DS.signedInfoCanonicalizationMethod = CanonicalizationAlgorithm -> CanonicalizationMethod
DS.simpleCanonicalization (CanonicalizationAlgorithm -> CanonicalizationMethod)
-> CanonicalizationAlgorithm -> CanonicalizationMethod
forall a b. (a -> b) -> a -> b
$ Bool -> CanonicalizationAlgorithm
C14N.CanonicalXMLExcl10 Bool
False
    , signedInfoSignatureMethod :: SignatureMethod
DS.signedInfoSignatureMethod = DS.SignatureMethod
      { signatureMethodAlgorithm :: IdentifiedURI SignatureAlgorithm
DS.signatureMethodAlgorithm = SignatureAlgorithm -> IdentifiedURI SignatureAlgorithm
forall b a. a -> Identified b a
Identified (SignatureAlgorithm -> IdentifiedURI SignatureAlgorithm)
-> SignatureAlgorithm -> IdentifiedURI SignatureAlgorithm
forall a b. (a -> b) -> a -> b
$ SigningKey -> SignatureAlgorithm
DS.signingKeySignatureAlgorithm SigningKey
sk
      , signatureMethodHMACOutputLength :: Maybe Int
DS.signatureMethodHMACOutputLength = Maybe Int
forall a. Maybe a
Nothing
      , signatureMethod :: Nodes
DS.signatureMethod = []
      }
    , signedInfoReference :: List1 Reference
DS.signedInfoReference = Reference
r Reference -> [Reference] -> List1 Reference
forall a. a -> [a] -> NonEmpty a
:| []
    } Signature -> SignedInfo
DS.signatureSignedInfo (Maybe Signature -> SignedInfo) -> Maybe Signature -> SignedInfo
forall a b. (a -> b) -> a -> b
$ ProtocolType -> Maybe Signature
SAMLP.protocolSignature ProtocolType
p
  a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$ (Maybe Signature -> Identity (Maybe Signature)) -> a -> Identity a
forall a. Signable a => Lens' a (Maybe Signature)
Lens' a (Maybe Signature)
DS.signature' ((Maybe Signature -> Identity (Maybe Signature))
 -> a -> Identity a)
-> Signature -> a -> a
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Signature
s' (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
m
  where
  p :: ProtocolType
p = a
m a -> Getting ProtocolType a ProtocolType -> ProtocolType
forall s a. s -> Getting a s a -> a
^. Getting ProtocolType a ProtocolType
forall a. SAMLProtocol a => Lens' a ProtocolType
Lens' a ProtocolType
SAMLP.samlProtocol'

verifySAMLProtocol :: SAMLP.SAMLProtocol a => BSL.ByteString -> IO a
verifySAMLProtocol :: forall a. SAMLProtocol a => ByteString -> IO a
verifySAMLProtocol ByteString
b = do
  XmlTree
x <- IO XmlTree
-> (XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ID -> IO XmlTree
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"invalid XML") XmlTree -> IO XmlTree
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe XmlTree
xmlToDoc ByteString
b
  a
m <- (ID -> IO a) -> (a -> IO a) -> Either ID a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ID a -> IO a) -> Either ID a -> IO a
forall a b. (a -> b) -> a -> b
$ XmlTree -> Either ID a
forall a. XmlPickler a => XmlTree -> Either ID a
docToSAML XmlTree
x
  Either SignatureError ()
v <- PublicKeys -> ID -> XmlTree -> IO (Either SignatureError ())
DS.verifySignatureUnenvelopedSigs PublicKeys
forall a. Monoid a => a
mempty (a -> ID
forall a. Signable a => a -> ID
DS.signedID a
m) XmlTree
x
  case Either SignatureError ()
v of
    Left SignatureError
msg -> ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail (ID -> IO a) -> ID -> IO a
forall a b. (a -> b) -> a -> b
$ ID
"verifySAMLProtocol: invalid or missing signature: " ID -> ID -> ID
forall a. [a] -> [a] -> [a]
++ SignatureError -> ID
forall a. Show a => a -> ID
show SignatureError
msg
    Right () -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
m

-- | A variant of 'verifySAMLProtocol' that is more symmetric to 'signSAMLProtocol'.  The reason it
-- takes an 'XmlTree' and not an @a@ is that signature verification needs both.
verifySAMLProtocolWithKeys :: SAMLP.SAMLProtocol a => DS.PublicKeys -> XmlTree -> IO a
verifySAMLProtocolWithKeys :: forall a. SAMLProtocol a => PublicKeys -> XmlTree -> IO a
verifySAMLProtocolWithKeys PublicKeys
pubkeys XmlTree
x = do
  a
m <- (ID -> IO a) -> (a -> IO a) -> Either ID a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ID a -> IO a) -> Either ID a -> IO a
forall a b. (a -> b) -> a -> b
$ XmlTree -> Either ID a
forall a. XmlPickler a => XmlTree -> Either ID a
docToSAML XmlTree
x
  Either SignatureError ()
v <- PublicKeys -> ID -> XmlTree -> IO (Either SignatureError ())
DS.verifySignatureUnenvelopedSigs PublicKeys
pubkeys (a -> ID
forall a. Signable a => a -> ID
DS.signedID a
m) XmlTree
x
  (SignatureError -> IO a)
-> (() -> IO a) -> Either SignatureError () -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail (ID -> IO a) -> (SignatureError -> ID) -> SignatureError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SignatureError -> ID
forall a. Show a => a -> ID
show) (IO a -> () -> IO a
forall a b. a -> b -> a
const (IO a -> () -> IO a) -> IO a -> () -> IO a
forall a b. (a -> b) -> a -> b
$ a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
m) Either SignatureError ()
v