{-# LANGUAGE OverloadedStrings #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Testlib.Env where

import Control.Concurrent.MVar
import qualified Control.Exception as E
import Control.Monad.Codensity
import Control.Monad.IO.Class
import Control.Monad.Reader
import Data.Foldable
import Data.Function ((&))
import Data.Functor
import Data.IORef
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Traversable (for)
import qualified Data.Yaml as Yaml
import qualified Database.CQL.IO as Cassandra
import GHC.Stack
import qualified Network.HTTP.Client as HTTP
import qualified OpenSSL.Session as OpenSSL
import System.Directory
import System.Environment (lookupEnv)
import System.Exit
import System.FilePath
import System.IO
import System.IO.Temp
import Testlib.Prekeys
import Testlib.ResourcePool
import Testlib.Types
import Text.Read (readMaybe)
import Prelude

serviceHostPort :: ServiceMap -> Service -> HostPort
serviceHostPort :: ServiceMap -> Service -> HostPort
serviceHostPort ServiceMap
m Service
Brig = ServiceMap
m.brig
serviceHostPort ServiceMap
m Service
Galley = ServiceMap
m.galley
serviceHostPort ServiceMap
m Service
Cannon = ServiceMap
m.cannon
serviceHostPort ServiceMap
m Service
Gundeck = ServiceMap
m.gundeck
serviceHostPort ServiceMap
m Service
Cargohold = ServiceMap
m.cargohold
serviceHostPort ServiceMap
m Service
Nginz = ServiceMap
m.nginz
serviceHostPort ServiceMap
m Service
WireProxy = ServiceMap
m.proxy
serviceHostPort ServiceMap
m Service
Spar = ServiceMap
m.spar
serviceHostPort ServiceMap
m Service
BackgroundWorker = ServiceMap
m.backgroundWorker
serviceHostPort ServiceMap
m Service
Stern = ServiceMap
m.stern
serviceHostPort ServiceMap
m Service
FederatorInternal = ServiceMap
m.federatorInternal
serviceHostPort ServiceMap
m Service
WireServerEnterprise = ServiceMap
m.wireServerEnterprise

mkGlobalEnv :: FilePath -> Codensity IO GlobalEnv
mkGlobalEnv :: String -> Codensity IO GlobalEnv
mkGlobalEnv String
cfgFile = do
  eith <- IO (Either ParseException IntegrationConfig)
-> Codensity IO (Either ParseException IntegrationConfig)
forall a. IO a -> Codensity IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either ParseException IntegrationConfig)
 -> Codensity IO (Either ParseException IntegrationConfig))
-> IO (Either ParseException IntegrationConfig)
-> Codensity IO (Either ParseException IntegrationConfig)
forall a b. (a -> b) -> a -> b
$ String -> IO (Either ParseException IntegrationConfig)
forall a. FromJSON a => String -> IO (Either ParseException a)
Yaml.decodeFileEither String
cfgFile
  intConfig <- liftIO $ case eith of
    Left ParseException
err -> do
      Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Could not parse " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
cfgFile String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
": " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ParseException -> String
Yaml.prettyPrintParseException ParseException
err
      IO IntegrationConfig
forall a. IO a
exitFailure
    Right (IntegrationConfig
intConfig :: IntegrationConfig) -> IntegrationConfig -> IO IntegrationConfig
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntegrationConfig
intConfig

  let devEnvProjectRoot = case String -> [String]
splitPath (String -> String
takeDirectory String
cfgFile) of
        [] -> Maybe String
forall a. Maybe a
Nothing
        [String]
ps ->
          if [String] -> String
forall a. HasCallStack => [a] -> a
last [String]
ps String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"services"
            then String -> Maybe String
forall a. a -> Maybe a
Just ([String] -> String
joinPath ([String] -> [String]
forall a. HasCallStack => [a] -> [a]
init [String]
ps))
            else Maybe String
forall a. Maybe a
Nothing
      getCassCertFilePath :: IO (Maybe FilePath) =
        maybe
          (pure Nothing)
          ( \String
certFilePath ->
              if String -> Bool
isAbsolute String
certFilePath
                then Maybe String -> IO (Maybe String)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe String -> IO (Maybe String))
-> Maybe String -> IO (Maybe String)
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
certFilePath
                else Maybe String -> (String -> IO String) -> IO (Maybe String)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Maybe String
devEnvProjectRoot ((String -> IO String) -> IO (Maybe String))
-> (String -> IO String) -> IO (Maybe String)
forall a b. (a -> b) -> a -> b
$ \String
projectRoot -> String -> IO String
makeAbsolute (String -> IO String) -> String -> IO String
forall a b. (a -> b) -> a -> b
$ String -> String -> String
combine String
projectRoot String
certFilePath
          )
          intConfig.cassandra.cassTlsCa

  manager <- liftIO $ HTTP.newManager HTTP.defaultManagerSettings

  mbCassCertFilePath <- liftIO $ getCassCertFilePath
  mbSSLContext <- liftIO $ createSSLContext mbCassCertFilePath
  let basicCassSettings =
        Settings
Cassandra.defSettings
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& String -> [String] -> Settings -> Settings
Cassandra.setContacts IntegrationConfig
intConfig.cassandra.cassHost []
          Settings -> (Settings -> Settings) -> Settings
forall a b. a -> (a -> b) -> b
& PortNumber -> Settings -> Settings
Cassandra.setPortNumber (Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral IntegrationConfig
intConfig.cassandra.cassPort)
      cassSettings = Settings
-> (SSLContext -> Settings) -> Maybe SSLContext -> Settings
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Settings
basicCassSettings (\SSLContext
sslCtx -> SSLContext -> Settings -> Settings
Cassandra.setSSLContext SSLContext
sslCtx Settings
basicCassSettings) Maybe SSLContext
mbSSLContext
  cassClient <- Cassandra.init cassSettings
  let resources = [DynamicBackendConfig] -> [BackendResource]
backendResources (Map String DynamicBackendConfig -> [DynamicBackendConfig]
forall k a. Map k a -> [a]
Map.elems IntegrationConfig
intConfig.dynamicBackends)
  resourcePool <-
    liftIO $
      createBackendResourcePool
        resources
        intConfig.rabbitmq
        cassClient
  let sm =
        [(String, ServiceMap)] -> Map String ServiceMap
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(String, ServiceMap)] -> Map String ServiceMap)
-> [(String, ServiceMap)] -> Map String ServiceMap
forall a b. (a -> b) -> a -> b
$
          [ (IntegrationConfig
intConfig.backendOne.originDomain, IntegrationConfig
intConfig.backendOne.beServiceMap),
            (IntegrationConfig
intConfig.backendTwo.originDomain, IntegrationConfig
intConfig.backendTwo.beServiceMap),
            (IntegrationConfig
intConfig.federationV0.originDomain, IntegrationConfig
intConfig.federationV0.beServiceMap),
            (IntegrationConfig
intConfig.federationV1.originDomain, IntegrationConfig
intConfig.federationV1.beServiceMap),
            (IntegrationConfig
intConfig.federationV2.originDomain, IntegrationConfig
intConfig.federationV2.beServiceMap)
          ]
            [(String, ServiceMap)]
-> [(String, ServiceMap)] -> [(String, ServiceMap)]
forall a. Semigroup a => a -> a -> a
<> [(BackendResource -> String
berDomain BackendResource
resource, BackendResource -> ServiceMap
resourceServiceMap BackendResource
resource) | BackendResource
resource <- [BackendResource]
resources]
  tempDir <- Codensity $ withSystemTempDirectory "test"
  timeOutSeconds <-
    liftIO $
      fromMaybe 10 . (readMaybe @Int =<<) <$> lookupEnv "TEST_TIMEOUT_SECONDS"
  gCellsEventWatchersLock <- liftIO newEmptyMVar
  gCellsEventWatchers <- liftIO $ newIORef mempty
  Codensity $ \() -> IO b
k -> do
    IO b -> IO () -> IO b
forall a b. IO a -> IO b -> IO a
E.finally (() -> IO b
k ()) (IO () -> IO b) -> IO () -> IO b
forall a b. (a -> b) -> a -> b
$ do
      watchers <- IORef (Map String QueueWatcher) -> IO (Map String QueueWatcher)
forall a. IORef a -> IO a
readIORef IORef (Map String QueueWatcher)
gCellsEventWatchers
      traverse_ stopQueueWatcher watchers
  pure
    GlobalEnv
      { gServiceMap = sm,
        gDomain1 = intConfig.backendOne.originDomain,
        gDomain2 = intConfig.backendTwo.originDomain,
        gIntegrationTestHostName = intConfig.integrationTestHostName,
        gFederationV0Domain = intConfig.federationV0.originDomain,
        gFederationV1Domain = intConfig.federationV1.originDomain,
        gFederationV2Domain = intConfig.federationV2.originDomain,
        gDynamicDomains = (.domain) <$> Map.elems intConfig.dynamicBackends,
        gDefaultAPIVersion = 14,
        gManager = manager,
        gServicesCwdBase = devEnvProjectRoot <&> (</> "services"),
        gBackendResourcePool = resourcePool,
        gRabbitMQConfig = intConfig.rabbitmq,
        gRabbitMQConfigV0 = intConfig.rabbitmqV0,
        gRabbitMQConfigV1 = intConfig.rabbitmqV1,
        gTempDir = tempDir,
        gTimeOutSeconds = timeOutSeconds,
        gDNSMockServerConfig = intConfig.dnsMockServer,
        gCellsEventQueue = intConfig.cellsEventQueue,
        gCellsEventWatchersLock,
        gCellsEventWatchers
      }
  where
    createSSLContext :: Maybe FilePath -> IO (Maybe OpenSSL.SSLContext)
    createSSLContext :: Maybe String -> IO (Maybe SSLContext)
createSSLContext (Just String
certFilePath) = do
      String -> IO ()
forall a. Show a => a -> IO ()
print (String
"TLS: Connecting to Cassandra with TLS. Provided CA path:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
certFilePath)
      sslContext <- IO SSLContext
OpenSSL.context
      OpenSSL.contextSetCAFile sslContext certFilePath
      OpenSSL.contextSetVerificationMode
        sslContext
        OpenSSL.VerifyPeer
          { vpFailIfNoPeerCert = True,
            vpClientOnce = True,
            vpCallback = Nothing
          }
      pure $ Just sslContext
    createSSLContext Maybe String
Nothing = Maybe SSLContext -> IO (Maybe SSLContext)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe SSLContext
forall a. Maybe a
Nothing

mkEnv :: Maybe String -> GlobalEnv -> Codensity IO Env
mkEnv :: Maybe String -> GlobalEnv -> Codensity IO Env
mkEnv Maybe String
currentTestName GlobalEnv
ge = do
  mls <- IO (IORef MLSState) -> Codensity IO (IORef MLSState)
forall a. IO a -> Codensity IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef MLSState) -> Codensity IO (IORef MLSState))
-> (MLSState -> IO (IORef MLSState))
-> MLSState
-> Codensity IO (IORef MLSState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MLSState -> IO (IORef MLSState)
forall a. a -> IO (IORef a)
newIORef (MLSState -> Codensity IO (IORef MLSState))
-> Codensity IO MLSState -> Codensity IO (IORef MLSState)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Codensity IO MLSState
mkMLSState
  liftIO $ do
    pks <- newIORef (zip [1 ..] somePrekeys)
    lpks <- newIORef someLastPrekeys
    pure
      Env
        { serviceMap = gServiceMap ge,
          domain1 = gDomain1 ge,
          domain2 = gDomain2 ge,
          integrationTestHostName = gIntegrationTestHostName ge,
          federationV0Domain = gFederationV0Domain ge,
          federationV1Domain = gFederationV1Domain ge,
          federationV2Domain = gFederationV2Domain ge,
          dynamicDomains = gDynamicDomains ge,
          defaultAPIVersion = gDefaultAPIVersion ge,
          -- hardcode API versions for federated domains because they don't have
          -- latest things. Ensure we do not use development API versions in
          -- those domains.
          apiVersionByDomain =
            Map.fromList
              [ (gFederationV0Domain ge, 4),
                (gFederationV1Domain ge, 5),
                (gFederationV2Domain ge, 8)
              ],
          manager = gManager ge,
          servicesCwdBase = gServicesCwdBase ge,
          prekeys = pks,
          lastPrekeys = lpks,
          mls = mls,
          resourcePool = ge.gBackendResourcePool,
          rabbitMQConfig = ge.gRabbitMQConfig,
          timeOutSeconds = ge.gTimeOutSeconds,
          currentTestName,
          dnsMockServerConfig = ge.gDNSMockServerConfig,
          cellsEventQueue = ge.gCellsEventQueue,
          cellsEventWatchersLock = ge.gCellsEventWatchersLock,
          cellsEventWatchers = ge.gCellsEventWatchers
        }

allCiphersuites :: [Ciphersuite]
-- FUTUREWORK: add 0x0005 to this list once openmls supports it
allCiphersuites :: [Ciphersuite]
allCiphersuites = (String -> Ciphersuite) -> [String] -> [Ciphersuite]
forall a b. (a -> b) -> [a] -> [b]
map String -> Ciphersuite
Ciphersuite [String
"0x0001", String
"0xf031", String
"0x0002", String
"0x0007"]

mkMLSState :: Codensity IO MLSState
mkMLSState :: Codensity IO MLSState
mkMLSState = (forall b. (MLSState -> IO b) -> IO b) -> Codensity IO MLSState
forall k (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Codensity m a
Codensity ((forall b. (MLSState -> IO b) -> IO b) -> Codensity IO MLSState)
-> (forall b. (MLSState -> IO b) -> IO b) -> Codensity IO MLSState
forall a b. (a -> b) -> a -> b
$ \MLSState -> IO b
k ->
  String -> (String -> IO b) -> IO b
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory String
"mls" ((String -> IO b) -> IO b) -> (String -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \String
tmp -> do
    MLSState -> IO b
k
      MLSState
        { baseDir :: String
baseDir = String
tmp,
          convs :: Map ConvId MLSConv
convs = Map ConvId MLSConv
forall a. Monoid a => a
mempty,
          clientGroupState :: Map ClientIdentity ClientGroupState
clientGroupState = Map ClientIdentity ClientGroupState
forall a. Monoid a => a
mempty
        }

getMLSConv :: (HasCallStack) => ConvId -> App MLSConv
getMLSConv :: HasCallStack => ConvId -> App MLSConv
getMLSConv ConvId
convId = do
  mConv <- ConvId -> Map ConvId MLSConv -> Maybe MLSConv
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ConvId
convId (Map ConvId MLSConv -> Maybe MLSConv)
-> (MLSState -> Map ConvId MLSConv) -> MLSState -> Maybe MLSConv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.convs) (MLSState -> Maybe MLSConv) -> App MLSState -> App (Maybe MLSConv)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> App MLSState
getMLSState
  case mConv of
    Just MLSConv
conv -> MLSConv -> App MLSConv
forall a. a -> App a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MLSConv
conv
    Maybe MLSConv
Nothing -> do
      String -> App MLSConv
forall a. HasCallStack => String -> App a
assertFailure (String -> App MLSConv) -> String -> App MLSConv
forall a b. (a -> b) -> a -> b
$ String
"MLSConv not found, convId=" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ConvId -> String
forall a. Show a => a -> String
show ConvId
convId

withAPIVersion :: Int -> App a -> App a
withAPIVersion :: forall a. Int -> App a -> App a
withAPIVersion Int
v = (Env -> Env) -> App a -> App a
forall a. (Env -> Env) -> App a -> App a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> App a -> App a) -> (Env -> Env) -> App a -> App a
forall a b. (a -> b) -> a -> b
$ \Env
e -> Env
e {defaultAPIVersion = v}