{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Bilge.RPC
  ( HasRequestId (..),
    RPCException (..),
    rpc,
    rpc',
    parseResponse,
    rpcExceptionMsg,
  )
where

import Bilge.IO
import Bilge.Request
import Bilge.Response
import Control.Error hiding (err)
import Control.Monad.Catch (MonadCatch, MonadThrow (..), try)
import Data.Aeson (FromJSON, eitherDecode')
import Data.CaseInsensitive (original)
import Data.Text.Lazy (pack)
import Data.Text.Lazy qualified as T
import Imports hiding (log)
import Network.HTTP.Client qualified as HTTP
import System.Logger.Class
import Wire.OpenTelemetry (withClientInstrumentation)

class HasRequestId m where
  getRequestId :: m RequestId

instance (Monad m) => HasRequestId (ReaderT RequestId m) where
  getRequestId :: ReaderT RequestId m RequestId
getRequestId = ReaderT RequestId m RequestId
forall r (m :: * -> *). MonadReader r m => m r
ask

data RPCException = RPCException
  { RPCException -> LText
rpceRemote :: !LText,
    RPCException -> Request
rpceRequest :: !Request,
    RPCException -> SomeException
rpceCause :: !SomeException
  }
  deriving (Typeable)

instance Exception RPCException

instance Show RPCException where
  showsPrec :: Int -> RPCException -> ShowS
showsPrec Int
_ (RPCException LText
r Request
rq (SomeException e
c)) =
    String -> ShowS
showString String
"RPCException {"
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"remote = "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LText -> ShowS
forall a. Show a => a -> ShowS
shows LText
r
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
", path = "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ShowS
forall a. Show a => a -> ShowS
shows (Request -> ByteString
HTTP.path Request
rq)
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
", headers = "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RequestHeaders -> ShowS
forall a. Show a => a -> ShowS
shows (Request -> RequestHeaders
HTTP.requestHeaders Request
rq)
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
", cause = "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> ShowS
forall a. Show a => a -> ShowS
shows e
c
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"}"

rpc ::
  (MonadUnliftIO m, MonadCatch m, MonadHttp m, HasRequestId m) =>
  LText ->
  (Request -> Request) ->
  m (Response (Maybe LByteString))
rpc :: forall (m :: * -> *).
(MonadUnliftIO m, MonadCatch m, MonadHttp m, HasRequestId m) =>
LText -> (Request -> Request) -> m (Response (Maybe LByteString))
rpc LText
sys = LText
-> Request
-> (Request -> Request)
-> m (Response (Maybe LByteString))
forall (m :: * -> *).
(MonadUnliftIO m, MonadCatch m, MonadHttp m, HasRequestId m) =>
LText
-> Request
-> (Request -> Request)
-> m (Response (Maybe LByteString))
rpc' LText
sys Request
empty

-- | Perform an HTTP request and return the response, thereby
-- forwarding the @Request-Id@ header from the current monadic
-- context.
-- Note: 'syncIO' is wrapped around the IO action performing the request
--       and any exceptions caught are re-thrown in an 'RPCException'.
rpc' ::
  (MonadUnliftIO m, MonadCatch m, MonadHttp m, HasRequestId m) =>
  -- | A label for the remote system in case of 'RPCException's.
  LText ->
  Request ->
  (Request -> Request) ->
  m (Response (Maybe LByteString))
rpc' :: forall (m :: * -> *).
(MonadUnliftIO m, MonadCatch m, MonadHttp m, HasRequestId m) =>
LText
-> Request
-> (Request -> Request)
-> m (Response (Maybe LByteString))
rpc' LText
sys Request
r Request -> Request
f = do
  RequestId
rId <- m RequestId
forall (m :: * -> *). HasRequestId m => m RequestId
getRequestId
  let rq :: Request
rq = Request -> Request
f (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$ RequestId -> Request -> Request
requestId RequestId
rId Request
r
  Either SomeException (Response (Maybe LByteString))
res <- m (Response (Maybe LByteString))
-> m (Either SomeException (Response (Maybe LByteString)))
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (m (Response (Maybe LByteString))
 -> m (Either SomeException (Response (Maybe LByteString))))
-> m (Response (Maybe LByteString))
-> m (Either SomeException (Response (Maybe LByteString)))
forall a b. (a -> b) -> a -> b
$ Text
-> ((Request
     -> (Request -> m (Response (Maybe LByteString)))
     -> m (Response (Maybe LByteString)))
    -> m (Response (Maybe LByteString)))
-> m (Response (Maybe LByteString))
forall (m :: * -> *) a b.
MonadUnliftIO m =>
Text
-> ((Request -> (Request -> m (Response a)) -> m (Response a))
    -> m b)
-> m b
withClientInstrumentation (Text
"intra-call-to-" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> LText -> Text
T.toStrict LText
sys) \Request
-> (Request -> m (Response (Maybe LByteString)))
-> m (Response (Maybe LByteString))
k -> do
    Request
-> (Request -> m (Response (Maybe LByteString)))
-> m (Response (Maybe LByteString))
k Request
rq \Request
r' -> Request -> (Request -> Request) -> m (Response (Maybe LByteString))
forall (m :: * -> *).
MonadHttp m =>
Request -> (Request -> Request) -> m (Response (Maybe LByteString))
httpLbs Request
r' Request -> Request
forall a. a -> a
id
  case Either SomeException (Response (Maybe LByteString))
res of
    Left SomeException
x -> RPCException -> m (Response (Maybe LByteString))
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (RPCException -> m (Response (Maybe LByteString)))
-> RPCException -> m (Response (Maybe LByteString))
forall a b. (a -> b) -> a -> b
$ LText -> Request -> SomeException -> RPCException
RPCException LText
sys Request
rq SomeException
x
    Right Response (Maybe LByteString)
x -> Response (Maybe LByteString) -> m (Response (Maybe LByteString))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Response (Maybe LByteString)
x

rpcExceptionMsg :: RPCException -> Msg -> Msg
rpcExceptionMsg :: RPCException -> Msg -> Msg
rpcExceptionMsg (RPCException LText
sys Request
req SomeException
ex) =
  ByteString
"remote" ByteString -> LText -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
.= LText
sys (Msg -> Msg) -> (Msg -> Msg) -> Msg -> Msg
forall b c a. (b -> c) -> (a -> b) -> a -> c
~~ ByteString
"path" ByteString -> ByteString -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
.= Request -> ByteString
HTTP.path Request
req (Msg -> Msg) -> (Msg -> Msg) -> Msg -> Msg
forall b c a. (b -> c) -> (a -> b) -> a -> c
~~ Msg -> Msg
headers (Msg -> Msg) -> (Msg -> Msg) -> Msg -> Msg
forall b c a. (b -> c) -> (a -> b) -> a -> c
~~ String -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
msg (SomeException -> String
forall a. Show a => a -> String
show SomeException
ex)
  where
    headers :: Msg -> Msg
headers = ((HeaderName, ByteString) -> (Msg -> Msg) -> Msg -> Msg)
-> (Msg -> Msg) -> RequestHeaders -> Msg -> Msg
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (HeaderName, ByteString) -> (Msg -> Msg) -> Msg -> Msg
forall {a} {c}.
ToBytes a =>
(HeaderName, a) -> (Msg -> c) -> Msg -> c
hdr Msg -> Msg
forall a. a -> a
id (Request -> RequestHeaders
HTTP.requestHeaders Request
req)
    hdr :: (HeaderName, a) -> (Msg -> c) -> Msg -> c
hdr (HeaderName
k, a
v) Msg -> c
x = Msg -> c
x (Msg -> c) -> (Msg -> Msg) -> Msg -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
~~ HeaderName -> ByteString
forall s. CI s -> s
original HeaderName
k ByteString -> a -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
.= a
v

parseResponse ::
  (Exception e, MonadThrow m, FromJSON a) =>
  (LText -> e) ->
  Response (Maybe LByteString) ->
  m a
parseResponse :: forall e (m :: * -> *) a.
(Exception e, MonadThrow m, FromJSON a) =>
(LText -> e) -> Response (Maybe LByteString) -> m a
parseResponse LText -> e
f Response (Maybe LByteString)
r = (e -> m a) -> (a -> m a) -> Either e a -> m a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either e -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either e a -> m a) -> Either e a -> m a
forall a b. (a -> b) -> a -> b
$ do
  LByteString
b <- e -> Maybe LByteString -> Either e LByteString
forall a b. a -> Maybe b -> Either a b
note (LText -> e
f LText
"no response body") (Response (Maybe LByteString) -> Maybe LByteString
forall body. Response body -> body
responseBody Response (Maybe LByteString)
r)
  (String -> e) -> Either String a -> Either e a
forall a b r. (a -> b) -> Either a r -> Either b r
fmapL (LText -> e
f (LText -> e) -> (String -> LText) -> String -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> LText
pack) (LByteString -> Either String a
forall a. FromJSON a => LByteString -> Either String a
eitherDecode' LByteString
b)