{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}

module Servant.Server.Experimental.Auth where

import           Control.Monad.Trans
                 (liftIO)
import           Data.Proxy
                 (Proxy (Proxy))
import           Data.Typeable
                 (Typeable)
import           GHC.Generics
                 (Generic)
import           Network.Wai
                 (Request)

import           Servant
                 ((:>))
import           Servant.API.Experimental.Auth
import           Servant.Server.Internal
                 (DelayedIO, Handler, HasContextEntry, HasServer (..),
                 addAuthCheck, delayedFailFatal, getContextEntry, runHandler,
                 withRequest)

-- * General Auth

-- | Specify the type of data returned after we've authenticated a request.
-- quite often this is some `User` datatype.
--
-- NOTE: THIS API IS EXPERIMENTAL AND SUBJECT TO CHANGE
type family AuthServerData a :: *

-- | Handlers for AuthProtected resources
--
-- NOTE: THIS API IS EXPERIMENTAL AND SUBJECT TO CHANGE
newtype AuthHandler r usr = AuthHandler
  { forall r usr. AuthHandler r usr -> r -> Handler usr
unAuthHandler :: r -> Handler usr }
  deriving ((forall a b. (a -> b) -> AuthHandler r a -> AuthHandler r b)
-> (forall a b. a -> AuthHandler r b -> AuthHandler r a)
-> Functor (AuthHandler r)
forall a b. a -> AuthHandler r b -> AuthHandler r a
forall a b. (a -> b) -> AuthHandler r a -> AuthHandler r b
forall r a b. a -> AuthHandler r b -> AuthHandler r a
forall r a b. (a -> b) -> AuthHandler r a -> AuthHandler r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall r a b. (a -> b) -> AuthHandler r a -> AuthHandler r b
fmap :: forall a b. (a -> b) -> AuthHandler r a -> AuthHandler r b
$c<$ :: forall r a b. a -> AuthHandler r b -> AuthHandler r a
<$ :: forall a b. a -> AuthHandler r b -> AuthHandler r a
Functor, (forall x. AuthHandler r usr -> Rep (AuthHandler r usr) x)
-> (forall x. Rep (AuthHandler r usr) x -> AuthHandler r usr)
-> Generic (AuthHandler r usr)
forall x. Rep (AuthHandler r usr) x -> AuthHandler r usr
forall x. AuthHandler r usr -> Rep (AuthHandler r usr) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall r usr x. Rep (AuthHandler r usr) x -> AuthHandler r usr
forall r usr x. AuthHandler r usr -> Rep (AuthHandler r usr) x
$cfrom :: forall r usr x. AuthHandler r usr -> Rep (AuthHandler r usr) x
from :: forall x. AuthHandler r usr -> Rep (AuthHandler r usr) x
$cto :: forall r usr x. Rep (AuthHandler r usr) x -> AuthHandler r usr
to :: forall x. Rep (AuthHandler r usr) x -> AuthHandler r usr
Generic, Typeable)

-- | NOTE: THIS API IS EXPERIMENTAL AND SUBJECT TO CHANGE
mkAuthHandler :: (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler :: forall r usr. (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler = (r -> Handler usr) -> AuthHandler r usr
forall r usr. (r -> Handler usr) -> AuthHandler r usr
AuthHandler

-- | Known orphan instance.
instance ( HasServer api context
         , HasContextEntry context (AuthHandler Request (AuthServerData (AuthProtect tag)))
         )
  => HasServer (AuthProtect tag :> api) context where

  type ServerT (AuthProtect tag :> api) m =
    AuthServerData (AuthProtect tag) -> ServerT api m

  hoistServerWithContext :: forall (m :: * -> *) (n :: * -> *).
Proxy (AuthProtect tag :> api)
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT (AuthProtect tag :> api) m
-> ServerT (AuthProtect tag :> api) n
hoistServerWithContext Proxy (AuthProtect tag :> api)
_ Proxy context
pc forall x. m x -> n x
nt ServerT (AuthProtect tag :> api) m
s = Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
forall {k} (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
forall (m :: * -> *) (n :: * -> *).
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (Proxy api
forall {k} (t :: k). Proxy t
Proxy :: Proxy api) Proxy context
pc m x -> n x
forall x. m x -> n x
nt (ServerT api m -> ServerT api n)
-> (AuthServerData (AuthProtect tag) -> ServerT api m)
-> AuthServerData (AuthProtect tag)
-> ServerT api n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerT (AuthProtect tag :> api) m
AuthServerData (AuthProtect tag) -> ServerT api m
s

  route :: forall env.
Proxy (AuthProtect tag :> api)
-> Context context
-> Delayed env (Server (AuthProtect tag :> api))
-> Router env
route Proxy (AuthProtect tag :> api)
Proxy Context context
context Delayed env (Server (AuthProtect tag :> api))
subserver =
    Proxy api
-> Context context -> Delayed env (Server api) -> Router env
forall env.
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
forall {k} (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (Proxy api
forall {k} (t :: k). Proxy t
Proxy :: Proxy api) Context context
context (Delayed env (Server (AuthProtect tag :> api))
Delayed env (AuthServerData (AuthProtect tag) -> Server api)
subserver Delayed env (AuthServerData (AuthProtect tag) -> Server api)
-> DelayedIO (AuthServerData (AuthProtect tag))
-> Delayed env (Server api)
forall env a b.
Delayed env (a -> b) -> DelayedIO a -> Delayed env b
`addAuthCheck` (Request -> DelayedIO (AuthServerData (AuthProtect tag)))
-> DelayedIO (AuthServerData (AuthProtect tag))
forall a. (Request -> DelayedIO a) -> DelayedIO a
withRequest Request -> DelayedIO (AuthServerData (AuthProtect tag))
authCheck)
      where
        authHandler :: Request -> Handler (AuthServerData (AuthProtect tag))
        authHandler :: Request -> Handler (AuthServerData (AuthProtect tag))
authHandler = AuthHandler Request (AuthServerData (AuthProtect tag))
-> Request -> Handler (AuthServerData (AuthProtect tag))
forall r usr. AuthHandler r usr -> r -> Handler usr
unAuthHandler (Context context
-> AuthHandler Request (AuthServerData (AuthProtect tag))
forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context context
context)
        authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag))
        authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag))
authCheck = (DelayedIO (Either ServerError (AuthServerData (AuthProtect tag)))
-> (Either ServerError (AuthServerData (AuthProtect tag))
    -> DelayedIO (AuthServerData (AuthProtect tag)))
-> DelayedIO (AuthServerData (AuthProtect tag))
forall a b. DelayedIO a -> (a -> DelayedIO b) -> DelayedIO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ServerError -> DelayedIO (AuthServerData (AuthProtect tag)))
-> (AuthServerData (AuthProtect tag)
    -> DelayedIO (AuthServerData (AuthProtect tag)))
-> Either ServerError (AuthServerData (AuthProtect tag))
-> DelayedIO (AuthServerData (AuthProtect tag))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ServerError -> DelayedIO (AuthServerData (AuthProtect tag))
forall a. ServerError -> DelayedIO a
delayedFailFatal AuthServerData (AuthProtect tag)
-> DelayedIO (AuthServerData (AuthProtect tag))
forall a. a -> DelayedIO a
forall (m :: * -> *) a. Monad m => a -> m a
return) (DelayedIO (Either ServerError (AuthServerData (AuthProtect tag)))
 -> DelayedIO (AuthServerData (AuthProtect tag)))
-> (Request
    -> DelayedIO
         (Either ServerError (AuthServerData (AuthProtect tag))))
-> Request
-> DelayedIO (AuthServerData (AuthProtect tag))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Either ServerError (AuthServerData (AuthProtect tag)))
-> DelayedIO
     (Either ServerError (AuthServerData (AuthProtect tag)))
forall a. IO a -> DelayedIO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either ServerError (AuthServerData (AuthProtect tag)))
 -> DelayedIO
      (Either ServerError (AuthServerData (AuthProtect tag))))
-> (Request
    -> IO (Either ServerError (AuthServerData (AuthProtect tag))))
-> Request
-> DelayedIO
     (Either ServerError (AuthServerData (AuthProtect tag)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handler (AuthServerData (AuthProtect tag))
-> IO (Either ServerError (AuthServerData (AuthProtect tag)))
forall a. Handler a -> IO (Either ServerError a)
runHandler (Handler (AuthServerData (AuthProtect tag))
 -> IO (Either ServerError (AuthServerData (AuthProtect tag))))
-> (Request -> Handler (AuthServerData (AuthProtect tag)))
-> Request
-> IO (Either ServerError (AuthServerData (AuthProtect tag)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Handler (AuthServerData (AuthProtect tag))
authHandler