{-# LANGUAGE OverloadedStrings #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Network.Wai.Utilities.ZAuth
  ( ZAuthType (..),
    (<&.),
    (.&>),
  )
where

import Data.ByteString.Conversion
import Imports
import Network.Wai.Predicate

-- ZAuth headers --------------------------------------------------------------

-- | Identifies the type of token used in an authenticated request.
data ZAuthType
  = -- | (Typically short-lived) access token.
    ZAuthAccess
  | -- | A user (aka refresh) token that can itself be used to
    -- obtain (short-lived) access tokens.
    ZAuthUser
  | -- | A bot token scoped to a specific bot and conversation,
    -- and issued to a certain service provider.
    ZAuthBot
  | -- | A provider token scoped to the provider management API.
    ZAuthProvider
  deriving (ZAuthType -> ZAuthType -> Bool
(ZAuthType -> ZAuthType -> Bool)
-> (ZAuthType -> ZAuthType -> Bool) -> Eq ZAuthType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ZAuthType -> ZAuthType -> Bool
== :: ZAuthType -> ZAuthType -> Bool
$c/= :: ZAuthType -> ZAuthType -> Bool
/= :: ZAuthType -> ZAuthType -> Bool
Eq, Int -> ZAuthType -> ShowS
[ZAuthType] -> ShowS
ZAuthType -> String
(Int -> ZAuthType -> ShowS)
-> (ZAuthType -> String)
-> ([ZAuthType] -> ShowS)
-> Show ZAuthType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ZAuthType -> ShowS
showsPrec :: Int -> ZAuthType -> ShowS
$cshow :: ZAuthType -> String
show :: ZAuthType -> String
$cshowList :: [ZAuthType] -> ShowS
showList :: [ZAuthType] -> ShowS
Show, Int -> ZAuthType
ZAuthType -> Int
ZAuthType -> [ZAuthType]
ZAuthType -> ZAuthType
ZAuthType -> ZAuthType -> [ZAuthType]
ZAuthType -> ZAuthType -> ZAuthType -> [ZAuthType]
(ZAuthType -> ZAuthType)
-> (ZAuthType -> ZAuthType)
-> (Int -> ZAuthType)
-> (ZAuthType -> Int)
-> (ZAuthType -> [ZAuthType])
-> (ZAuthType -> ZAuthType -> [ZAuthType])
-> (ZAuthType -> ZAuthType -> [ZAuthType])
-> (ZAuthType -> ZAuthType -> ZAuthType -> [ZAuthType])
-> Enum ZAuthType
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: ZAuthType -> ZAuthType
succ :: ZAuthType -> ZAuthType
$cpred :: ZAuthType -> ZAuthType
pred :: ZAuthType -> ZAuthType
$ctoEnum :: Int -> ZAuthType
toEnum :: Int -> ZAuthType
$cfromEnum :: ZAuthType -> Int
fromEnum :: ZAuthType -> Int
$cenumFrom :: ZAuthType -> [ZAuthType]
enumFrom :: ZAuthType -> [ZAuthType]
$cenumFromThen :: ZAuthType -> ZAuthType -> [ZAuthType]
enumFromThen :: ZAuthType -> ZAuthType -> [ZAuthType]
$cenumFromTo :: ZAuthType -> ZAuthType -> [ZAuthType]
enumFromTo :: ZAuthType -> ZAuthType -> [ZAuthType]
$cenumFromThenTo :: ZAuthType -> ZAuthType -> ZAuthType -> [ZAuthType]
enumFromThenTo :: ZAuthType -> ZAuthType -> ZAuthType -> [ZAuthType]
Enum, ZAuthType
ZAuthType -> ZAuthType -> Bounded ZAuthType
forall a. a -> a -> Bounded a
$cminBound :: ZAuthType
minBound :: ZAuthType
$cmaxBound :: ZAuthType
maxBound :: ZAuthType
Bounded, Eq ZAuthType
Eq ZAuthType =>
(ZAuthType -> ZAuthType -> Ordering)
-> (ZAuthType -> ZAuthType -> Bool)
-> (ZAuthType -> ZAuthType -> Bool)
-> (ZAuthType -> ZAuthType -> Bool)
-> (ZAuthType -> ZAuthType -> Bool)
-> (ZAuthType -> ZAuthType -> ZAuthType)
-> (ZAuthType -> ZAuthType -> ZAuthType)
-> Ord ZAuthType
ZAuthType -> ZAuthType -> Bool
ZAuthType -> ZAuthType -> Ordering
ZAuthType -> ZAuthType -> ZAuthType
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ZAuthType -> ZAuthType -> Ordering
compare :: ZAuthType -> ZAuthType -> Ordering
$c< :: ZAuthType -> ZAuthType -> Bool
< :: ZAuthType -> ZAuthType -> Bool
$c<= :: ZAuthType -> ZAuthType -> Bool
<= :: ZAuthType -> ZAuthType -> Bool
$c> :: ZAuthType -> ZAuthType -> Bool
> :: ZAuthType -> ZAuthType -> Bool
$c>= :: ZAuthType -> ZAuthType -> Bool
>= :: ZAuthType -> ZAuthType -> Bool
$cmax :: ZAuthType -> ZAuthType -> ZAuthType
max :: ZAuthType -> ZAuthType -> ZAuthType
$cmin :: ZAuthType -> ZAuthType -> ZAuthType
min :: ZAuthType -> ZAuthType -> ZAuthType
Ord)

instance FromByteString ZAuthType where
  parser :: Parser ZAuthType
parser = do
    ByteString
t <- Parser ByteString
forall a. FromByteString a => Parser a
parser
    case (ByteString
t :: ByteString) of
      ByteString
"access" -> ZAuthType -> Parser ZAuthType
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZAuthType
ZAuthAccess
      ByteString
"user" -> ZAuthType -> Parser ZAuthType
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZAuthType
ZAuthUser
      ByteString
"bot" -> ZAuthType -> Parser ZAuthType
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZAuthType
ZAuthBot
      ByteString
"provider" -> ZAuthType -> Parser ZAuthType
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZAuthType
ZAuthProvider
      ByteString
_ -> String -> Parser ZAuthType
forall a. String -> Parser ByteString a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Parser ZAuthType) -> String -> Parser ZAuthType
forall a b. (a -> b) -> a -> b
$ String
"Invalid ZAuth type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
t

-- Extra Predicate Combinators ------------------------------------------------

-- Variations of '.&.' that keep only the result of the left or right
-- predicate, respectively. These might be useful to add upstream
-- in 'wai-predicates'.

infixr 3 <&.

infixr 3 .&>

(<&.) :: Predicate a f t -> Predicate a f t' -> Predicate a f t
<&. :: forall a f t t'.
Predicate a f t -> Predicate a f t' -> Predicate a f t
(<&.) Predicate a f t
a Predicate a f t'
b = (Result f (t ::: t') -> Result f t)
-> (a -> Result f (t ::: t')) -> Predicate a f t
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((t ::: t') -> t) -> Result f (t ::: t') -> Result f t
forall a b. (a -> b) -> Result f a -> Result f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (t ::: t') -> t
forall a b. (a ::: b) -> a
hd) (Predicate a f t
a Predicate a f t -> Predicate a f t' -> a -> Result f (t ::: t')
forall a f t t'.
Predicate a f t -> Predicate a f t' -> Predicate a f (t ::: t')
.&. Predicate a f t'
b)

(.&>) :: Predicate a f t -> Predicate a f t' -> Predicate a f t'
.&> :: forall a f t t'.
Predicate a f t -> Predicate a f t' -> Predicate a f t'
(.&>) Predicate a f t
a Predicate a f t'
b = (Result f (t ::: t') -> Result f t')
-> (a -> Result f (t ::: t')) -> Predicate a f t'
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((t ::: t') -> t') -> Result f (t ::: t') -> Result f t'
forall a b. (a -> b) -> Result f a -> Result f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (t ::: t') -> t'
forall a b. (a ::: b) -> b
tl) (Predicate a f t
a Predicate a f t -> Predicate a f t' -> a -> Result f (t ::: t')
forall a f t t'.
Predicate a f t -> Predicate a f t' -> Predicate a f (t ::: t')
.&. Predicate a f t'
b)