{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
module Servant.Server.Internal.BasicAuth where
import Control.Monad
(guard)
import Control.Monad.Trans
(liftIO)
import qualified Data.ByteString as BS
import Data.ByteString.Base64
(decodeLenient)
import Data.Typeable
(Typeable)
import Data.Word8
(isSpace, toLower, _colon)
import GHC.Generics
import Network.HTTP.Types
(Header)
import Network.Wai
(Request, requestHeaders)
import Servant.API.BasicAuth
(BasicAuthData (BasicAuthData))
import Servant.Server.Internal.DelayedIO
import Servant.Server.Internal.ServerError
data BasicAuthResult usr
= Unauthorized
| BadPassword
| NoSuchUser
| Authorized usr
deriving (BasicAuthResult usr -> BasicAuthResult usr -> Bool
(BasicAuthResult usr -> BasicAuthResult usr -> Bool)
-> (BasicAuthResult usr -> BasicAuthResult usr -> Bool)
-> Eq (BasicAuthResult usr)
forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
== :: BasicAuthResult usr -> BasicAuthResult usr -> Bool
$c/= :: forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
/= :: BasicAuthResult usr -> BasicAuthResult usr -> Bool
Eq, Int -> BasicAuthResult usr -> ShowS
[BasicAuthResult usr] -> ShowS
BasicAuthResult usr -> String
(Int -> BasicAuthResult usr -> ShowS)
-> (BasicAuthResult usr -> String)
-> ([BasicAuthResult usr] -> ShowS)
-> Show (BasicAuthResult usr)
forall usr. Show usr => Int -> BasicAuthResult usr -> ShowS
forall usr. Show usr => [BasicAuthResult usr] -> ShowS
forall usr. Show usr => BasicAuthResult usr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall usr. Show usr => Int -> BasicAuthResult usr -> ShowS
showsPrec :: Int -> BasicAuthResult usr -> ShowS
$cshow :: forall usr. Show usr => BasicAuthResult usr -> String
show :: BasicAuthResult usr -> String
$cshowList :: forall usr. Show usr => [BasicAuthResult usr] -> ShowS
showList :: [BasicAuthResult usr] -> ShowS
Show, ReadPrec [BasicAuthResult usr]
ReadPrec (BasicAuthResult usr)
Int -> ReadS (BasicAuthResult usr)
ReadS [BasicAuthResult usr]
(Int -> ReadS (BasicAuthResult usr))
-> ReadS [BasicAuthResult usr]
-> ReadPrec (BasicAuthResult usr)
-> ReadPrec [BasicAuthResult usr]
-> Read (BasicAuthResult usr)
forall usr. Read usr => ReadPrec [BasicAuthResult usr]
forall usr. Read usr => ReadPrec (BasicAuthResult usr)
forall usr. Read usr => Int -> ReadS (BasicAuthResult usr)
forall usr. Read usr => ReadS [BasicAuthResult usr]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: forall usr. Read usr => Int -> ReadS (BasicAuthResult usr)
readsPrec :: Int -> ReadS (BasicAuthResult usr)
$creadList :: forall usr. Read usr => ReadS [BasicAuthResult usr]
readList :: ReadS [BasicAuthResult usr]
$creadPrec :: forall usr. Read usr => ReadPrec (BasicAuthResult usr)
readPrec :: ReadPrec (BasicAuthResult usr)
$creadListPrec :: forall usr. Read usr => ReadPrec [BasicAuthResult usr]
readListPrec :: ReadPrec [BasicAuthResult usr]
Read, (forall x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x)
-> (forall x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr)
-> Generic (BasicAuthResult usr)
forall x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
forall x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall usr x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
forall usr x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
$cfrom :: forall usr x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
from :: forall x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
$cto :: forall usr x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
to :: forall x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
Generic, Typeable, (forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b)
-> (forall a b. a -> BasicAuthResult b -> BasicAuthResult a)
-> Functor BasicAuthResult
forall a b. a -> BasicAuthResult b -> BasicAuthResult a
forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
fmap :: forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
$c<$ :: forall a b. a -> BasicAuthResult b -> BasicAuthResult a
<$ :: forall a b. a -> BasicAuthResult b -> BasicAuthResult a
Functor)
newtype BasicAuthCheck usr = BasicAuthCheck
{ forall usr.
BasicAuthCheck usr -> BasicAuthData -> IO (BasicAuthResult usr)
unBasicAuthCheck :: BasicAuthData
-> IO (BasicAuthResult usr)
}
deriving ((forall x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x)
-> (forall x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr)
-> Generic (BasicAuthCheck usr)
forall x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
forall x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall usr x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
forall usr x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
$cfrom :: forall usr x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
from :: forall x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
$cto :: forall usr x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
to :: forall x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
Generic, Typeable, (forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b)
-> (forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a)
-> Functor BasicAuthCheck
forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
fmap :: forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
$c<$ :: forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
<$ :: forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
Functor)
mkBAChallengerHdr :: BS.ByteString -> Header
mkBAChallengerHdr :: ByteString -> Header
mkBAChallengerHdr ByteString
realm = (HeaderName
"WWW-Authenticate", ByteString
"Basic realm=\"" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
realm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\"")
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr Request
req = do
ByteString
ah <- HeaderName -> [Header] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req
let (ByteString
b, ByteString
rest) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isSpace ByteString
ah
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((Word8 -> Word8) -> ByteString -> ByteString
BS.map Word8 -> Word8
toLower ByteString
b ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"basic")
let decoded :: ByteString
decoded = ByteString -> ByteString
decodeLenient ((Word8 -> Bool) -> ByteString -> ByteString
BS.dropWhile Word8 -> Bool
isSpace ByteString
rest)
let (ByteString
username, ByteString
passWithColonAtHead) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_colon) ByteString
decoded
(Word8
_, ByteString
password) <- ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
passWithColonAtHead
BasicAuthData -> Maybe BasicAuthData
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> BasicAuthData
BasicAuthData ByteString
username ByteString
password)
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth :: forall usr.
Request -> ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth Request
req ByteString
realm (BasicAuthCheck BasicAuthData -> IO (BasicAuthResult usr)
ba) =
case Request -> Maybe BasicAuthData
decodeBAHdr Request
req of
Maybe BasicAuthData
Nothing -> DelayedIO usr
forall {a}. DelayedIO a
plzAuthenticate
Just BasicAuthData
e -> IO (BasicAuthResult usr) -> DelayedIO (BasicAuthResult usr)
forall a. IO a -> DelayedIO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (BasicAuthData -> IO (BasicAuthResult usr)
ba BasicAuthData
e) DelayedIO (BasicAuthResult usr)
-> (BasicAuthResult usr -> DelayedIO usr) -> DelayedIO usr
forall a b. DelayedIO a -> (a -> DelayedIO b) -> DelayedIO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \BasicAuthResult usr
res -> case BasicAuthResult usr
res of
BasicAuthResult usr
BadPassword -> DelayedIO usr
forall {a}. DelayedIO a
plzAuthenticate
BasicAuthResult usr
NoSuchUser -> DelayedIO usr
forall {a}. DelayedIO a
plzAuthenticate
BasicAuthResult usr
Unauthorized -> ServerError -> DelayedIO usr
forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
err403
Authorized usr
usr -> usr -> DelayedIO usr
forall a. a -> DelayedIO a
forall (m :: * -> *) a. Monad m => a -> m a
return usr
usr
where plzAuthenticate :: DelayedIO a
plzAuthenticate = ServerError -> DelayedIO a
forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
err401 { errHeaders = [mkBAChallengerHdr realm] }