{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
module Wire.API.User.Saml where
import Control.Monad.Except
import Data.Aeson hiding (fieldLabelModifier)
import Data.Aeson.TH hiding (fieldLabelModifier)
import Data.ByteString (toStrict)
import Data.ByteString.Builder qualified as Builder
import Data.Id (UserId)
import Data.OpenApi
import Data.Proxy (Proxy (Proxy))
import Data.Text qualified as T
import Data.Text.Encoding
import Data.Text.Encoding.Error
import Data.Time
import GHC.TypeLits (KnownSymbol, symbolVal)
import GHC.Types (Symbol)
import Imports
import SAML2.WebSSO
import SAML2.WebSSO.Types.TH (deriveJSONOptions)
import URI.ByteString
import Web.Cookie
import Wire.API.User.Orphans ()
type AReqId = ID AuthnRequest
type AssId = ID Assertion
data VerdictFormat
= VerdictFormatWeb
| VerdictFormatMobile {VerdictFormat -> URI
_formatGrantedURI :: URI, VerdictFormat -> URI
_formatDeniedURI :: URI}
deriving (VerdictFormat -> VerdictFormat -> Bool
(VerdictFormat -> VerdictFormat -> Bool)
-> (VerdictFormat -> VerdictFormat -> Bool) -> Eq VerdictFormat
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: VerdictFormat -> VerdictFormat -> Bool
== :: VerdictFormat -> VerdictFormat -> Bool
$c/= :: VerdictFormat -> VerdictFormat -> Bool
/= :: VerdictFormat -> VerdictFormat -> Bool
Eq, Int -> VerdictFormat -> ShowS
[VerdictFormat] -> ShowS
VerdictFormat -> String
(Int -> VerdictFormat -> ShowS)
-> (VerdictFormat -> String)
-> ([VerdictFormat] -> ShowS)
-> Show VerdictFormat
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VerdictFormat -> ShowS
showsPrec :: Int -> VerdictFormat -> ShowS
$cshow :: VerdictFormat -> String
show :: VerdictFormat -> String
$cshowList :: [VerdictFormat] -> ShowS
showList :: [VerdictFormat] -> ShowS
Show, (forall x. VerdictFormat -> Rep VerdictFormat x)
-> (forall x. Rep VerdictFormat x -> VerdictFormat)
-> Generic VerdictFormat
forall x. Rep VerdictFormat x -> VerdictFormat
forall x. VerdictFormat -> Rep VerdictFormat x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. VerdictFormat -> Rep VerdictFormat x
from :: forall x. VerdictFormat -> Rep VerdictFormat x
$cto :: forall x. Rep VerdictFormat x -> VerdictFormat
to :: forall x. Rep VerdictFormat x -> VerdictFormat
Generic)
deriveJSON deriveJSONOptions ''VerdictFormat
mkVerdictGrantedFormatMobile :: (MonadError String m) => URI -> SetCookie -> UserId -> m URI
mkVerdictGrantedFormatMobile :: forall (m :: * -> *).
MonadError String m =>
URI -> SetCookie -> UserId -> m URI
mkVerdictGrantedFormatMobile URI
before SetCookie
cky UserId
uid =
Text -> m URI
forall (m :: * -> *). MonadError String m => Text -> m URI
parseURI'
(Text -> m URI) -> (Text -> Text) -> Text -> m URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar
Text
"cookie"
( OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode
(ByteString -> Text)
-> (SetCookie -> ByteString) -> SetCookie -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
toStrict
(ByteString -> ByteString)
-> (SetCookie -> ByteString) -> SetCookie -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
Builder.toLazyByteString
(Builder -> ByteString)
-> (SetCookie -> Builder) -> SetCookie -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SetCookie -> Builder
renderSetCookie
(SetCookie -> Text) -> SetCookie -> Text
forall a b. (a -> b) -> a -> b
$ SetCookie
cky
)
(Text -> Text) -> (Text -> Text) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar Text
"userid" (String -> Text
T.pack (String -> Text) -> (UserId -> String) -> UserId -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UserId -> String
forall a. Show a => a -> String
show (UserId -> Text) -> UserId -> Text
forall a b. (a -> b) -> a -> b
$ UserId
uid)
(Text -> m URI) -> Text -> m URI
forall a b. (a -> b) -> a -> b
$ URI -> Text
renderURI URI
before
mkVerdictDeniedFormatMobile :: (MonadError String m) => URI -> Text -> m URI
mkVerdictDeniedFormatMobile :: forall (m :: * -> *). MonadError String m => URI -> Text -> m URI
mkVerdictDeniedFormatMobile URI
before Text
lbl =
Text -> m URI
forall (m :: * -> *). MonadError String m => Text -> m URI
parseURI'
(Text -> m URI) -> (Text -> Text) -> Text -> m URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar Text
"label" Text
lbl
(Text -> m URI) -> Text -> m URI
forall a b. (a -> b) -> a -> b
$ URI -> Text
renderURI URI
before
substituteVar :: Text -> Text -> Text -> Text
substituteVar :: Text -> Text -> Text -> Text
substituteVar Text
var Text
val = Text -> Text -> Text -> Text
substituteVar' (Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
var) Text
val (Text -> Text) -> (Text -> Text) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar' (Text
"%24" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
var) Text
val
substituteVar' :: Text -> Text -> Text -> Text
substituteVar' :: Text -> Text -> Text -> Text
substituteVar' Text
var Text
val = Text -> [Text] -> Text
T.intercalate Text
val ([Text] -> Text) -> (Text -> [Text]) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
var
newtype TTL (tablename :: Symbol) = TTL {forall (tablename :: Symbol). TTL tablename -> Int32
fromTTL :: Int32}
deriving (TTL tablename -> TTL tablename -> Bool
(TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool) -> Eq (TTL tablename)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
$c== :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
== :: TTL tablename -> TTL tablename -> Bool
$c/= :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
/= :: TTL tablename -> TTL tablename -> Bool
Eq, Eq (TTL tablename)
Eq (TTL tablename) =>
(TTL tablename -> TTL tablename -> Ordering)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> Ord (TTL tablename)
TTL tablename -> TTL tablename -> Bool
TTL tablename -> TTL tablename -> Ordering
TTL tablename -> TTL tablename -> TTL tablename
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall (tablename :: Symbol). Eq (TTL tablename)
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Ordering
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
$ccompare :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Ordering
compare :: TTL tablename -> TTL tablename -> Ordering
$c< :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
< :: TTL tablename -> TTL tablename -> Bool
$c<= :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
<= :: TTL tablename -> TTL tablename -> Bool
$c> :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
> :: TTL tablename -> TTL tablename -> Bool
$c>= :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
>= :: TTL tablename -> TTL tablename -> Bool
$cmax :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
max :: TTL tablename -> TTL tablename -> TTL tablename
$cmin :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
min :: TTL tablename -> TTL tablename -> TTL tablename
Ord, Int -> TTL tablename -> ShowS
[TTL tablename] -> ShowS
TTL tablename -> String
(Int -> TTL tablename -> ShowS)
-> (TTL tablename -> String)
-> ([TTL tablename] -> ShowS)
-> Show (TTL tablename)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (tablename :: Symbol). Int -> TTL tablename -> ShowS
forall (tablename :: Symbol). [TTL tablename] -> ShowS
forall (tablename :: Symbol). TTL tablename -> String
$cshowsPrec :: forall (tablename :: Symbol). Int -> TTL tablename -> ShowS
showsPrec :: Int -> TTL tablename -> ShowS
$cshow :: forall (tablename :: Symbol). TTL tablename -> String
show :: TTL tablename -> String
$cshowList :: forall (tablename :: Symbol). [TTL tablename] -> ShowS
showList :: [TTL tablename] -> ShowS
Show, Integer -> TTL tablename
TTL tablename -> TTL tablename
TTL tablename -> TTL tablename -> TTL tablename
(TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename)
-> (Integer -> TTL tablename)
-> Num (TTL tablename)
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
forall (tablename :: Symbol). Integer -> TTL tablename
forall (tablename :: Symbol). TTL tablename -> TTL tablename
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
$c+ :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
+ :: TTL tablename -> TTL tablename -> TTL tablename
$c- :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
- :: TTL tablename -> TTL tablename -> TTL tablename
$c* :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
* :: TTL tablename -> TTL tablename -> TTL tablename
$cnegate :: forall (tablename :: Symbol). TTL tablename -> TTL tablename
negate :: TTL tablename -> TTL tablename
$cabs :: forall (tablename :: Symbol). TTL tablename -> TTL tablename
abs :: TTL tablename -> TTL tablename
$csignum :: forall (tablename :: Symbol). TTL tablename -> TTL tablename
signum :: TTL tablename -> TTL tablename
$cfromInteger :: forall (tablename :: Symbol). Integer -> TTL tablename
fromInteger :: Integer -> TTL tablename
Num)
showTTL :: (KnownSymbol a) => TTL a -> String
showTTL :: forall (a :: Symbol). KnownSymbol a => TTL a -> String
showTTL (TTL Int32
i :: TTL a) = String
"TTL:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Proxy a -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @a) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
":" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int32 -> String
forall a. Show a => a -> String
show Int32
i
instance FromJSON (TTL a) where
parseJSON :: Value -> Parser (TTL a)
parseJSON = String -> (Scientific -> Parser (TTL a)) -> Value -> Parser (TTL a)
forall a. String -> (Scientific -> Parser a) -> Value -> Parser a
withScientific String
"TTL value (seconds)" (TTL a -> Parser (TTL a)
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TTL a -> Parser (TTL a))
-> (Scientific -> TTL a) -> Scientific -> Parser (TTL a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> TTL a
forall (tablename :: Symbol). Int32 -> TTL tablename
TTL (Int32 -> TTL a) -> (Scientific -> Int32) -> Scientific -> TTL a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scientific -> Int32
forall b. Integral b => Scientific -> b
forall a b. (RealFrac a, Integral b) => a -> b
round)
data TTLError = TTLTooLong String String | TTLNegative String
deriving (TTLError -> TTLError -> Bool
(TTLError -> TTLError -> Bool)
-> (TTLError -> TTLError -> Bool) -> Eq TTLError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TTLError -> TTLError -> Bool
== :: TTLError -> TTLError -> Bool
$c/= :: TTLError -> TTLError -> Bool
/= :: TTLError -> TTLError -> Bool
Eq, Int -> TTLError -> ShowS
[TTLError] -> ShowS
TTLError -> String
(Int -> TTLError -> ShowS)
-> (TTLError -> String) -> ([TTLError] -> ShowS) -> Show TTLError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TTLError -> ShowS
showsPrec :: Int -> TTLError -> ShowS
$cshow :: TTLError -> String
show :: TTLError -> String
$cshowList :: [TTLError] -> ShowS
showList :: [TTLError] -> ShowS
Show)
ttlToNominalDiffTime :: TTL a -> NominalDiffTime
ttlToNominalDiffTime :: forall (a :: Symbol). TTL a -> NominalDiffTime
ttlToNominalDiffTime (TTL Int32
i32) = Int32 -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
i32
data SsoSettings = SsoSettings
{ SsoSettings -> Maybe IdPId
defaultSsoCode :: !(Maybe IdPId)
}
deriving ((forall x. SsoSettings -> Rep SsoSettings x)
-> (forall x. Rep SsoSettings x -> SsoSettings)
-> Generic SsoSettings
forall x. Rep SsoSettings x -> SsoSettings
forall x. SsoSettings -> Rep SsoSettings x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SsoSettings -> Rep SsoSettings x
from :: forall x. SsoSettings -> Rep SsoSettings x
$cto :: forall x. Rep SsoSettings x -> SsoSettings
to :: forall x. Rep SsoSettings x -> SsoSettings
Generic, Int -> SsoSettings -> ShowS
[SsoSettings] -> ShowS
SsoSettings -> String
(Int -> SsoSettings -> ShowS)
-> (SsoSettings -> String)
-> ([SsoSettings] -> ShowS)
-> Show SsoSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SsoSettings -> ShowS
showsPrec :: Int -> SsoSettings -> ShowS
$cshow :: SsoSettings -> String
show :: SsoSettings -> String
$cshowList :: [SsoSettings] -> ShowS
showList :: [SsoSettings] -> ShowS
Show)
instance FromJSON SsoSettings where
parseJSON :: Value -> Parser SsoSettings
parseJSON = String
-> (Object -> Parser SsoSettings) -> Value -> Parser SsoSettings
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"SsoSettings" ((Object -> Parser SsoSettings) -> Value -> Parser SsoSettings)
-> (Object -> Parser SsoSettings) -> Value -> Parser SsoSettings
forall a b. (a -> b) -> a -> b
$ \Object
obj -> do
Maybe IdPId -> SsoSettings
SsoSettings (Maybe IdPId -> SsoSettings)
-> Parser (Maybe IdPId) -> Parser SsoSettings
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
obj Object -> Key -> Parser (Maybe IdPId)
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"default_sso_code"
instance ToJSON SsoSettings where
toJSON :: SsoSettings -> Value
toJSON SsoSettings {Maybe IdPId
$sel:defaultSsoCode:SsoSettings :: SsoSettings -> Maybe IdPId
defaultSsoCode :: Maybe IdPId
defaultSsoCode} =
[Pair] -> Value
object [Key
"default_sso_code" Key -> Maybe IdPId -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Maybe IdPId
defaultSsoCode]
instance ToSchema SsoSettings where
declareNamedSchema :: Proxy SsoSettings -> Declare (Definitions Schema) NamedSchema
declareNamedSchema =
SchemaOptions
-> Proxy SsoSettings -> Declare (Definitions Schema) NamedSchema
forall a.
(Generic a, GToSchema (Rep a), Typeable a) =>
SchemaOptions
-> Proxy a -> Declare (Definitions Schema) NamedSchema
genericDeclareNamedSchema
SchemaOptions
defaultSchemaOptions
{ fieldLabelModifier = \case
String
"defaultSsoCode" -> String
"default_sso_code"
String
other -> String
other
}