{-# LANGUAGE OverloadedStrings #-}

-- | A high-level wrapper for 'Web.Cookie.SetCookie' that interfaces with servant types, generates
-- and verifies cookie name from the type, handles randomness generation, and cookie deletion.
module SAML2.WebSSO.Cookie
  ( SimpleSetCookie (..),
    cookieName,
    cookieToHeader,
    toggleCookie,
    setSimpleCookieValue,
  )
where

import Control.Lens
import Control.Monad.Except
import Data.Binary.Builder (toLazyByteString)
import qualified Data.ByteString.Builder as SBSBuilder
import Data.Proxy
import Data.String.Conversions
import qualified Data.Text as ST
import Data.Time
import GHC.TypeLits (KnownSymbol, symbolVal)
import GHC.Types
import qualified Network.HTTP.Types.Header as HttpTypes
import SAML2.WebSSO.SP
import SAML2.WebSSO.Types
import SAML2.WebSSO.XML
import Servant.API as Servant hiding (URI (..))
import Web.Cookie

newtype SimpleSetCookie name = SimpleSetCookie {forall {k} (name :: k). SimpleSetCookie name -> SetCookie
fromSimpleSetCookie :: SetCookie}
  deriving (SimpleSetCookie name -> SimpleSetCookie name -> Bool
(SimpleSetCookie name -> SimpleSetCookie name -> Bool)
-> (SimpleSetCookie name -> SimpleSetCookie name -> Bool)
-> Eq (SimpleSetCookie name)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (name :: k).
SimpleSetCookie name -> SimpleSetCookie name -> Bool
$c== :: forall k (name :: k).
SimpleSetCookie name -> SimpleSetCookie name -> Bool
== :: SimpleSetCookie name -> SimpleSetCookie name -> Bool
$c/= :: forall k (name :: k).
SimpleSetCookie name -> SimpleSetCookie name -> Bool
/= :: SimpleSetCookie name -> SimpleSetCookie name -> Bool
Eq, Int -> SimpleSetCookie name -> ShowS
[SimpleSetCookie name] -> ShowS
SimpleSetCookie name -> String
(Int -> SimpleSetCookie name -> ShowS)
-> (SimpleSetCookie name -> String)
-> ([SimpleSetCookie name] -> ShowS)
-> Show (SimpleSetCookie name)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (name :: k). Int -> SimpleSetCookie name -> ShowS
forall k (name :: k). [SimpleSetCookie name] -> ShowS
forall k (name :: k). SimpleSetCookie name -> String
$cshowsPrec :: forall k (name :: k). Int -> SimpleSetCookie name -> ShowS
showsPrec :: Int -> SimpleSetCookie name -> ShowS
$cshow :: forall k (name :: k). SimpleSetCookie name -> String
show :: SimpleSetCookie name -> String
$cshowList :: forall k (name :: k). [SimpleSetCookie name] -> ShowS
showList :: [SimpleSetCookie name] -> ShowS
Show)

instance KnownSymbol name => ToHttpApiData (SimpleSetCookie name) where
  toUrlPiece :: SimpleSetCookie name -> Text
toUrlPiece = ByteString -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> Text)
-> (SimpleSetCookie name -> ByteString)
-> SimpleSetCookie name
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
SBSBuilder.toLazyByteString (Builder -> ByteString)
-> (SimpleSetCookie name -> Builder)
-> SimpleSetCookie name
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SetCookie -> Builder
renderSetCookie (SetCookie -> Builder)
-> (SimpleSetCookie name -> SetCookie)
-> SimpleSetCookie name
-> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleSetCookie name -> SetCookie
forall {k} (name :: k). SimpleSetCookie name -> SetCookie
fromSimpleSetCookie

instance KnownSymbol name => FromHttpApiData (SimpleSetCookie name) where
  parseUrlPiece :: Text -> Either Text (SimpleSetCookie name)
parseUrlPiece = Text -> Either Text (SimpleSetCookie name)
forall (name :: Symbol).
KnownSymbol name =>
Text -> Either Text (SimpleSetCookie name)
headerValueToCookie

cookieToHeader :: SimpleSetCookie name -> HttpTypes.Header
cookieToHeader :: forall {k} (name :: k). SimpleSetCookie name -> Header
cookieToHeader =
  (HeaderName
"set-cookie",) (ByteString -> Header)
-> (SimpleSetCookie name -> ByteString)
-> SimpleSetCookie name
-> Header
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> ByteString)
-> (SimpleSetCookie name -> ByteString)
-> SimpleSetCookie name
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
toLazyByteString
    (Builder -> ByteString)
-> (SimpleSetCookie name -> Builder)
-> SimpleSetCookie name
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SetCookie -> Builder
renderSetCookie
    (SetCookie -> Builder)
-> (SimpleSetCookie name -> SetCookie)
-> SimpleSetCookie name
-> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleSetCookie name -> SetCookie
forall {k} (name :: k). SimpleSetCookie name -> SetCookie
fromSimpleSetCookie

cookieName :: forall (proxy :: Symbol -> Type) (name :: Symbol). KnownSymbol name => proxy name -> SBS
cookieName :: forall (proxy :: Symbol -> *) (name :: Symbol).
KnownSymbol name =>
proxy name -> ByteString
cookieName proxy name
_ = String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ Proxy name -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @name)

headerValueToCookie :: forall name. KnownSymbol name => ST -> Either ST (SimpleSetCookie name)
headerValueToCookie :: forall (name :: Symbol).
KnownSymbol name =>
Text -> Either Text (SimpleSetCookie name)
headerValueToCookie Text
txt = do
  let cookie :: SetCookie
cookie = ByteString -> SetCookie
parseSetCookie (ByteString -> SetCookie) -> ByteString -> SetCookie
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs Text
txt
  case [Text
"missing cookie name" | SetCookie -> ByteString
setCookieName SetCookie
cookie ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""]
    [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [ ByteString -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ ByteString
"wrong cookie name: got " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> SetCookie -> ByteString
setCookieName SetCookie
cookie ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
", expected " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Proxy name -> ByteString
forall (proxy :: Symbol -> *) (name :: Symbol).
KnownSymbol name =>
proxy name -> ByteString
cookieName (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @name)
         | SetCookie -> ByteString
setCookieName SetCookie
cookie ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= Proxy name -> ByteString
forall (proxy :: Symbol -> *) (name :: Symbol).
KnownSymbol name =>
proxy name -> ByteString
cookieName (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @name)
       ]
    [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text
"missing cookie value" | SetCookie -> ByteString
setCookieValue SetCookie
cookie ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""] of
    errs :: [Text]
errs@(Text
_ : [Text]
_) -> Text -> Either Text (SimpleSetCookie name)
forall a. Text -> Either Text a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> Either Text (SimpleSetCookie name))
-> Text -> Either Text (SimpleSetCookie name)
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
ST.intercalate Text
", " [Text]
errs
    [] -> SimpleSetCookie name -> Either Text (SimpleSetCookie name)
forall a. a -> Either Text a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SetCookie -> SimpleSetCookie name
forall {k} (name :: k). SetCookie -> SimpleSetCookie name
SimpleSetCookie SetCookie
cookie)

toggleCookie :: forall name m. (Applicative m, SP m, KnownSymbol name) => SBS -> Maybe (ST, NominalDiffTime) -> m (SimpleSetCookie name)
toggleCookie :: forall (name :: Symbol) (m :: * -> *).
(Applicative m, SP m, KnownSymbol name) =>
ByteString
-> Maybe (Text, NominalDiffTime) -> m (SimpleSetCookie name)
toggleCookie ByteString
path =
  (SetCookie -> SimpleSetCookie name)
-> m SetCookie -> m (SimpleSetCookie name)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SetCookie -> SimpleSetCookie name
forall {k} (name :: k). SetCookie -> SimpleSetCookie name
SimpleSetCookie (m SetCookie -> m (SimpleSetCookie name))
-> (Maybe (Text, NominalDiffTime) -> m SetCookie)
-> Maybe (Text, NominalDiffTime)
-> m (SimpleSetCookie name)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
    Just (Text
value, NominalDiffTime
ttl) ->
      m Time
forall (m :: * -> *). HasNow m => m Time
getNow m Time -> (Time -> SetCookie) -> m SetCookie
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Time
now ->
        SetCookie
cookie
          { setCookieValue = cs value,
            setCookieExpires = Just . fromTime $ ttl `addTime` now,
            setCookieMaxAge = Just $ realToFrac ttl
          }
    Maybe (Text, NominalDiffTime)
Nothing ->
      SetCookie -> m SetCookie
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        SetCookie
cookie
          { setCookieValue = "",
            setCookieExpires = Just $ fromTime beginningOfTime,
            setCookieMaxAge = Just (-1)
          }
  where
    cookie :: SetCookie
cookie =
      SetCookie
defaultSetCookie
        { setCookieName = cookieName (Proxy @name),
          setCookieSecure = True,
          setCookiePath = Just path,
          setCookieHttpOnly = True,
          setCookieSameSite = Just sameSiteStrict
        }

beginningOfTime :: Time
beginningOfTime :: Time
beginningOfTime = HasCallStack => String -> Time
String -> Time
unsafeReadTime String
"1970-01-01T00:00:00Z"

setSimpleCookieValue :: SimpleSetCookie name -> SBS
setSimpleCookieValue :: forall {k} (name :: k). SimpleSetCookie name -> ByteString
setSimpleCookieValue = SetCookie -> ByteString
setCookieValue (SetCookie -> ByteString)
-> (SimpleSetCookie name -> SetCookie)
-> SimpleSetCookie name
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleSetCookie name -> SetCookie
forall {k} (name :: k). SimpleSetCookie name -> SetCookie
fromSimpleSetCookie