{-# LANGUAGE ScopedTypeVariables #-}
module SAML2.Core.Signature
( signSAMLProtocol
, verifySAMLProtocol
, verifySAMLProtocolWithKeys
) where
import Control.Lens ((^.), (?~))
import qualified Data.ByteString.Lazy as BSL
import Data.List.NonEmpty (NonEmpty((:|)))
import Network.URI (URI(uriFragment), nullURI)
import Text.XML.HXT.DOM.TypeDefs
import SAML2.XML
import qualified SAML2.XML.Canonical as C14N
import qualified SAML2.XML.Signature as DS
import qualified SAML2.Core.Protocols as SAMLP
signSAMLProtocol :: SAMLP.SAMLProtocol a => DS.SigningKey -> a -> IO a
signSAMLProtocol :: forall a. SAMLProtocol a => SigningKey -> a -> IO a
signSAMLProtocol SigningKey
sk a
m = do
r <- Reference -> XmlTree -> IO Reference
DS.generateReference DS.Reference
{ referenceId :: Maybe ID
DS.referenceId = Maybe ID
forall a. Maybe a
Nothing
, referenceURI :: Maybe AnyURI
DS.referenceURI = AnyURI -> Maybe AnyURI
forall a. a -> Maybe a
Just AnyURI
nullURI{ uriFragment = '#':SAMLP.protocolID p }
, referenceType :: Maybe AnyURI
DS.referenceType = Maybe AnyURI
forall a. Maybe a
Nothing
, referenceTransforms :: Maybe Transforms
DS.referenceTransforms = Transforms -> Maybe Transforms
forall a. a -> Maybe a
Just (Transforms -> Maybe Transforms) -> Transforms -> Maybe Transforms
forall a b. (a -> b) -> a -> b
$ List1 Transform -> Transforms
DS.Transforms
(List1 Transform -> Transforms) -> List1 Transform -> Transforms
forall a b. (a -> b) -> a -> b
$ TransformAlgorithm -> Transform
DS.simpleTransform TransformAlgorithm
DS.TransformEnvelopedSignature
Transform -> [Transform] -> List1 Transform
forall a. a -> [a] -> NonEmpty a
:| TransformAlgorithm -> Transform
DS.simpleTransform (CanonicalizationAlgorithm -> TransformAlgorithm
DS.TransformCanonicalization (CanonicalizationAlgorithm -> TransformAlgorithm)
-> CanonicalizationAlgorithm -> TransformAlgorithm
forall a b. (a -> b) -> a -> b
$ Bool -> CanonicalizationAlgorithm
C14N.CanonicalXMLExcl10 Bool
False)
Transform -> [Transform] -> [Transform]
forall a. a -> [a] -> [a]
: []
, referenceDigestMethod :: DigestMethod
DS.referenceDigestMethod = DigestAlgorithm -> DigestMethod
DS.simpleDigest DigestAlgorithm
DS.DigestSHA1
, referenceDigestValue :: Base64Binary
DS.referenceDigestValue = ID -> Base64Binary
forall a. HasCallStack => ID -> a
error ID
"signSAMLProtocol: referenceDigestValue"
} (XmlTree -> IO Reference) -> XmlTree -> IO Reference
forall a b. (a -> b) -> a -> b
$ a -> XmlTree
forall a. XmlPickler a => a -> XmlTree
samlToDoc a
m
s' <- DS.generateSignature sk $ maybe DS.SignedInfo
{ DS.signedInfoId = Nothing
, DS.signedInfoCanonicalizationMethod = DS.simpleCanonicalization $ C14N.CanonicalXMLExcl10 False
, DS.signedInfoSignatureMethod = DS.SignatureMethod
{ DS.signatureMethodAlgorithm = Identified $ DS.signingKeySignatureAlgorithm sk
, DS.signatureMethodHMACOutputLength = Nothing
, DS.signatureMethod = []
}
, DS.signedInfoReference = r :| []
} DS.signatureSignedInfo $ SAMLP.protocolSignature p
return $ DS.signature' ?~ s' $ m
where
p :: ProtocolType
p = a
m a -> Getting ProtocolType a ProtocolType -> ProtocolType
forall s a. s -> Getting a s a -> a
^. Getting ProtocolType a ProtocolType
forall a. SAMLProtocol a => Lens' a ProtocolType
Lens' a ProtocolType
SAMLP.samlProtocol'
verifySAMLProtocol :: SAMLP.SAMLProtocol a => BSL.ByteString -> IO a
verifySAMLProtocol :: forall a. SAMLProtocol a => ByteString -> IO a
verifySAMLProtocol ByteString
b = do
x <- IO XmlTree
-> (XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ID -> IO XmlTree
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"invalid XML") XmlTree -> IO XmlTree
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe XmlTree
xmlToDoc ByteString
b
m <- either fail return $ docToSAML x
v <- DS.verifySignatureUnenvelopedSigs mempty (DS.signedID m) x
case v of
Left SignatureError
msg -> ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail (ID -> IO a) -> ID -> IO a
forall a b. (a -> b) -> a -> b
$ ID
"verifySAMLProtocol: invalid or missing signature: " ID -> ID -> ID
forall a. [a] -> [a] -> [a]
++ SignatureError -> ID
forall a. Show a => a -> ID
show SignatureError
msg
Right () -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
m
verifySAMLProtocolWithKeys :: SAMLP.SAMLProtocol a => DS.PublicKeys -> XmlTree -> IO a
verifySAMLProtocolWithKeys :: forall a. SAMLProtocol a => PublicKeys -> XmlTree -> IO a
verifySAMLProtocolWithKeys PublicKeys
pubkeys XmlTree
x = do
m <- (ID -> IO a) -> (a -> IO a) -> Either ID a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ID -> IO a
forall a. ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ID a -> IO a) -> Either ID a -> IO a
forall a b. (a -> b) -> a -> b
$ XmlTree -> Either ID a
forall a. XmlPickler a => XmlTree -> Either ID a
docToSAML XmlTree
x
v <- DS.verifySignatureUnenvelopedSigs pubkeys (DS.signedID m) x
either (fail . show) (const $ pure m) v