{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE StrictData #-}

-- FUTUREWORK: set `-XNoDeriveAnyClass`.
module SAML2.WebSSO.Config where

import Control.Exception
import Control.Lens hiding (Level, (.=))
import Control.Monad (when)
import Data.Aeson
import Data.String.Conversions
import qualified Data.Yaml as Yaml
import GHC.Generics
import SAML2.WebSSO.Types
import System.Environment
import System.FilePath
import System.IO
import URI.ByteString
import URI.ByteString.QQ

----------------------------------------------------------------------
-- data types

data Config = Config
  { Config -> Level
_cfgLogLevel :: Level,
    Config -> FilePath
_cfgSPHost :: String,
    Config -> Int
_cfgSPPort :: Int,
    Config -> URI
_cfgSPAppURI :: URI,
    Config -> URI
_cfgSPSsoURI :: URI,
    Config -> [ContactPerson]
_cfgContacts :: [ContactPerson]
  }
  deriving (Config -> Config -> Bool
(Config -> Config -> Bool)
-> (Config -> Config -> Bool) -> Eq Config
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
/= :: Config -> Config -> Bool
Eq, Int -> Config -> ShowS
[Config] -> ShowS
Config -> FilePath
(Int -> Config -> ShowS)
-> (Config -> FilePath) -> ([Config] -> ShowS) -> Show Config
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Config -> ShowS
showsPrec :: Int -> Config -> ShowS
$cshow :: Config -> FilePath
show :: Config -> FilePath
$cshowList :: [Config] -> ShowS
showList :: [Config] -> ShowS
Show, (forall x. Config -> Rep Config x)
-> (forall x. Rep Config x -> Config) -> Generic Config
forall x. Rep Config x -> Config
forall x. Config -> Rep Config x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Config -> Rep Config x
from :: forall x. Config -> Rep Config x
$cto :: forall x. Rep Config x -> Config
to :: forall x. Rep Config x -> Config
Generic)

-- | this looks exactly like tinylog's type, but we redefine it here to avoid the dependency.
data Level = Trace | Debug | Info | Warn | Error | Fatal
  deriving (Level -> Level -> Bool
(Level -> Level -> Bool) -> (Level -> Level -> Bool) -> Eq Level
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Level -> Level -> Bool
== :: Level -> Level -> Bool
$c/= :: Level -> Level -> Bool
/= :: Level -> Level -> Bool
Eq, Eq Level
Eq Level =>
(Level -> Level -> Ordering)
-> (Level -> Level -> Bool)
-> (Level -> Level -> Bool)
-> (Level -> Level -> Bool)
-> (Level -> Level -> Bool)
-> (Level -> Level -> Level)
-> (Level -> Level -> Level)
-> Ord Level
Level -> Level -> Bool
Level -> Level -> Ordering
Level -> Level -> Level
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Level -> Level -> Ordering
compare :: Level -> Level -> Ordering
$c< :: Level -> Level -> Bool
< :: Level -> Level -> Bool
$c<= :: Level -> Level -> Bool
<= :: Level -> Level -> Bool
$c> :: Level -> Level -> Bool
> :: Level -> Level -> Bool
$c>= :: Level -> Level -> Bool
>= :: Level -> Level -> Bool
$cmax :: Level -> Level -> Level
max :: Level -> Level -> Level
$cmin :: Level -> Level -> Level
min :: Level -> Level -> Level
Ord, Int -> Level -> ShowS
[Level] -> ShowS
Level -> FilePath
(Int -> Level -> ShowS)
-> (Level -> FilePath) -> ([Level] -> ShowS) -> Show Level
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Level -> ShowS
showsPrec :: Int -> Level -> ShowS
$cshow :: Level -> FilePath
show :: Level -> FilePath
$cshowList :: [Level] -> ShowS
showList :: [Level] -> ShowS
Show, Int -> Level
Level -> Int
Level -> [Level]
Level -> Level
Level -> Level -> [Level]
Level -> Level -> Level -> [Level]
(Level -> Level)
-> (Level -> Level)
-> (Int -> Level)
-> (Level -> Int)
-> (Level -> [Level])
-> (Level -> Level -> [Level])
-> (Level -> Level -> [Level])
-> (Level -> Level -> Level -> [Level])
-> Enum Level
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Level -> Level
succ :: Level -> Level
$cpred :: Level -> Level
pred :: Level -> Level
$ctoEnum :: Int -> Level
toEnum :: Int -> Level
$cfromEnum :: Level -> Int
fromEnum :: Level -> Int
$cenumFrom :: Level -> [Level]
enumFrom :: Level -> [Level]
$cenumFromThen :: Level -> Level -> [Level]
enumFromThen :: Level -> Level -> [Level]
$cenumFromTo :: Level -> Level -> [Level]
enumFromTo :: Level -> Level -> [Level]
$cenumFromThenTo :: Level -> Level -> Level -> [Level]
enumFromThenTo :: Level -> Level -> Level -> [Level]
Enum, Level
Level -> Level -> Bounded Level
forall a. a -> a -> Bounded a
$cminBound :: Level
minBound :: Level
$cmaxBound :: Level
maxBound :: Level
Bounded, (forall x. Level -> Rep Level x)
-> (forall x. Rep Level x -> Level) -> Generic Level
forall x. Rep Level x -> Level
forall x. Level -> Rep Level x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Level -> Rep Level x
from :: forall x. Level -> Rep Level x
$cto :: forall x. Rep Level x -> Level
to :: forall x. Rep Level x -> Level
Generic, Value -> Parser [Level]
Value -> Parser Level
(Value -> Parser Level)
-> (Value -> Parser [Level]) -> FromJSON Level
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
$cparseJSON :: Value -> Parser Level
parseJSON :: Value -> Parser Level
$cparseJSONList :: Value -> Parser [Level]
parseJSONList :: Value -> Parser [Level]
FromJSON, [Level] -> Value
[Level] -> Encoding
Level -> Value
Level -> Encoding
(Level -> Value)
-> (Level -> Encoding)
-> ([Level] -> Value)
-> ([Level] -> Encoding)
-> ToJSON Level
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
$ctoJSON :: Level -> Value
toJSON :: Level -> Value
$ctoEncoding :: Level -> Encoding
toEncoding :: Level -> Encoding
$ctoJSONList :: [Level] -> Value
toJSONList :: [Level] -> Value
$ctoEncodingList :: [Level] -> Encoding
toEncodingList :: [Level] -> Encoding
ToJSON)

----------------------------------------------------------------------
-- instances

makeLenses ''Config

instance ToJSON Config where
  toJSON :: Config -> Value
toJSON Config {Int
FilePath
[ContactPerson]
URI
Level
_cfgLogLevel :: Config -> Level
_cfgSPHost :: Config -> FilePath
_cfgSPPort :: Config -> Int
_cfgSPAppURI :: Config -> URI
_cfgSPSsoURI :: Config -> URI
_cfgContacts :: Config -> [ContactPerson]
_cfgLogLevel :: Level
_cfgSPHost :: FilePath
_cfgSPPort :: Int
_cfgSPAppURI :: URI
_cfgSPSsoURI :: URI
_cfgContacts :: [ContactPerson]
..} =
    [Pair] -> Value
object ([Pair] -> Value) -> [Pair] -> Value
forall a b. (a -> b) -> a -> b
$
      [ Key
"logLevel" Key -> Level -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Level
_cfgLogLevel,
        Key
"spHost" Key -> FilePath -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= FilePath
_cfgSPHost,
        Key
"spPort" Key -> Int -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Int
_cfgSPPort,
        Key
"spAppUri" Key -> URI -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= URI
_cfgSPAppURI,
        Key
"spSsoUri" Key -> URI -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= URI
_cfgSPSsoURI,
        Key
"contacts" Key -> [ContactPerson] -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= [ContactPerson]
_cfgContacts
      ]

instance FromJSON Config where
  parseJSON :: Value -> Parser Config
parseJSON = FilePath -> (Object -> Parser Config) -> Value -> Parser Config
forall a. FilePath -> (Object -> Parser a) -> Value -> Parser a
withObject FilePath
"Config" ((Object -> Parser Config) -> Value -> Parser Config)
-> (Object -> Parser Config) -> Value -> Parser Config
forall a b. (a -> b) -> a -> b
$ \Object
obj -> do
    Level
_cfgLogLevel <- Object
obj Object -> Key -> Parser Level
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"logLevel"
    FilePath
_cfgSPHost <- Object
obj Object -> Key -> Parser FilePath
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"spHost"
    Int
_cfgSPPort <- Object
obj Object -> Key -> Parser Int
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"spPort"
    URI
_cfgSPAppURI <- Object
obj Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"spAppUri"
    URI
_cfgSPSsoURI <- Object
obj Object -> Key -> Parser URI
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"spSsoUri"
    [ContactPerson]
_cfgContacts <- Object
obj Object -> Key -> Parser [ContactPerson]
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"contacts"
    Config -> Parser Config
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Config {Int
FilePath
[ContactPerson]
URI
Level
_cfgLogLevel :: Level
_cfgSPHost :: FilePath
_cfgSPPort :: Int
_cfgSPAppURI :: URI
_cfgSPSsoURI :: URI
_cfgContacts :: [ContactPerson]
_cfgLogLevel :: Level
_cfgSPHost :: FilePath
_cfgSPPort :: Int
_cfgSPAppURI :: URI
_cfgSPSsoURI :: URI
_cfgContacts :: [ContactPerson]
..}

----------------------------------------------------------------------
-- default

fallbackConfig :: Config
fallbackConfig :: Config
fallbackConfig =
  Config
    { _cfgLogLevel :: Level
_cfgLogLevel = Level
Debug,
      _cfgSPHost :: FilePath
_cfgSPHost = FilePath
"localhost",
      _cfgSPPort :: Int
_cfgSPPort = Int
8081,
      _cfgSPAppURI :: URI
_cfgSPAppURI = [uri|https://example-sp.com/landing|],
      _cfgSPSsoURI :: URI
_cfgSPSsoURI = [uri|https://example-sp.com/sso|],
      _cfgContacts :: [ContactPerson]
_cfgContacts = [ContactPerson
fallbackContact]
    }

fallbackContact :: ContactPerson
fallbackContact :: ContactPerson
fallbackContact =
  ContactType
-> Maybe XmlText
-> Maybe XmlText
-> Maybe XmlText
-> Maybe URI
-> Maybe XmlText
-> ContactPerson
ContactPerson
    ContactType
ContactSupport
    (XmlText -> Maybe XmlText
forall a. a -> Maybe a
Just (XmlText -> Maybe XmlText) -> XmlText -> Maybe XmlText
forall a b. (a -> b) -> a -> b
$ ST -> XmlText
mkXmlText ST
"evil corp.")
    (XmlText -> Maybe XmlText
forall a. a -> Maybe a
Just (XmlText -> Maybe XmlText) -> XmlText -> Maybe XmlText
forall a b. (a -> b) -> a -> b
$ ST -> XmlText
mkXmlText ST
"Dr.")
    (XmlText -> Maybe XmlText
forall a. a -> Maybe a
Just (XmlText -> Maybe XmlText) -> XmlText -> Maybe XmlText
forall a b. (a -> b) -> a -> b
$ ST -> XmlText
mkXmlText ST
"Girlfriend")
    (URI -> Maybe URI
forall a. a -> Maybe a
Just [uri|email:president@evil.corp|])
    (XmlText -> Maybe XmlText
forall a. a -> Maybe a
Just (XmlText -> Maybe XmlText) -> XmlText -> Maybe XmlText
forall a b. (a -> b) -> a -> b
$ ST -> XmlText
mkXmlText ST
"+314159265")

----------------------------------------------------------------------
-- IO

configIO :: IO Config
configIO :: IO Config
configIO = FilePath -> IO Config
readConfig (FilePath -> IO Config) -> IO FilePath -> IO Config
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO FilePath
configFilePath

configFilePath :: IO FilePath
configFilePath :: IO FilePath
configFilePath = (FilePath -> ShowS
</> FilePath
"server.yaml") ShowS -> IO FilePath -> IO FilePath
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath -> IO FilePath
getEnv FilePath
"SAML2_WEB_SSO_ROOT"

readConfig :: FilePath -> IO Config
readConfig :: FilePath -> IO Config
readConfig FilePath
filepath =
  (ParseException -> IO Config)
-> (Config -> IO Config)
-> Either ParseException Config
-> IO Config
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\ParseException
err -> Config
fallbackConfig Config -> IO () -> IO Config
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ParseException -> IO ()
warn ParseException
err) (\Config
cnf -> Config -> IO ()
info Config
cnf IO () -> IO Config -> IO Config
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Config -> IO Config
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Config
cnf)
    (Either ParseException Config -> IO Config)
-> IO (Either ParseException Config) -> IO Config
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FilePath -> IO (Either ParseException Config)
forall a. FromJSON a => FilePath -> IO (Either ParseException a)
Yaml.decodeFileEither FilePath
filepath
  where
    info :: Config -> IO ()
    info :: Config -> IO ()
info Config
cfg =
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Config
cfg Config -> Getting Level Config Level -> Level
forall s a. s -> Getting a s a -> a
^. Getting Level Config Level
Lens' Config Level
cfgLogLevel Level -> Level -> Bool
forall a. Ord a => a -> a -> Bool
<= Level
Info) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        Handle -> FilePath -> IO ()
hPutStrLn Handle
stderr (FilePath -> IO ()) -> (Config -> FilePath) -> Config -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FilePath
"\n>>> server config:\n" FilePath -> ShowS
forall a. Semigroup a => a -> a -> a
<>) ShowS -> (Config -> FilePath) -> Config -> FilePath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> FilePath
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> FilePath)
-> (Config -> ByteString) -> Config -> FilePath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Config -> ByteString
forall a. ToJSON a => a -> ByteString
Yaml.encode (Config -> IO ()) -> Config -> IO ()
forall a b. (a -> b) -> a -> b
$
          Config
cfg
    warn :: Yaml.ParseException -> IO ()
    warn :: ParseException -> IO ()
warn ParseException
err =
      Handle -> FilePath -> IO ()
hPutStrLn Handle
stderr (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$
        FilePath
"*** could not read config file: " FilePath -> ShowS
forall a. Semigroup a => a -> a -> a
<> ParseException -> FilePath
forall a. Show a => a -> FilePath
show ParseException
err
          FilePath -> ShowS
forall a. Semigroup a => a -> a -> a
<> FilePath
"  using default!  see SAML.WebSSO.Config for details!"

-- | Convenience function to write a config file if you don't already have one.  Writes to
-- `$SAML2_WEB_SSO_ROOT/server.yaml`.  Warns if env does not contain the root.
writeConfig :: Config -> IO ()
writeConfig :: Config -> IO ()
writeConfig Config
cfg = (FilePath -> Config -> IO ()
forall a. ToJSON a => FilePath -> a -> IO ()
`Yaml.encodeFile` Config
cfg) (FilePath -> IO ()) -> IO FilePath -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO FilePath
configFilePath

idpConfigIO :: Config -> IO [IdPConfig_]
idpConfigIO :: Config -> IO [IdPConfig_]
idpConfigIO Config
cfg = Config -> FilePath -> IO [IdPConfig_]
readIdPConfig Config
cfg (FilePath -> IO [IdPConfig_]) -> IO FilePath -> IO [IdPConfig_]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO FilePath
idpConfigFilePath

idpConfigFilePath :: IO FilePath
idpConfigFilePath :: IO FilePath
idpConfigFilePath = (FilePath -> ShowS
</> FilePath
"idps.yaml") ShowS -> IO FilePath -> IO FilePath
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath -> IO FilePath
getEnv FilePath
"SAML2_WEB_SSO_ROOT"

readIdPConfig :: Config -> FilePath -> IO [IdPConfig_]
readIdPConfig :: Config -> FilePath -> IO [IdPConfig_]
readIdPConfig Config
cfg FilePath
filepath =
  (ParseException -> IO [IdPConfig_])
-> ([IdPConfig_] -> IO [IdPConfig_])
-> Either ParseException [IdPConfig_]
-> IO [IdPConfig_]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ErrorCall -> IO [IdPConfig_]
forall e a. Exception e => e -> IO a
throwIO (ErrorCall -> IO [IdPConfig_])
-> (ParseException -> ErrorCall)
-> ParseException
-> IO [IdPConfig_]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> ErrorCall
ErrorCall (FilePath -> ErrorCall)
-> (ParseException -> FilePath) -> ParseException -> ErrorCall
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseException -> FilePath
forall a. Show a => a -> FilePath
show) (\[IdPConfig_]
cnf -> [IdPConfig_] -> IO ()
info [IdPConfig_]
cnf IO () -> IO [IdPConfig_] -> IO [IdPConfig_]
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [IdPConfig_] -> IO [IdPConfig_]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [IdPConfig_]
cnf)
    (Either ParseException [IdPConfig_] -> IO [IdPConfig_])
-> IO (Either ParseException [IdPConfig_]) -> IO [IdPConfig_]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FilePath -> IO (Either ParseException [IdPConfig_])
forall a. FromJSON a => FilePath -> IO (Either ParseException a)
Yaml.decodeFileEither FilePath
filepath
  where
    info :: [IdPConfig_] -> IO ()
    info :: [IdPConfig_] -> IO ()
info [IdPConfig_]
idps =
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Config
cfg Config -> Getting Level Config Level -> Level
forall s a. s -> Getting a s a -> a
^. Getting Level Config Level
Lens' Config Level
cfgLogLevel Level -> Level -> Bool
forall a. Ord a => a -> a -> Bool
<= Level
Info) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        Handle -> FilePath -> IO ()
hPutStrLn Handle
stderr (FilePath -> IO ())
-> ([IdPConfig_] -> FilePath) -> [IdPConfig_] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FilePath
"\n>>>known idps:\n" FilePath -> ShowS
forall a. Semigroup a => a -> a -> a
<>) ShowS -> ([IdPConfig_] -> FilePath) -> [IdPConfig_] -> FilePath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> FilePath
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> FilePath)
-> ([IdPConfig_] -> ByteString) -> [IdPConfig_] -> FilePath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [IdPConfig_] -> ByteString
forall a. ToJSON a => a -> ByteString
Yaml.encode ([IdPConfig_] -> IO ()) -> [IdPConfig_] -> IO ()
forall a b. (a -> b) -> a -> b
$
          [IdPConfig_]
idps

----------------------------------------------------------------------
-- class

class HasConfig m where
  getConfig :: m Config

instance HasConfig ((->) Config) where
  getConfig :: Config -> Config
getConfig = Config -> Config
forall a. a -> a
id