-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Testlib.Mock (startMockServer, MockServerConfig (..), codensityApp) where

import Control.Arrow ((>>>))
import Control.Concurrent.Async
import Control.Concurrent.MVar
import Control.Exception
import Control.Monad.Codensity
import Control.Monad.Reader
import Data.Streaming.Network
import qualified Network.Socket as Socket
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.Warp as Warp
import qualified Network.Wai.Handler.WarpTLS as Warp
import Testlib.Prelude

codensityApp :: (Wai.Request -> Codensity IO Wai.Response) -> Wai.Application
codensityApp :: (Request -> Codensity IO Response) -> Application
codensityApp Request -> Codensity IO Response
f Request
req = Codensity IO Response -> forall b. (Response -> IO b) -> IO b
forall k (m :: k -> *) a.
Codensity m a -> forall (b :: k). (a -> m b) -> m b
runCodensity (Request -> Codensity IO Response
f Request
req)

data MockServerConfig = MockServerConfig
  { MockServerConfig -> Maybe Port
port :: Maybe Warp.Port,
    MockServerConfig -> Bool
tls :: Bool
  }

instance Default MockServerConfig where
  def :: MockServerConfig
def = MockServerConfig {port :: Maybe Port
port = Maybe Port
forall k1. Maybe k1
Nothing, tls :: Bool
tls = Bool
False}

spawnServer :: Warp.Settings -> Socket.Socket -> Wai.Application -> App ()
spawnServer :: Settings -> Socket -> Application -> App ()
spawnServer Settings
wsettings Socket
sock Application
app = IO () -> App ()
forall a. IO a -> App a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> App ()) -> IO () -> App ()
forall a b. (a -> b) -> a -> b
$ Settings -> Socket -> Application -> IO ()
Warp.runSettingsSocket Settings
wsettings Socket
sock Application
app

spawnTLSServer :: Warp.Settings -> Socket.Socket -> Wai.Application -> App ()
spawnTLSServer :: Settings -> Socket -> Application -> App ()
spawnTLSServer Settings
wsettings Socket
sock Application
app = do
  (cert, key) <-
    (Env -> (String, String)) -> App (String, String)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks do
      Env -> Maybe String
servicesCwdBase (Env -> Maybe String)
-> (Maybe String -> (String, String)) -> Env -> (String, String)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
        Maybe String
Nothing ->
          ( String
"/etc/wire/federator/secrets/tls.crt",
            String
"/etc/wire/federator/secrets/tls.key"
          )
        Just String
base ->
          ( String
base String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/federator/test/resources/integration-leaf.pem",
            String
base String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/federator/test/resources/integration-leaf-key.pem"
          )
  liftIO $ Warp.runTLSSocket (Warp.tlsSettings cert key) wsettings sock app

startMockServer :: MockServerConfig -> Wai.Application -> Codensity App Warp.Port
startMockServer :: MockServerConfig -> Application -> Codensity App Port
startMockServer MockServerConfig
config Application
app = do
  let closeSocket :: Socket -> IO ()
closeSocket Socket
sock = IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (Socket -> IO ()
Socket.close Socket
sock) (\(SomeException
_ :: SomeException) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
  (port, sock) <-
    Codensity IO (Port, Socket) -> Codensity App (Port, Socket)
forall a. Codensity IO a -> Codensity App a
hoistCodensity
      (Codensity IO (Port, Socket) -> Codensity App (Port, Socket))
-> Codensity IO (Port, Socket) -> Codensity App (Port, Socket)
forall a b. (a -> b) -> a -> b
$ (forall b. ((Port, Socket) -> IO b) -> IO b)
-> Codensity IO (Port, Socket)
forall k (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Codensity m a
Codensity
      ((forall b. ((Port, Socket) -> IO b) -> IO b)
 -> Codensity IO (Port, Socket))
-> (forall b. ((Port, Socket) -> IO b) -> IO b)
-> Codensity IO (Port, Socket)
forall a b. (a -> b) -> a -> b
$ IO (Port, Socket)
-> ((Port, Socket) -> IO ()) -> ((Port, Socket) -> IO b) -> IO b
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
        ( case MockServerConfig
config.port of
            Maybe Port
Nothing -> HostPreference -> IO (Port, Socket)
bindRandomPortTCP (String -> HostPreference
forall a. IsString a => String -> a
fromString String
"*6")
            Just Port
n -> (Port
n,) (Socket -> (Port, Socket)) -> IO Socket -> IO (Port, Socket)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Port -> HostPreference -> IO Socket
bindPortTCP Port
n (String -> HostPreference
forall a. IsString a => String -> a
fromString String
"*6")
        )
        (\(Port
_, Socket
sock) -> Socket -> IO ()
closeSocket Socket
sock)
  serverStarted <- liftIO newEmptyMVar
  let wsettings =
        Settings
Warp.defaultSettings
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& Port -> Settings -> Settings
Warp.setPort Port
port
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& Port -> Settings -> Settings
Warp.setGracefulCloseTimeout2 Port
0
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& IO () -> Settings -> Settings
Warp.setBeforeMainLoop (MVar (Maybe SomeException) -> Maybe SomeException -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe SomeException)
serverStarted Maybe SomeException
forall k1. Maybe k1
Nothing)

  -- Action to start server in a separate thread.
  startServer <- lift . appToIO $ (if config.tls then spawnTLSServer else spawnServer) wsettings sock app
  let startServerAsync = do
        a <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ do
          IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO ()
startServer ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(SomeException
e :: SomeException) ->
            IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (Maybe SomeException) -> Maybe SomeException -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar (Maybe SomeException)
serverStarted (SomeException -> Maybe SomeException
forall k1. k1 -> Maybe k1
Just SomeException
e)
        mException <- readMVar serverStarted
        traverse_ throw mException
        pure a

  void
    $ hoistCodensity
    $ Codensity
    $ bracket
      startServerAsync
      ( \Async ()
serverAsync -> do
          Socket -> IO ()
closeSocket Socket
sock
          -- kill the thread running the server
          Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
serverAsync
      )

  pure port