{-# LANGUAGE NumericUnderscores #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Network.Wai.Utilities.MockServer where

import Control.Concurrent.Async qualified as Async
import Control.Exception (throwIO)
import Control.Exception qualified as E
import Control.Monad.Catch
import Control.Monad.Codensity
import Data.Streaming.Network (bindRandomPortTCP)
import Imports
import Network.HTTP2.Client qualified as HTTP2
import Network.Wai qualified as Wai
import Network.Wai.Handler.Warp qualified as Warp
import Network.Wai.Handler.WarpTLS qualified as Warp
import System.Timeout qualified as System

-- | Thrown in IO by mock federator if the server could not be started after 10
-- seconds.
newtype MockTimeout = MockTimeout Warp.Port
  deriving (MockTimeout -> MockTimeout -> Bool
(MockTimeout -> MockTimeout -> Bool)
-> (MockTimeout -> MockTimeout -> Bool) -> Eq MockTimeout
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MockTimeout -> MockTimeout -> Bool
== :: MockTimeout -> MockTimeout -> Bool
$c/= :: MockTimeout -> MockTimeout -> Bool
/= :: MockTimeout -> MockTimeout -> Bool
Eq, Port -> MockTimeout -> ShowS
[MockTimeout] -> ShowS
MockTimeout -> String
(Port -> MockTimeout -> ShowS)
-> (MockTimeout -> String)
-> ([MockTimeout] -> ShowS)
-> Show MockTimeout
forall a.
(Port -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Port -> MockTimeout -> ShowS
showsPrec :: Port -> MockTimeout -> ShowS
$cshow :: MockTimeout -> String
show :: MockTimeout -> String
$cshowList :: [MockTimeout] -> ShowS
showList :: [MockTimeout] -> ShowS
Show, Typeable)

instance Exception MockTimeout

withMockServer :: Wai.Application -> Codensity IO Word16
withMockServer :: Application -> Codensity IO Word16
withMockServer Application
app = (forall b. (Word16 -> IO b) -> IO b) -> Codensity IO Word16
forall k (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Codensity m a
Codensity ((forall b. (Word16 -> IO b) -> IO b) -> Codensity IO Word16)
-> (forall b. (Word16 -> IO b) -> IO b) -> Codensity IO Word16
forall a b. (a -> b) -> a -> b
$ \Word16 -> IO b
k ->
  IO (IO (), Port)
-> ((IO (), Port) -> IO ()) -> ((IO (), Port) -> IO b) -> IO b
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket
    (IO (IO (), Port) -> IO (IO (), Port)
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IO (), Port) -> IO (IO (), Port))
-> IO (IO (), Port) -> IO (IO (), Port)
forall a b. (a -> b) -> a -> b
$ Maybe TLSSettings -> Application -> IO (IO (), Port)
startMockServer Maybe TLSSettings
forall a. Maybe a
Nothing Application
app)
    (IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ())
-> ((IO (), Port) -> IO ()) -> (IO (), Port) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IO (), Port) -> IO ()
forall a b. (a, b) -> a
fst)
    (Word16 -> IO b
k (Word16 -> IO b)
-> ((IO (), Port) -> Word16) -> (IO (), Port) -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Port -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Port -> Word16)
-> ((IO (), Port) -> Port) -> (IO (), Port) -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IO (), Port) -> Port
forall a b. (a, b) -> b
snd)

ignoreHTTP2NonError :: Maybe Wai.Request -> SomeException -> IO ()
ignoreHTTP2NonError :: Maybe Request -> SomeException -> IO ()
ignoreHTTP2NonError Maybe Request
mr SomeException
e
  | Just HTTP2Error
HTTP2.ConnectionIsClosed <- SomeException -> Maybe HTTP2Error
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise = Maybe Request -> SomeException -> IO ()
Warp.defaultOnException Maybe Request
mr SomeException
e

-- | Start a mock warp server on a random port, serving the given Wai application.
--
-- If the 'Warp.TLSSettings` argument is provided, start an HTTPS server,
-- otherwise start a plain HTTP server.
--
-- Returns an action to kill the spawned server, and the port on which the
-- server is running.
--
-- This function should normally be used within 'bracket', e.g.:
-- @
--     bracket (startMockServer Nothing app) fst $ \(close, port) ->
--       makeRequest "localhost" port
-- @
startMockServer :: Maybe Warp.TLSSettings -> Wai.Application -> IO (IO (), Warp.Port)
startMockServer :: Maybe TLSSettings -> Application -> IO (IO (), Port)
startMockServer Maybe TLSSettings
mtlsSettings Application
app = do
  (Port
port, Socket
sock) <- HostPreference -> IO (Port, Socket)
bindRandomPortTCP HostPreference
"*6"
  MVar ()
serverStarted <- IO (MVar ())
forall (m :: * -> *) a. MonadIO m => m (MVar a)
newEmptyMVar
  let wsettings :: Settings
wsettings =
        Settings
Warp.defaultSettings
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& Port -> Settings -> Settings
Warp.setPort Port
port
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& Port -> Settings -> Settings
Warp.setGracefulCloseTimeout2 Port
0 -- Defaults to 2 seconds, causes server stop to take very long
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& IO () -> Settings -> Settings
Warp.setBeforeMainLoop (MVar () -> () -> IO ()
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m ()
putMVar MVar ()
serverStarted ())
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& (Maybe Request -> SomeException -> IO ()) -> Settings -> Settings
Warp.setOnException Maybe Request -> SomeException -> IO ()
ignoreHTTP2NonError

  Async ()
serverThread <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
Async.async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ case Maybe TLSSettings
mtlsSettings of
    Just TLSSettings
tlsSettings -> TLSSettings -> Settings -> Socket -> Application -> IO ()
Warp.runTLSSocket TLSSettings
tlsSettings Settings
wsettings Socket
sock Application
app
    Maybe TLSSettings
Nothing -> Settings -> Socket -> Application -> IO ()
Warp.runSettingsSocket Settings
wsettings Socket
sock Application
app
  Maybe ()
serverStartedSignal <- Port -> IO () -> IO (Maybe ())
forall a. Port -> IO a -> IO (Maybe a)
System.timeout Port
10_000_000 (MVar () -> IO ()
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
readMVar MVar ()
serverStarted)
  let closeMock :: IO ()
closeMock = do
        Maybe (Either SomeException ())
me <- Async () -> IO (Maybe (Either SomeException ()))
forall a. Async a -> IO (Maybe (Either SomeException a))
Async.poll Async ()
serverThread
        case Maybe (Either SomeException ())
me of
          Maybe (Either SomeException ())
Nothing -> Async () -> IO ()
forall a. Async a -> IO ()
Async.cancel Async ()
serverThread
          Just (Left SomeException
e) -> SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO SomeException
e
          Just (Right ()
a) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
a
  case Maybe ()
serverStartedSignal of
    Maybe ()
Nothing -> do
      Async () -> IO ()
forall a. Async a -> IO ()
Async.cancel Async ()
serverThread
      MockTimeout -> IO (IO (), Port)
forall e a. Exception e => e -> IO a
throwIO (Port -> MockTimeout
MockTimeout Port
port)
    Just ()
_ -> (IO (), Port) -> IO (IO (), Port)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IO ()
closeMock, Port
port)