{-# LANGUAGE OverloadedStrings #-}

module Text.XML.Util where

import Control.Monad.Except
import Data.ByteString.Lazy qualified as BSL
import Data.ByteString.Lazy.UTF8 qualified as BSLUTF8
import Data.Char (isSpace)
import Data.Generics.Uniplate.Data qualified as Uniplate
import Data.Kind (Type)
import Data.Map as Map
import Data.Proxy
import Data.String.Conversions
import Data.Text qualified as ST
import Data.Tree.NTree.TypeDefs qualified as HXT
import Data.Typeable
import GHC.Stack
import SAML2.XML qualified as HS
import Text.XML
import Text.XML.HXT.Arrow.Pickle.Xml qualified as XP
import Text.XML.HXT.Core qualified as HXT

die :: forall (a :: Type) b c m. (HasCallStack, Typeable a, Show b, MonadError String m) => Proxy a -> b -> m c
die :: forall a b c (m :: * -> *).
(HasCallStack, Typeable a, Show b, MonadError String m) =>
Proxy a -> b -> m c
die = Maybe String -> Proxy a -> b -> m c
forall a b c (m :: * -> *).
(HasCallStack, Typeable a, Show b, MonadError String m) =>
Maybe String -> Proxy a -> b -> m c
die' Maybe String
forall a. Maybe a
Nothing

die' :: forall (a :: Type) b c m. (HasCallStack, Typeable a, Show b, MonadError String m) => Maybe String -> Proxy a -> b -> m c
die' :: forall a b c (m :: * -> *).
(HasCallStack, Typeable a, Show b, MonadError String m) =>
Maybe String -> Proxy a -> b -> m c
die' Maybe String
mextra Proxy a
Proxy b
msg =
  String -> m c
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m c) -> String -> m c
forall a b. (a -> b) -> a -> b
$
    String
"HasXML: could not parse " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> TypeRep -> String
forall a. Show a => a -> String
show (forall a. Typeable a => a -> TypeRep
typeOf @a a
forall a. HasCallStack => a
undefined) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
": " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> b -> String
forall a. Show a => a -> String
show b
msg String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String
"; " <>) Maybe String
mextra

type Attrs = Map.Map Name ST

elemToNodes :: (HasCallStack) => Element -> [Node]
elemToNodes :: HasCallStack => Element -> [Node]
elemToNodes = (Node -> [Node] -> [Node]
forall a. a -> [a] -> [a]
: []) (Node -> [Node]) -> (Element -> Node) -> Element -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> Node
NodeElement

nodesToElem :: (HasCallStack) => [Node] -> Element
nodesToElem :: HasCallStack => [Node] -> Element
nodesToElem [NodeElement Element
el] = Element
el
nodesToElem [Node]
bad = String -> Element
forall a. HasCallStack => String -> a
error (String -> Element) -> String -> Element
forall a b. (a -> b) -> a -> b
$ [Node] -> String
forall a. Show a => a -> String
show [Node]
bad

docToNodes :: (HasCallStack) => Document -> [Node]
docToNodes :: HasCallStack => Document -> [Node]
docToNodes (Document Prologue
_ Element
el [Miscellaneous]
_) = HasCallStack => Element -> [Node]
Element -> [Node]
elemToNodes Element
el

mkDocument :: Element -> Document
mkDocument :: Element -> Document
mkDocument Element
el = Prologue -> Element -> [Miscellaneous] -> Document
Document Prologue
defPrologue Element
el [Miscellaneous]
defMiscellaneous

defPrologue :: Prologue
defPrologue :: Prologue
defPrologue = [Miscellaneous] -> Maybe Doctype -> [Miscellaneous] -> Prologue
Prologue [] Maybe Doctype
forall a. Maybe a
Nothing []

defMiscellaneous :: [Miscellaneous]
defMiscellaneous :: [Miscellaneous]
defMiscellaneous = []

hxtToConduit :: (MonadError String m) => HXT.XmlTree -> m Document
hxtToConduit :: forall (m :: * -> *). MonadError String m => XmlTree -> m Document
hxtToConduit = (SomeException -> m Document)
-> (Document -> m Document)
-> Either SomeException Document
-> m Document
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> m Document
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m Document)
-> (SomeException -> String) -> SomeException -> m Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"hxtToConduit: parseLBS failed: " <>) (String -> String)
-> (SomeException -> String) -> SomeException -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> String
forall a. Show a => a -> String
show) Document -> m Document
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SomeException Document -> m Document)
-> (XmlTree -> Either SomeException Document)
-> XmlTree
-> m Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseSettings -> ByteString -> Either SomeException Document
parseLBS ParseSettings
forall a. Default a => a
def (ByteString -> Either SomeException Document)
-> (XmlTree -> ByteString)
-> XmlTree
-> Either SomeException Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. XmlTree -> ByteString
ourDocToXMLWithRoot

conduitToHxt :: (MonadError String m) => Document -> m HXT.XmlTree
conduitToHxt :: forall (m :: * -> *). MonadError String m => Document -> m XmlTree
conduitToHxt = (String -> m XmlTree)
-> (XmlTree -> m XmlTree) -> Either String XmlTree -> m XmlTree
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> m XmlTree
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m XmlTree) -> (String -> String) -> String -> m XmlTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"conduitToHxt: xmlToDoc' failed: " <>)) XmlTree -> m XmlTree
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String XmlTree -> m XmlTree)
-> (Document -> Either String XmlTree) -> Document -> m XmlTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String XmlTree
forall (m :: * -> *).
MonadError String m =>
ByteString -> m XmlTree
xmlToDoc' (ByteString -> Either String XmlTree)
-> (Document -> ByteString) -> Document -> Either String XmlTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RenderSettings -> Document -> ByteString
renderLBS RenderSettings
forall a. Default a => a
def {rsXMLDeclaration = False}

samlToConduit :: (MonadError String m, HXT.XmlPickler a) => a -> m Document
samlToConduit :: forall (m :: * -> *) a.
(MonadError String m, XmlPickler a) =>
a -> m Document
samlToConduit = (SomeException -> m Document)
-> (Document -> m Document)
-> Either SomeException Document
-> m Document
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> m Document
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m Document)
-> (SomeException -> String) -> SomeException -> m Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"samlToConduit: parseLBS failed: " <>) (String -> String)
-> (SomeException -> String) -> SomeException -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> String
forall a. Show a => a -> String
show) Document -> m Document
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SomeException Document -> m Document)
-> (a -> Either SomeException Document) -> a -> m Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseSettings -> ByteString -> Either SomeException Document
parseLBS ParseSettings
forall a. Default a => a
def (ByteString -> Either SomeException Document)
-> (a -> ByteString) -> a -> Either SomeException Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ByteString
forall a. XmlPickler a => a -> ByteString
ourSamlToXML

ourSamlToXML :: (XP.XmlPickler a) => a -> BSL.ByteString
ourSamlToXML :: forall a. XmlPickler a => a -> ByteString
ourSamlToXML = HasCallStack => XmlTree -> ByteString
XmlTree -> ByteString
ourDocToXMLWithoutRoot (XmlTree -> ByteString) -> (a -> XmlTree) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> XmlTree
forall a. XmlPickler a => a -> XmlTree
HS.samlToDoc

-- | Direct usage of `xshowBlob` breaks non-Latin-1 encodings (e.g. UTF-8,
-- Unicode)! This helper function works around these issues.
ourDocToXMLWithoutRoot :: (HasCallStack) => HXT.XmlTree -> BSL.ByteString
ourDocToXMLWithoutRoot :: HasCallStack => XmlTree -> ByteString
ourDocToXMLWithoutRoot XmlTree
t = case LA XmlTree String -> XmlTree -> [String]
forall a b. LA a b -> a -> [b]
HXT.runLA (SysConfigList -> LA XmlTree String
forall (a :: * -> * -> *).
ArrowXml a =>
SysConfigList -> a XmlTree String
HXT.writeDocumentToString []) XmlTree
t of
  [String
xmlContent] -> String -> ByteString
BSLUTF8.fromString String
xmlContent
  [String]
other -> String -> ByteString
forall a. HasCallStack => String -> a
error (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ String
"Expected one element. Got: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall a. Show a => a -> String
show [String]
other

-- | Direct usage of `xshowBlob` breaks non-Latin-1 encodings (e.g. UTF-8,
-- Unicode)! This helper function works around these issues.
ourDocToXMLWithRoot :: HXT.XmlTree -> BSL.ByteString
ourDocToXMLWithRoot :: XmlTree -> ByteString
ourDocToXMLWithRoot XmlTree
t = HasCallStack => XmlTree -> ByteString
XmlTree -> ByteString
ourDocToXMLWithoutRoot (XmlTree -> ByteString) -> XmlTree -> ByteString
forall a b. (a -> b) -> a -> b
$ XNode -> NTrees XNode -> XmlTree
forall a. a -> NTrees a -> NTree a
HXT.NTree (String -> XNode
HXT.XText String
"throw-me-away") [XmlTree
t]

-- | This is subtly different from HS.xmlToDoc' and should probably be moved to hsaml2.
xmlToDoc' :: (MonadError String m) => BSL.ByteString -> m HXT.XmlTree
xmlToDoc' :: forall (m :: * -> *).
MonadError String m =>
ByteString -> m XmlTree
xmlToDoc' ByteString
xml = case LA String XmlTree -> String -> NTrees XNode
forall a b. LA a b -> a -> [b]
HXT.runLA LA String XmlTree
forall (a :: * -> * -> *). ArrowXml a => a String XmlTree
HXT.xread (ByteString -> String
forall a b. ConvertibleStrings a b => a -> b
cs ByteString
xml) of
  [HXT.NTree (HXT.XError Int
_errcode String
errmsg) NTrees XNode
_] -> String -> m XmlTree
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
errmsg
  [XmlTree
t] -> XmlTree -> m XmlTree
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure XmlTree
t
  [] -> String -> m XmlTree
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"no root elements"
  bad :: NTrees XNode
bad@(XmlTree
_ : XmlTree
_ : NTrees XNode
_) -> String -> m XmlTree
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m XmlTree) -> String -> m XmlTree
forall a b. (a -> b) -> a -> b
$ String
"more than one root element: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (NTrees XNode -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length NTrees XNode
bad)

-- | Remove all whitespace in the text nodes of the xml document.
stripWhitespace :: Document -> Document
stripWhitespace :: Document -> Document
stripWhitespace =
  [[Transformer]] -> Document -> Document
forall a. Data a => [[Transformer]] -> a -> a
Uniplate.transformBis
    [ [ (Node -> Node) -> Transformer
forall a. Data a => (a -> a) -> Transformer
Uniplate.transformer ((Node -> Node) -> Transformer) -> (Node -> Node) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
          (NodeContent Text
txt) -> Text -> Node
NodeContent (Text -> Node) -> Text -> Node
forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> Text -> Text
ST.filter (Bool -> Bool
not (Bool -> Bool) -> (Char -> Bool) -> Char -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Bool
isSpace) Text
txt
          Node
other -> Node
other
      ],
      [ (Element -> Element) -> Transformer
forall a. Data a => (a -> a) -> Transformer
Uniplate.transformer ((Element -> Element) -> Transformer)
-> (Element -> Element) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
          (Element Name
nm Map Name Text
attrs [Node]
nodes) -> Name -> Map Name Text -> [Node] -> Element
Element Name
nm Map Name Text
attrs ((Node -> Bool) -> [Node] -> [Node]
forall a. (a -> Bool) -> [a] -> [a]
Prelude.filter (Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Text -> Node
NodeContent Text
"") ([Node] -> [Node]) -> [Node] -> [Node]
forall a b. (a -> b) -> a -> b
$ [Node]
nodes)
      ]
    ]

-- | if two content nodes are next to each other, concatenate them into one.  NB: if you call
-- 'stripWhitespace' it should be called *after* 'mergeContentSiblings', or some two words will be
-- merged into one.
mergeContentSiblings :: Document -> Document
mergeContentSiblings :: Document -> Document
mergeContentSiblings =
  [[Transformer]] -> Document -> Document
forall a. Data a => [[Transformer]] -> a -> a
Uniplate.transformBis
    [ [ (Element -> Element) -> Transformer
forall a. Data a => (a -> a) -> Transformer
Uniplate.transformer ((Element -> Element) -> Transformer)
-> (Element -> Element) -> Transformer
forall a b. (a -> b) -> a -> b
$ \case
          (Element Name
nm Map Name Text
attrs [Node]
nodes) -> Name -> Map Name Text -> [Node] -> Element
Element Name
nm Map Name Text
attrs ([Node] -> [Node]
go [Node]
nodes)
      ]
    ]
  where
    go :: [Node] -> [Node]
go [] = []
    go (NodeContent Text
s : NodeContent Text
t : [Node]
xs) = [Node] -> [Node]
go ([Node] -> [Node]) -> [Node] -> [Node]
forall a b. (a -> b) -> a -> b
$ Text -> Node
NodeContent (Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
t) Node -> [Node] -> [Node]
forall a. a -> [a] -> [a]
: [Node]
xs
    go (Node
x : [Node]
xs) = Node
x Node -> [Node] -> [Node]
forall a. a -> [a] -> [a]
: [Node] -> [Node]
go [Node]
xs

normalizeDoc :: Document -> Document
normalizeDoc :: Document -> Document
normalizeDoc = Document -> Document
stripWhitespace (Document -> Document)
-> (Document -> Document) -> Document -> Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Document -> Document
mergeContentSiblings

-- | https://github.com/snoyberg/xml/issues/137
repairNamespaces :: (HasCallStack) => [Node] -> [Node]
repairNamespaces :: HasCallStack => [Node] -> [Node]
repairNamespaces = (Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Node -> Node) -> [Node] -> [Node])
-> (Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> a -> b
$ \case
  NodeElement Element
el -> Element -> Node
NodeElement (Element -> Node) -> Element -> Node
forall a b. (a -> b) -> a -> b
$ HasCallStack => Element -> Element
Element -> Element
repairNamespacesEl Element
el
  Node
other -> Node
other

-- | https://github.com/snoyberg/xml/issues/137
repairNamespacesEl :: (HasCallStack) => Element -> Element
repairNamespacesEl :: HasCallStack => Element -> Element
repairNamespacesEl Element
el = Either SomeException Document -> Element
forall {a}. Show a => Either a Document -> Element
unwrap (Either SomeException Document -> Element)
-> (Element -> Either SomeException Document) -> Element -> Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseSettings -> Text -> Either SomeException Document
parseText ParseSettings
forall a. Default a => a
def (Text -> Either SomeException Document)
-> (Element -> Text) -> Element -> Either SomeException Document
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RenderSettings -> Document -> Text
renderText RenderSettings
forall a. Default a => a
def (Document -> Text) -> (Element -> Document) -> Element -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> Document
mkDocument (Element -> Element) -> Element -> Element
forall a b. (a -> b) -> a -> b
$ Element
el
  where
    unwrap :: Either a Document -> Element
unwrap (Right (Document Prologue
_ Element
el' [Miscellaneous]
_)) = Element
el'
    unwrap (Left a
msg) = String -> Element
forall a. HasCallStack => String -> a
error (String -> Element) -> String -> Element
forall a b. (a -> b) -> a -> b
$ a -> String
forall a. Show a => a -> String
show a
msg