-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Galley.External.LegalHoldService.Internal
  ( makeVerifiedRequest,
    makeVerifiedRequestFreshManager,
  )
where

import Bilge qualified
import Bilge.Retry
import Control.Lens (view)
import Control.Monad.Catch
import Control.Retry
import Data.ByteString qualified as BS
import Data.ByteString.Lazy.Char8 qualified as LC8
import Data.Misc
import Galley.API.Error
import Galley.Env
import Galley.Monad
import Imports
import Network.HTTP.Client qualified as Http
import OpenSSL.Session qualified as SSL
import Ssl.Util
import System.Logger.Class qualified as Log
import URI.ByteString (uriPath)

-- | Check that the given fingerprint is valid and make the request over ssl.
-- If the team has a device registered use 'makeLegalHoldServiceRequest' instead.
makeVerifiedRequestWithManager :: Http.Manager -> ([Fingerprint Rsa] -> SSL.SSL -> IO ()) -> Fingerprint Rsa -> HttpsUrl -> (Http.Request -> Http.Request) -> App (Http.Response LC8.ByteString)
makeVerifiedRequestWithManager :: Manager
-> ([Fingerprint Rsa] -> SSL -> IO ())
-> Fingerprint Rsa
-> HttpsUrl
-> (Request -> Request)
-> App (Response ByteString)
makeVerifiedRequestWithManager Manager
mgr [Fingerprint Rsa] -> SSL -> IO ()
verifyFingerprints Fingerprint Rsa
fpr (HttpsUrl URIRef Absolute
url) Request -> Request
reqBuilder = do
  let verified :: SSL -> IO ()
verified = [Fingerprint Rsa] -> SSL -> IO ()
verifyFingerprints [Fingerprint Rsa
fpr]
  (SomeException -> App (Response ByteString))
-> App (Response ByteString) -> App (Response ByteString)
forall (m :: * -> *) a.
MonadCatch m =>
(SomeException -> m a) -> m a -> m a
extHandleAll SomeException -> App (Response ByteString)
forall {m :: * -> *} {a} {b}.
(MonadLogger m, Show a, MonadThrow m) =>
a -> m b
errHandler (App (Response ByteString) -> App (Response ByteString))
-> App (Response ByteString) -> App (Response ByteString)
forall a b. (a -> b) -> a -> b
$ do
    RetryPolicyM App
-> [RetryStatus -> Handler App Bool]
-> (RetryStatus -> App (Response ByteString))
-> App (Response ByteString)
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
RetryPolicyM m
-> [RetryStatus -> Handler m Bool] -> (RetryStatus -> m a) -> m a
recovering RetryPolicyM App
RetryPolicy
x3 [RetryStatus -> Handler App Bool]
forall (m :: * -> *) a. Monad m => [a -> Handler m Bool]
httpHandlers ((RetryStatus -> App (Response ByteString))
 -> App (Response ByteString))
-> (RetryStatus -> App (Response ByteString))
-> App (Response ByteString)
forall a b. (a -> b) -> a -> b
$
      App (Response ByteString)
-> RetryStatus -> App (Response ByteString)
forall a b. a -> b -> a
const (App (Response ByteString)
 -> RetryStatus -> App (Response ByteString))
-> App (Response ByteString)
-> RetryStatus
-> App (Response ByteString)
forall a b. (a -> b) -> a -> b
$
        IO (Response ByteString) -> App (Response ByteString)
forall a. IO a -> App a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Response ByteString) -> App (Response ByteString))
-> IO (Response ByteString) -> App (Response ByteString)
forall a b. (a -> b) -> a -> b
$
          (SSL -> IO ())
-> Manager
-> (Request -> Request)
-> (Request -> IO (Response ByteString))
-> IO (Response ByteString)
forall a.
(SSL -> IO ())
-> Manager -> (Request -> Request) -> (Request -> IO a) -> IO a
withVerifiedSslConnection SSL -> IO ()
verified Manager
mgr (Request -> Request
reqBuilderMods (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
reqBuilder) ((Request -> IO (Response ByteString)) -> IO (Response ByteString))
-> (Request -> IO (Response ByteString))
-> IO (Response ByteString)
forall a b. (a -> b) -> a -> b
$
            \Request
req ->
              Request -> Manager -> IO (Response ByteString)
Http.httpLbs Request
req Manager
mgr
  where
    reqBuilderMods :: Request -> Request
reqBuilderMods =
      (Request -> Request)
-> (ByteString -> Request -> Request)
-> Maybe ByteString
-> Request
-> Request
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Request -> Request
forall a. a -> a
id ByteString -> Request -> Request
Bilge.host (URIRef Absolute -> Maybe ByteString
Bilge.extHost URIRef Absolute
url)
        (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Request -> Request
Bilge.port (Word16 -> Maybe Word16 -> Word16
forall a. a -> Maybe a -> a
fromMaybe Word16
443 (URIRef Absolute -> Maybe Word16
Bilge.extPort URIRef Absolute
url))
        (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
Bilge.secure
        (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Request -> Request
prependPath (URIRef Absolute -> ByteString
uriPath URIRef Absolute
url)
    errHandler :: a -> m b
errHandler a
e = do
      (Msg -> Msg) -> m ()
forall (m :: * -> *). MonadLogger m => (Msg -> Msg) -> m ()
Log.info ((Msg -> Msg) -> m ()) -> (String -> Msg -> Msg) -> String -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
Log.msg (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"error making request to legalhold service: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
e
      Error -> m b
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (a -> Error
forall a. Show a => a -> Error
legalHoldServiceUnavailable a
e)
    prependPath :: ByteString -> Http.Request -> Http.Request
    prependPath :: ByteString -> Request -> Request
prependPath ByteString
pth Request
req = Request
req {Http.path = pth </> Http.path req}
    -- append two paths with exactly one slash
    (</>) :: ByteString -> ByteString -> ByteString
    ByteString
a </> :: ByteString -> ByteString -> ByteString
</> ByteString
b = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
a (ByteString -> ByteString -> Maybe ByteString
BS.stripSuffix ByteString
"/" ByteString
a) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
b (ByteString -> ByteString -> Maybe ByteString
BS.stripPrefix ByteString
"/" ByteString
b)
    x3 :: RetryPolicy
    x3 :: RetryPolicy
x3 = Int -> RetryPolicy
limitRetries Int
3 RetryPolicyM m -> RetryPolicyM m -> RetryPolicyM m
forall a. Semigroup a => a -> a -> a
<> Int -> RetryPolicyM m
forall (m :: * -> *). Monad m => Int -> RetryPolicyM m
exponentialBackoff Int
100000
    extHandleAll :: (MonadCatch m) => (SomeException -> m a) -> m a -> m a
    extHandleAll :: forall (m :: * -> *) a.
MonadCatch m =>
(SomeException -> m a) -> m a -> m a
extHandleAll SomeException -> m a
f m a
ma =
      m a -> [Handler m a] -> m a
forall (f :: * -> *) (m :: * -> *) a.
(HasCallStack, Foldable f, MonadCatch m) =>
m a -> f (Handler m a) -> m a
catches
        m a
ma
        [ (SomeAsyncException -> m a) -> Handler m a
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((SomeAsyncException -> m a) -> Handler m a)
-> (SomeAsyncException -> m a) -> Handler m a
forall a b. (a -> b) -> a -> b
$ \(SomeAsyncException
ex :: SomeAsyncException) -> SomeAsyncException -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeAsyncException
ex,
          (SomeException -> m a) -> Handler m a
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((SomeException -> m a) -> Handler m a)
-> (SomeException -> m a) -> Handler m a
forall a b. (a -> b) -> a -> b
$ \(SomeException
ex :: SomeException) -> SomeException -> m a
f SomeException
ex
        ]

makeVerifiedRequest ::
  Fingerprint Rsa ->
  HttpsUrl ->
  (Http.Request -> Http.Request) ->
  App (Http.Response LC8.ByteString)
makeVerifiedRequest :: Fingerprint Rsa
-> HttpsUrl -> (Request -> Request) -> App (Response ByteString)
makeVerifiedRequest Fingerprint Rsa
fpr HttpsUrl
url Request -> Request
reqBuilder = do
  (Manager
mgr, [Fingerprint Rsa] -> SSL -> IO ()
verifyFingerprints) <- Getting
  (Manager, [Fingerprint Rsa] -> SSL -> IO ())
  Env
  (Manager, [Fingerprint Rsa] -> SSL -> IO ())
-> App (Manager, [Fingerprint Rsa] -> SSL -> IO ())
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view ((ExtEnv
 -> Const (Manager, [Fingerprint Rsa] -> SSL -> IO ()) ExtEnv)
-> Env -> Const (Manager, [Fingerprint Rsa] -> SSL -> IO ()) Env
Lens' Env ExtEnv
extEnv ((ExtEnv
  -> Const (Manager, [Fingerprint Rsa] -> SSL -> IO ()) ExtEnv)
 -> Env -> Const (Manager, [Fingerprint Rsa] -> SSL -> IO ()) Env)
-> (((Manager, [Fingerprint Rsa] -> SSL -> IO ())
     -> Const
          (Manager, [Fingerprint Rsa] -> SSL -> IO ())
          (Manager, [Fingerprint Rsa] -> SSL -> IO ()))
    -> ExtEnv
    -> Const (Manager, [Fingerprint Rsa] -> SSL -> IO ()) ExtEnv)
-> Getting
     (Manager, [Fingerprint Rsa] -> SSL -> IO ())
     Env
     (Manager, [Fingerprint Rsa] -> SSL -> IO ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Manager, [Fingerprint Rsa] -> SSL -> IO ())
 -> Const
      (Manager, [Fingerprint Rsa] -> SSL -> IO ())
      (Manager, [Fingerprint Rsa] -> SSL -> IO ()))
-> ExtEnv
-> Const (Manager, [Fingerprint Rsa] -> SSL -> IO ()) ExtEnv
Iso' ExtEnv (Manager, [Fingerprint Rsa] -> SSL -> IO ())
extGetManager)
  Manager
-> ([Fingerprint Rsa] -> SSL -> IO ())
-> Fingerprint Rsa
-> HttpsUrl
-> (Request -> Request)
-> App (Response ByteString)
makeVerifiedRequestWithManager Manager
mgr [Fingerprint Rsa] -> SSL -> IO ()
verifyFingerprints Fingerprint Rsa
fpr HttpsUrl
url Request -> Request
reqBuilder

-- | NOTE: Use this function wisely - this creates a new manager _every_ time it is called.
--   We should really _only_ use it in `checkLegalHoldServiceStatus` for the time being because
--   this is where we check for signatures, etc. If we reuse the manager, we are likely to reuse
--   an existing connection which will _not_ cause the new public key to be verified.
makeVerifiedRequestFreshManager ::
  Fingerprint Rsa ->
  HttpsUrl ->
  (Http.Request -> Http.Request) ->
  App (Http.Response LC8.ByteString)
makeVerifiedRequestFreshManager :: Fingerprint Rsa
-> HttpsUrl -> (Request -> Request) -> App (Response ByteString)
makeVerifiedRequestFreshManager Fingerprint Rsa
fpr HttpsUrl
url Request -> Request
reqBuilder = do
  ExtEnv (Manager
mgr, [Fingerprint Rsa] -> SSL -> IO ()
verifyFingerprints) <- IO ExtEnv -> App ExtEnv
forall a. IO a -> App a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ExtEnv
initExtEnv
  Manager
-> ([Fingerprint Rsa] -> SSL -> IO ())
-> Fingerprint Rsa
-> HttpsUrl
-> (Request -> Request)
-> App (Response ByteString)
makeVerifiedRequestWithManager Manager
mgr [Fingerprint Rsa] -> SSL -> IO ()
verifyFingerprints Fingerprint Rsa
fpr HttpsUrl
url Request -> Request
reqBuilder