-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

-- | Code from "Servant.Server.Internal", modified very slightly to allow for returning json
-- errors instead of plaintext.
module Servant.API.Extended where

import Data.ByteString
import Data.ByteString.Lazy qualified as BL
import Data.EitherR (fmapL)
import Data.Kind
import Data.Metrics.Servant
import Data.Typeable
import GHC.TypeLits
import Imports
import Network.HTTP.Types hiding (Header, ResponseHeaders)
import Network.Wai
import Servant.API
import Servant.API.ContentTypes
import Servant.API.Modifiers
import Servant.OpenApi
import Servant.Server.Internal
import Prelude ()

-- | Like 'ReqBody'', but takes parsers that throw 'ServerError', not 'String'.  @tag@ is used
-- to select a 'MakeCustomError' instance.
--
-- FUTUREWORK: this does not reflect the changes we make to the error responses wrt. the
-- 'ReqBody'' instance.  however, in order to fix that, we would need to get more information
-- out of the 'MakeCustomError' instance and into 'ReqBodyCustomError''. Perhaps something
-- like @data ReqBody (mods :: [*]) (headers :: ...) (status :: ...) (list :: [ct]) (tag ::
-- Symbol) (a :: *)@.  and then we'll trip over issues similar to this one:
-- https://github.com/wireapp/servant-uverb/blob/3647c488a88137d3ec2583b518bda59ee7072278/servant-uverb/src/Servant/API/UVerb.hs#L33-L57
--
-- FUTUREWORK: this approach is not ideal because it makes it hard to avoid orphan instances.
--
-- FUTUREWORK: parser failures currently can't have custom monad effects like logging, since
-- they are run inside 'DelayedIO'.  we can either work around this by writing a middleware
-- that inspects the response and logs conditionally what it finds in the body (bad for
-- streaming and performance!), or re-wire more of the servant internals (unclear how hard
-- that'll be).
--
-- See also: https://github.com/haskell-servant/servant/issues/353
data ReqBodyCustomError' (mods :: [Type]) (list :: [ct]) (tag :: Symbol) (a :: Type)

type ReqBodyCustomError = ReqBodyCustomError' '[Required, Strict]

-- | Custom parse error for bad request bodies.
class MakeCustomError (tag :: Symbol) (a :: Type) where
  makeCustomError :: String -> ServerError

-- | Variant of the 'ReqBody'' instance that takes a 'ServerError' as argument instead of a
-- 'String'.  This gives the caller more control over error responses.
instance
  ( MakeCustomError tag a,
    AllCTUnrender list a,
    HasServer api context,
    SBoolI (FoldLenient mods)
  ) =>
  HasServer (ReqBodyCustomError' mods list tag a :> api) context
  where
  type
    ServerT (ReqBodyCustomError' mods list tag a :> api) m =
      If (FoldLenient mods) (Either ServerError a) a -> ServerT api m

  hoistServerWithContext :: forall (m :: * -> *) (n :: * -> *).
Proxy (ReqBodyCustomError' mods list tag a :> api)
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT (ReqBodyCustomError' mods list tag a :> api) m
-> ServerT (ReqBodyCustomError' mods list tag a :> api) n
hoistServerWithContext Proxy (ReqBodyCustomError' mods list tag a :> api)
_ Proxy context
pc forall x. m x -> n x
nt ServerT (ReqBodyCustomError' mods list tag a :> api) m
s = Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
forall {k} (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
forall (m :: * -> *) (n :: * -> *).
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (Proxy api
forall {k} (t :: k). Proxy t
Proxy :: Proxy api) Proxy context
pc m x -> n x
forall x. m x -> n x
nt (ServerT api m -> ServerT api n)
-> (If (FoldLenient mods) (Either ServerError a) a
    -> ServerT api m)
-> If (FoldLenient mods) (Either ServerError a) a
-> ServerT api n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerT (ReqBodyCustomError' mods list tag a :> api) m
If (FoldLenient mods) (Either ServerError a) a -> ServerT api m
s

  route :: forall env.
Proxy (ReqBodyCustomError' mods list tag a :> api)
-> Context context
-> Delayed
     env (Server (ReqBodyCustomError' mods list tag a :> api))
-> Router env
route Proxy (ReqBodyCustomError' mods list tag a :> api)
Proxy Context context
context Delayed env (Server (ReqBodyCustomError' mods list tag a :> api))
subserver =
    Proxy api
-> Context context -> Delayed env (Server api) -> Router env
forall env.
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
forall {k} (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (Proxy api
forall {k} (t :: k). Proxy t
Proxy :: Proxy api) Context context
context (Delayed env (Server api) -> Router env)
-> Delayed env (Server api) -> Router env
forall a b. (a -> b) -> a -> b
$
      Delayed
  env (If (FoldLenient mods) (Either ServerError a) a -> Server api)
-> DelayedIO (ByteString -> Either String a)
-> ((ByteString -> Either String a)
    -> DelayedIO (If (FoldLenient mods) (Either ServerError a) a))
-> Delayed env (Server api)
forall env a b c.
Delayed env (a -> b)
-> DelayedIO c -> (c -> DelayedIO a) -> Delayed env b
addBodyCheck Delayed env (Server (ReqBodyCustomError' mods list tag a :> api))
Delayed
  env (If (FoldLenient mods) (Either ServerError a) a -> Server api)
subserver DelayedIO (ByteString -> Either String a)
ctCheck (ByteString -> Either String a)
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
bodyCheck
    where
      -- Content-Type check, we only lookup we can try to parse the request body
      ctCheck :: DelayedIO (ByteString -> Either String a)
ctCheck = (Request -> DelayedIO (ByteString -> Either String a))
-> DelayedIO (ByteString -> Either String a)
forall a. (Request -> DelayedIO a) -> DelayedIO a
withRequest ((Request -> DelayedIO (ByteString -> Either String a))
 -> DelayedIO (ByteString -> Either String a))
-> (Request -> DelayedIO (ByteString -> Either String a))
-> DelayedIO (ByteString -> Either String a)
forall a b. (a -> b) -> a -> b
$ \Request
request -> do
        -- See HTTP RFC 2616, section 7.2.1
        -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
        -- See also "W3C Internet Media Type registration, consistency of use"
        -- http://www.w3.org/2001/tag/2002/0129-mime
        let contentTypeH :: ByteString
contentTypeH =
              ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"application/octet-stream" (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
                HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hContentType ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
                  Request -> [(HeaderName, ByteString)]
requestHeaders Request
request
        case Proxy list -> ByteString -> Maybe (ByteString -> Either String a)
forall (list :: [*]) a.
AllCTUnrender list a =>
Proxy list -> ByteString -> Maybe (ByteString -> Either String a)
canHandleCTypeH (Proxy list
forall {k} (t :: k). Proxy t
Proxy :: Proxy list) (ByteString -> ByteString
fromStrict ByteString
contentTypeH) :: Maybe (BL.ByteString -> Either String a) of
          Maybe (ByteString -> Either String a)
Nothing -> ServerError -> DelayedIO (ByteString -> Either String a)
forall a. ServerError -> DelayedIO a
delayedFail ServerError
err415
          Just ByteString -> Either String a
f -> (ByteString -> Either String a)
-> DelayedIO (ByteString -> Either String a)
forall a. a -> DelayedIO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString -> Either String a
f
      -- Body check, we get a body parsing functions as the first argument.
      bodyCheck ::
        (BL.ByteString -> Either String a) ->
        DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
      bodyCheck :: (ByteString -> Either String a)
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
bodyCheck ByteString -> Either String a
f = (Request
 -> DelayedIO (If (FoldLenient mods) (Either ServerError a) a))
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
forall a. (Request -> DelayedIO a) -> DelayedIO a
withRequest ((Request
  -> DelayedIO (If (FoldLenient mods) (Either ServerError a) a))
 -> DelayedIO (If (FoldLenient mods) (Either ServerError a) a))
-> (Request
    -> DelayedIO (If (FoldLenient mods) (Either ServerError a) a))
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
forall a b. (a -> b) -> a -> b
$ \Request
request -> do
        Either ServerError a
mrqbody <- (String -> ServerError) -> Either String a -> Either ServerError a
forall a b r. (a -> b) -> Either a r -> Either b r
fmapL (forall (tag :: Symbol) a.
MakeCustomError tag a =>
String -> ServerError
makeCustomError @tag @a) (Either String a -> Either ServerError a)
-> (ByteString -> Either String a)
-> ByteString
-> Either ServerError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String a
f (ByteString -> Either ServerError a)
-> DelayedIO ByteString -> DelayedIO (Either ServerError a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString -> DelayedIO ByteString
forall a. IO a -> DelayedIO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Request -> IO ByteString
lazyRequestBody Request
request)
        case SBool (FoldLenient mods)
forall (b :: Bool). SBoolI b => SBool b
sbool :: SBool (FoldLenient mods) of
          SBool (FoldLenient mods)
STrue -> If (FoldLenient mods) (Either ServerError a) a
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
forall a. a -> DelayedIO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Either ServerError a
If (FoldLenient mods) (Either ServerError a) a
mrqbody
          SBool (FoldLenient mods)
SFalse -> case Either ServerError a
mrqbody of
            Left ServerError
e -> ServerError
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
e
            Right a
v -> If (FoldLenient mods) (Either ServerError a) a
-> DelayedIO (If (FoldLenient mods) (Either ServerError a) a)
forall a. a -> DelayedIO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
If (FoldLenient mods) (Either ServerError a) a
v

instance
  (HasOpenApi (ReqBody' '[Required, Strict] cts a :> api)) =>
  HasOpenApi (ReqBodyCustomError cts tag a :> api)
  where
  toOpenApi :: Proxy (ReqBodyCustomError cts tag a :> api) -> OpenApi
toOpenApi Proxy (ReqBodyCustomError cts tag a :> api)
Proxy = Proxy (ReqBody' '[Required, Strict] cts a :> api) -> OpenApi
forall {k} (api :: k). HasOpenApi api => Proxy api -> OpenApi
toOpenApi (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @(ReqBody' '[Required, Strict] cts a :> api))

instance (RoutesToPaths rest) => RoutesToPaths (ReqBodyCustomError' mods list tag a :> rest) where
  getRoutes :: Forest PathSegment
getRoutes = forall routes. RoutesToPaths routes => Forest PathSegment
forall {k} (routes :: k).
RoutesToPaths routes =>
Forest PathSegment
getRoutes @rest