module Testlib.MockIntegrationService
  ( withMockServer,
    CreateMock (..),
    MockServerSettings (..),

import Control.Monad.Catch
import Control.Monad.Reader
import qualified Data.Aeson
import qualified Data.ByteString.Lazy as LBS
import Data.Streaming.Network
import Data.String.Conversions (cs)
import Network.HTTP.Types
import Network.Socket
import qualified Network.Socket as Socket
import Network.Wai as Wai
import qualified Network.Wai.Handler.Warp as Warp
import qualified Network.Wai.Handler.Warp.Internal as Warp
import qualified Network.Wai.Handler.WarpTLS as Warp
import Testlib.Prelude hiding (IntegrationConfig (integrationTestHostName))
import UnliftIO (MonadUnliftIO (withRunInIO))
import UnliftIO.Async
import UnliftIO.Chan
import UnliftIO.MVar
import UnliftIO.Timeout (timeout)

withFreePortAnyAddr :: (MonadMask m, MonadIO m) => ((Warp.Port, Socket) -> m a) -> m a
withFreePortAnyAddr :: forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
((Port, Socket) -> m a) -> m a
withFreePortAnyAddr = m (Port, Socket)
-> ((Port, Socket) -> m ()) -> ((Port, Socket) -> m a) -> m a
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket m (Port, Socket)
forall (m :: * -> *). MonadIO m => m (Port, Socket)
openFreePortAnyAddr (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> ((Port, Socket) -> IO ()) -> (Port, Socket) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO ()
Socket.close (Socket -> IO ())
-> ((Port, Socket) -> Socket) -> (Port, Socket) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Port, Socket) -> Socket
forall a b. (a, b) -> b

openFreePortAnyAddr :: (MonadIO m) => m (Warp.Port, Socket)
openFreePortAnyAddr :: forall (m :: * -> *). MonadIO m => m (Port, Socket)
openFreePortAnyAddr = IO (Port, Socket) -> m (Port, Socket)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Port, Socket) -> m (Port, Socket))
-> IO (Port, Socket) -> m (Port, Socket)
forall a b. (a -> b) -> a -> b
$ HostPreference -> IO (Port, Socket)
bindRandomPortTCP (String -> HostPreference
forall a. IsString a => String -> a
fromString String

type LiftedApplication = Request -> (Wai.Response -> App ResponseReceived) -> App ResponseReceived

type Host = String

-- | The channel exists to facilitate out of http comms between the test and the
-- service. Could be used for recording (request, response) pairs.
withMockServer ::
  (HasCallStack) =>
  -- | the mock server settings
  MockServerSettings ->
  -- | The certificate and key pair
  (Chan e -> LiftedApplication) ->
  -- | the test
  ((Host, Warp.Port) -> Chan e -> App a) ->
  App a
withMockServer :: forall e a.
HasCallStack =>
-> (Chan e -> LiftedApplication)
-> ((String, Port) -> Chan e -> App a)
-> App a
withMockServer MockServerSettings
settings Chan e -> LiftedApplication
mkApp (String, Port) -> Chan e -> App a
go = ((Port, Socket) -> App a) -> App a
forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
((Port, Socket) -> m a) -> m a
withFreePortAnyAddr \(Port
sPort, Socket
sock) -> do
  MVar ()
serverStarted <- App (MVar ())
forall (m :: * -> *) a. MonadIO m => m (MVar a)
host <- (Env -> String) -> App String
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> String
  let tlss :: TLSSettings
tlss = ByteString -> ByteString -> TLSSettings
Warp.tlsSettingsMemory (String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs MockServerSettings
settings.certificate) (String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs MockServerSettings
  let defs :: Settings
defs = Settings
Warp.defaultSettings {Warp.settingsPort = sPort, Warp.settingsBeforeMainLoop = putMVar serverStarted ()}
  Chan e
buf <- App (Chan e)
forall (m :: * -> *) a. MonadIO m => m (Chan a)
  Async ()
srv <- App () -> App (Async ())
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async (App () -> App (Async ())) -> App () -> App (Async ())
forall a b. (a -> b) -> a -> b
$ ((forall a. App a -> IO a) -> IO ()) -> App ()
forall b. ((forall a. App a -> IO a) -> IO b) -> App b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO \forall a. App a -> IO a
inIO -> do
    TLSSettings -> Settings -> Socket -> Application -> IO ()
Warp.runTLSSocket TLSSettings
tlss Settings
defs Socket
sock \Request
req Response -> IO ResponseReceived
respond -> do
      App ResponseReceived -> IO ResponseReceived
forall a. App a -> IO a
inIO (App ResponseReceived -> IO ResponseReceived)
-> App ResponseReceived -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Chan e -> LiftedApplication
mkApp Chan e
buf Request
req (IO ResponseReceived -> App ResponseReceived
forall a. IO a -> App a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ResponseReceived -> App ResponseReceived)
-> (Response -> IO ResponseReceived)
-> Response
-> App ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> IO ResponseReceived
  Maybe ()
srvMVar <- Port -> App () -> App (Maybe ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
Port -> m a -> m (Maybe a)
UnliftIO.Timeout.timeout Port
5_000_000 (MVar () -> App ()
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
takeMVar MVar ()
  case Maybe ()
srvMVar of
    Just () -> (String, Port) -> Chan e -> App a
go (String
host, Port
sPort) Chan e
buf App a -> App () -> App a
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
`finally` Async () -> App ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
cancel Async ()
    Maybe ()
Nothing -> String -> App a
forall a. HasCallStack => String -> a
error (String -> App a)
-> (Maybe (Either SomeException ()) -> String)
-> Maybe (Either SomeException ())
-> App a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Either SomeException ()) -> String
forall a. Show a => a -> String
show (Maybe (Either SomeException ()) -> App a)
-> App (Maybe (Either SomeException ())) -> App a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Async () -> App (Maybe (Either SomeException ()))
forall (m :: * -> *) a.
MonadIO m =>
Async a -> m (Maybe (Either SomeException a))
poll Async ()

lhMockApp :: Chan (Wai.Request, LBS.ByteString) -> LiftedApplication
lhMockApp :: Chan (Request, ByteString) -> LiftedApplication
lhMockApp = CreateMock App -> Chan (Request, ByteString) -> LiftedApplication
lhMockAppWithPrekeys CreateMock App
forall a. Default a => a

data MockServerSettings = MkMockServerSettings
  { -- | the certificate the mock service uses
    MockServerSettings -> String
certificate :: String,
    -- | the private key the mock service uses
    MockServerSettings -> String
privateKey :: String,
    -- | the public key the mock service uses
    MockServerSettings -> String
publicKey :: String

instance Default MockServerSettings where
  def :: MockServerSettings
def =
      { $sel:certificate:MkMockServerSettings :: String
certificate = String
        $sel:privateKey:MkMockServerSettings :: String
privateKey = String
        $sel:publicKey:MkMockServerSettings :: String
publicKey = String

data CreateMock f = MkCreateMock
  { -- | how to obtain the next last prekey of a mock app
    forall (f :: * -> *). CreateMock f -> f Value
nextLastPrey :: f Value,
    -- | how to obtain some prekeys of a mock app
    forall (f :: * -> *). CreateMock f -> f [Value]
somePrekeys :: f [Value]

instance (App ~ f) => Default (CreateMock f) where
  def :: CreateMock f
def =
      { $sel:nextLastPrey:MkCreateMock :: App Value
nextLastPrey = App Value
        $sel:somePrekeys:MkCreateMock :: App [Value]
somePrekeys = Port -> App Value -> App [Value]
forall (m :: * -> *) a. Applicative m => Port -> m a -> m [a]
replicateM Port
3 App Value

-- | LegalHold service.  Just fake the API, do not maintain any internal state.
lhMockAppWithPrekeys ::
  CreateMock App -> Chan (Wai.Request, LBS.ByteString) -> LiftedApplication
lhMockAppWithPrekeys :: CreateMock App -> Chan (Request, ByteString) -> LiftedApplication
lhMockAppWithPrekeys CreateMock App
mks Chan (Request, ByteString)
ch Request
req Response -> App ResponseReceived
cont = ((forall a. App a -> IO a) -> IO ResponseReceived)
-> App ResponseReceived
forall b. ((forall a. App a -> IO a) -> IO b) -> App b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO \forall a. App a -> IO a
inIO -> do
reqBody <- Request -> IO ByteString
Wai.strictRequestBody Request
  Chan (Request, ByteString) -> (Request, ByteString) -> IO ()
forall (m :: * -> *) a. MonadIO m => Chan a -> a -> m ()
writeChan Chan (Request, ByteString)
ch (Request
req, ByteString
  App ResponseReceived -> IO ResponseReceived
forall a. App a -> IO a
inIO do
nextLastPrekey, [Value]
threePrekeys) <-
        (Value -> [Value] -> (Value, [Value]))
-> App Value -> App ([Value] -> (Value, [Value]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CreateMock App
        App ([Value] -> (Value, [Value]))
-> App [Value] -> App (Value, [Value])
forall a b. App (a -> b) -> App a -> App b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> CreateMock App
    case (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> String) -> [Text] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Request -> [Text]
pathInfo Request
req, ByteString -> String
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req, forall a b. ConvertibleStrings a b => a -> b
cs @_ @String (ByteString -> String) -> Maybe ByteString -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Request -> Maybe ByteString
getRequestHeader String
"Authorization" Request
req) of
"legalhold", String
"status"], String
"GET", Maybe String
_) -> Response -> App ResponseReceived
cont Response
_, String
_, Maybe String
Nothing) -> Response -> App ResponseReceived
cont Response
"legalhold", String
"initiate"], String
"POST", Just String
_) -> Response -> App ResponseReceived
cont (Value -> [Value] -> Response
initiateResp Value
nextLastPrekey [Value]
"legalhold", String
"confirm"], String
"POST", Just String
_) -> Response -> App ResponseReceived
cont Response
"legalhold", String
"remove"], String
"POST", Just String
_) -> Response -> App ResponseReceived
cont Response
      ([String], String, Maybe String)
_ -> Response -> App ResponseReceived
cont Response
    initiateResp :: Value -> [Value] -> Wai.Response
    initiateResp :: Value -> [Value] -> Response
initiateResp Value
npk [Value]
pks =
      Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [(HeaderName
hContentType, String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs String
        (ByteString -> Response)
-> ([Pair] -> ByteString) -> [Pair] -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> ByteString
forall a. ToJSON a => a -> ByteString
        (Value -> ByteString) -> ([Pair] -> Value) -> [Pair] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Pair] -> Value
        ([Pair] -> Response) -> [Pair] -> Response
forall a b. (a -> b) -> a -> b
$ [ String
"prekeys" String -> [Value] -> Pair
forall a. ToJSON a => String -> a -> Pair
.= [Value]
"last_prekey" String -> Value -> Pair
forall a. ToJSON a => String -> a -> Pair
.= Value

    respondOk :: Wai.Response
    respondOk :: Response
respondOk = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 ResponseHeaders
forall a. Monoid a => a
mempty ByteString
forall a. Monoid a => a

    respondBad :: Wai.Response
    respondBad :: Response
respondBad = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status404 ResponseHeaders
forall a. Monoid a => a
mempty ByteString
forall a. Monoid a => a

    missingAuth :: Wai.Response
    missingAuth :: Response
missingAuth = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 ResponseHeaders
forall a. Monoid a => a
mempty (String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs String
"no authorization header")

    getRequestHeader :: String -> Wai.Request -> Maybe ByteString
    getRequestHeader :: String -> Request -> Maybe ByteString
getRequestHeader String
name = HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (String -> HeaderName
forall a. IsString a => String -> a
fromString String
name) (ResponseHeaders -> Maybe ByteString)
-> (Request -> ResponseHeaders) -> Request -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ResponseHeaders

mkLegalHoldSettings :: (String, Warp.Port) -> Value
mkLegalHoldSettings :: (String, Port) -> Value
mkLegalHoldSettings (String
botHost, Port
lhPort) =
  [Pair] -> Value
    [ String
"base_url" String -> String -> Pair
forall a. ToJSON a => String -> a -> Pair
.= (String
"https://" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
botHost String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Port -> String
forall a. Show a => a -> String
show Port
lhPort String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"public_key" String -> String -> Pair
forall a. ToJSON a => String -> a -> Pair
.= String
"auth_token" String -> String -> Pair
forall a. ToJSON a => String -> a -> Pair
.= String

mockServerPubKey :: String
mockServerPubKey :: String
mockServerPubKey =
"-----BEGIN PUBLIC KEY-----\n\
  \-----END PUBLIC KEY-----"

mockServerPrivKey :: String
mockServerPrivKey :: String
mockServerPrivKey =
  \-----END RSA PRIVATE KEY-----"

mockServerCert :: String
mockServerCert :: String
mockServerCert =
  \-----END CERTIFICATE-----"