{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

-- | A "default" module for types used in Spar, unless there's a better / more specific place
-- for them.
module Wire.API.User.Saml where

import Control.Monad.Except
import Data.Aeson hiding (fieldLabelModifier)
import Data.Aeson.TH hiding (fieldLabelModifier)
import Data.ByteString (toStrict)
import Data.ByteString.Builder qualified as Builder
import Data.Id (UserId)
import Data.OpenApi
import Data.Proxy (Proxy (Proxy))
import Data.Text qualified as T
import Data.Text.Encoding
import Data.Text.Encoding.Error
import Data.Time
import GHC.TypeLits (KnownSymbol, symbolVal)
import GHC.Types (Symbol)
import Imports
import SAML2.WebSSO
import SAML2.WebSSO.Types.TH (deriveJSONOptions)
import URI.ByteString
import Web.Cookie
import Wire.API.User.Orphans ()

----------------------------------------------------------------------------
-- Requests and verdicts

type AReqId = ID AuthnRequest

type AssId = ID Assertion

-- | Clients can request different ways of receiving the final 'AccessVerdict' when fetching their
-- 'AuthnRequest'.  Web-based clients want an html page, mobile clients want to set two URIs for the
-- two resp. 'AccessVerdict' constructors.  This format is stored in cassandra under the request id
-- so that the verdict handler can act on it.
data VerdictFormat
  = VerdictFormatWeb
  | VerdictFormatMobile {VerdictFormat -> URI
_formatGrantedURI :: URI, VerdictFormat -> URI
_formatDeniedURI :: URI}
  deriving (VerdictFormat -> VerdictFormat -> Bool
(VerdictFormat -> VerdictFormat -> Bool)
-> (VerdictFormat -> VerdictFormat -> Bool) -> Eq VerdictFormat
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: VerdictFormat -> VerdictFormat -> Bool
== :: VerdictFormat -> VerdictFormat -> Bool
$c/= :: VerdictFormat -> VerdictFormat -> Bool
/= :: VerdictFormat -> VerdictFormat -> Bool
Eq, Int -> VerdictFormat -> ShowS
[VerdictFormat] -> ShowS
VerdictFormat -> String
(Int -> VerdictFormat -> ShowS)
-> (VerdictFormat -> String)
-> ([VerdictFormat] -> ShowS)
-> Show VerdictFormat
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> VerdictFormat -> ShowS
showsPrec :: Int -> VerdictFormat -> ShowS
$cshow :: VerdictFormat -> String
show :: VerdictFormat -> String
$cshowList :: [VerdictFormat] -> ShowS
showList :: [VerdictFormat] -> ShowS
Show, (forall x. VerdictFormat -> Rep VerdictFormat x)
-> (forall x. Rep VerdictFormat x -> VerdictFormat)
-> Generic VerdictFormat
forall x. Rep VerdictFormat x -> VerdictFormat
forall x. VerdictFormat -> Rep VerdictFormat x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. VerdictFormat -> Rep VerdictFormat x
from :: forall x. VerdictFormat -> Rep VerdictFormat x
$cto :: forall x. Rep VerdictFormat x -> VerdictFormat
to :: forall x. Rep VerdictFormat x -> VerdictFormat
Generic)

deriveJSON deriveJSONOptions ''VerdictFormat

mkVerdictGrantedFormatMobile :: (MonadError String m) => URI -> SetCookie -> UserId -> m URI
mkVerdictGrantedFormatMobile :: forall (m :: * -> *).
MonadError String m =>
URI -> SetCookie -> UserId -> m URI
mkVerdictGrantedFormatMobile URI
before SetCookie
cky UserId
uid =
  Text -> m URI
forall (m :: * -> *). MonadError String m => Text -> m URI
parseURI'
    (Text -> m URI) -> (Text -> Text) -> Text -> m URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar
      Text
"cookie"
      ( OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode
          (ByteString -> Text)
-> (SetCookie -> ByteString) -> SetCookie -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
toStrict
          (ByteString -> ByteString)
-> (SetCookie -> ByteString) -> SetCookie -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
Builder.toLazyByteString
          (Builder -> ByteString)
-> (SetCookie -> Builder) -> SetCookie -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SetCookie -> Builder
renderSetCookie
          (SetCookie -> Text) -> SetCookie -> Text
forall a b. (a -> b) -> a -> b
$ SetCookie
cky
      )
    (Text -> Text) -> (Text -> Text) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar Text
"userid" (String -> Text
T.pack (String -> Text) -> (UserId -> String) -> UserId -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UserId -> String
forall a. Show a => a -> String
show (UserId -> Text) -> UserId -> Text
forall a b. (a -> b) -> a -> b
$ UserId
uid)
    (Text -> m URI) -> Text -> m URI
forall a b. (a -> b) -> a -> b
$ URI -> Text
renderURI URI
before

mkVerdictDeniedFormatMobile :: (MonadError String m) => URI -> Text -> m URI
mkVerdictDeniedFormatMobile :: forall (m :: * -> *). MonadError String m => URI -> Text -> m URI
mkVerdictDeniedFormatMobile URI
before Text
lbl =
  Text -> m URI
forall (m :: * -> *). MonadError String m => Text -> m URI
parseURI'
    (Text -> m URI) -> (Text -> Text) -> Text -> m URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar Text
"label" Text
lbl
    (Text -> m URI) -> Text -> m URI
forall a b. (a -> b) -> a -> b
$ URI -> Text
renderURI URI
before

substituteVar :: Text -> Text -> Text -> Text
substituteVar :: Text -> Text -> Text -> Text
substituteVar Text
var Text
val = Text -> Text -> Text -> Text
substituteVar' (Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
var) Text
val (Text -> Text) -> (Text -> Text) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> Text -> Text
substituteVar' (Text
"%24" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
var) Text
val

substituteVar' :: Text -> Text -> Text -> Text
substituteVar' :: Text -> Text -> Text -> Text
substituteVar' Text
var Text
val = Text -> [Text] -> Text
T.intercalate Text
val ([Text] -> Text) -> (Text -> [Text]) -> Text -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
var

-- | (seconds)
newtype TTL (tablename :: Symbol) = TTL {forall (tablename :: Symbol). TTL tablename -> Int32
fromTTL :: Int32}
  deriving (TTL tablename -> TTL tablename -> Bool
(TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool) -> Eq (TTL tablename)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
$c== :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
== :: TTL tablename -> TTL tablename -> Bool
$c/= :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
/= :: TTL tablename -> TTL tablename -> Bool
Eq, Eq (TTL tablename)
Eq (TTL tablename) =>
(TTL tablename -> TTL tablename -> Ordering)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> Bool)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> Ord (TTL tablename)
TTL tablename -> TTL tablename -> Bool
TTL tablename -> TTL tablename -> Ordering
TTL tablename -> TTL tablename -> TTL tablename
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall (tablename :: Symbol). Eq (TTL tablename)
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Ordering
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
$ccompare :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Ordering
compare :: TTL tablename -> TTL tablename -> Ordering
$c< :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
< :: TTL tablename -> TTL tablename -> Bool
$c<= :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
<= :: TTL tablename -> TTL tablename -> Bool
$c> :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
> :: TTL tablename -> TTL tablename -> Bool
$c>= :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> Bool
>= :: TTL tablename -> TTL tablename -> Bool
$cmax :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
max :: TTL tablename -> TTL tablename -> TTL tablename
$cmin :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
min :: TTL tablename -> TTL tablename -> TTL tablename
Ord, Int -> TTL tablename -> ShowS
[TTL tablename] -> ShowS
TTL tablename -> String
(Int -> TTL tablename -> ShowS)
-> (TTL tablename -> String)
-> ([TTL tablename] -> ShowS)
-> Show (TTL tablename)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (tablename :: Symbol). Int -> TTL tablename -> ShowS
forall (tablename :: Symbol). [TTL tablename] -> ShowS
forall (tablename :: Symbol). TTL tablename -> String
$cshowsPrec :: forall (tablename :: Symbol). Int -> TTL tablename -> ShowS
showsPrec :: Int -> TTL tablename -> ShowS
$cshow :: forall (tablename :: Symbol). TTL tablename -> String
show :: TTL tablename -> String
$cshowList :: forall (tablename :: Symbol). [TTL tablename] -> ShowS
showList :: [TTL tablename] -> ShowS
Show, Integer -> TTL tablename
TTL tablename -> TTL tablename
TTL tablename -> TTL tablename -> TTL tablename
(TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename)
-> (TTL tablename -> TTL tablename)
-> (Integer -> TTL tablename)
-> Num (TTL tablename)
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
forall (tablename :: Symbol). Integer -> TTL tablename
forall (tablename :: Symbol). TTL tablename -> TTL tablename
forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
$c+ :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
+ :: TTL tablename -> TTL tablename -> TTL tablename
$c- :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
- :: TTL tablename -> TTL tablename -> TTL tablename
$c* :: forall (tablename :: Symbol).
TTL tablename -> TTL tablename -> TTL tablename
* :: TTL tablename -> TTL tablename -> TTL tablename
$cnegate :: forall (tablename :: Symbol). TTL tablename -> TTL tablename
negate :: TTL tablename -> TTL tablename
$cabs :: forall (tablename :: Symbol). TTL tablename -> TTL tablename
abs :: TTL tablename -> TTL tablename
$csignum :: forall (tablename :: Symbol). TTL tablename -> TTL tablename
signum :: TTL tablename -> TTL tablename
$cfromInteger :: forall (tablename :: Symbol). Integer -> TTL tablename
fromInteger :: Integer -> TTL tablename
Num)

showTTL :: (KnownSymbol a) => TTL a -> String
showTTL :: forall (a :: Symbol). KnownSymbol a => TTL a -> String
showTTL (TTL Int32
i :: TTL a) = String
"TTL:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Proxy a -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @a) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
":" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int32 -> String
forall a. Show a => a -> String
show Int32
i

instance FromJSON (TTL a) where
  parseJSON :: Value -> Parser (TTL a)
parseJSON = String -> (Scientific -> Parser (TTL a)) -> Value -> Parser (TTL a)
forall a. String -> (Scientific -> Parser a) -> Value -> Parser a
withScientific String
"TTL value (seconds)" (TTL a -> Parser (TTL a)
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TTL a -> Parser (TTL a))
-> (Scientific -> TTL a) -> Scientific -> Parser (TTL a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> TTL a
forall (tablename :: Symbol). Int32 -> TTL tablename
TTL (Int32 -> TTL a) -> (Scientific -> Int32) -> Scientific -> TTL a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scientific -> Int32
forall b. Integral b => Scientific -> b
forall a b. (RealFrac a, Integral b) => a -> b
round)

data TTLError = TTLTooLong String String | TTLNegative String
  deriving (TTLError -> TTLError -> Bool
(TTLError -> TTLError -> Bool)
-> (TTLError -> TTLError -> Bool) -> Eq TTLError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TTLError -> TTLError -> Bool
== :: TTLError -> TTLError -> Bool
$c/= :: TTLError -> TTLError -> Bool
/= :: TTLError -> TTLError -> Bool
Eq, Int -> TTLError -> ShowS
[TTLError] -> ShowS
TTLError -> String
(Int -> TTLError -> ShowS)
-> (TTLError -> String) -> ([TTLError] -> ShowS) -> Show TTLError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TTLError -> ShowS
showsPrec :: Int -> TTLError -> ShowS
$cshow :: TTLError -> String
show :: TTLError -> String
$cshowList :: [TTLError] -> ShowS
showList :: [TTLError] -> ShowS
Show)

ttlToNominalDiffTime :: TTL a -> NominalDiffTime
ttlToNominalDiffTime :: forall (a :: Symbol). TTL a -> NominalDiffTime
ttlToNominalDiffTime (TTL Int32
i32) = Int32 -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
i32

data SsoSettings = SsoSettings
  { SsoSettings -> Maybe IdPId
defaultSsoCode :: !(Maybe IdPId)
  }
  deriving ((forall x. SsoSettings -> Rep SsoSettings x)
-> (forall x. Rep SsoSettings x -> SsoSettings)
-> Generic SsoSettings
forall x. Rep SsoSettings x -> SsoSettings
forall x. SsoSettings -> Rep SsoSettings x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SsoSettings -> Rep SsoSettings x
from :: forall x. SsoSettings -> Rep SsoSettings x
$cto :: forall x. Rep SsoSettings x -> SsoSettings
to :: forall x. Rep SsoSettings x -> SsoSettings
Generic, Int -> SsoSettings -> ShowS
[SsoSettings] -> ShowS
SsoSettings -> String
(Int -> SsoSettings -> ShowS)
-> (SsoSettings -> String)
-> ([SsoSettings] -> ShowS)
-> Show SsoSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SsoSettings -> ShowS
showsPrec :: Int -> SsoSettings -> ShowS
$cshow :: SsoSettings -> String
show :: SsoSettings -> String
$cshowList :: [SsoSettings] -> ShowS
showList :: [SsoSettings] -> ShowS
Show)

instance FromJSON SsoSettings where
  parseJSON :: Value -> Parser SsoSettings
parseJSON = String
-> (Object -> Parser SsoSettings) -> Value -> Parser SsoSettings
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"SsoSettings" ((Object -> Parser SsoSettings) -> Value -> Parser SsoSettings)
-> (Object -> Parser SsoSettings) -> Value -> Parser SsoSettings
forall a b. (a -> b) -> a -> b
$ \Object
obj -> do
    -- key needs to be present, but can be null
    Maybe IdPId -> SsoSettings
SsoSettings (Maybe IdPId -> SsoSettings)
-> Parser (Maybe IdPId) -> Parser SsoSettings
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
obj Object -> Key -> Parser (Maybe IdPId)
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"default_sso_code"

instance ToJSON SsoSettings where
  toJSON :: SsoSettings -> Value
toJSON SsoSettings {Maybe IdPId
$sel:defaultSsoCode:SsoSettings :: SsoSettings -> Maybe IdPId
defaultSsoCode :: Maybe IdPId
defaultSsoCode} =
    [Pair] -> Value
object [Key
"default_sso_code" Key -> Maybe IdPId -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Maybe IdPId
defaultSsoCode]

-- Swagger instances

instance ToSchema SsoSettings where
  declareNamedSchema :: Proxy SsoSettings -> Declare (Definitions Schema) NamedSchema
declareNamedSchema =
    SchemaOptions
-> Proxy SsoSettings -> Declare (Definitions Schema) NamedSchema
forall a.
(Generic a, GToSchema (Rep a), Typeable a) =>
SchemaOptions
-> Proxy a -> Declare (Definitions Schema) NamedSchema
genericDeclareNamedSchema
      SchemaOptions
defaultSchemaOptions
        { fieldLabelModifier = \case
            String
"defaultSsoCode" -> String
"default_sso_code"
            String
other -> String
other
        }