-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Test.NginxZAuthModule where

import API.Brig
import API.Common
import Control.Monad.Codensity
import Control.Monad.Reader
import qualified Data.ByteString as BS
import Data.List.Extra
import Data.Streaming.Network
import Data.UnixTime
import qualified Network.HTTP.Client as HTTP
import Network.HTTP.Types
import Network.Socket (Socket)
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
import SetupHelpers
import System.FilePath ((</>))
import System.IO (writeFile)
import System.IO.Temp
import System.Posix
import System.Process (getPid)
import Testlib.Prelude
import Text.RawString.QQ
import UnliftIO (bracket)
import UnliftIO.Async (async, waitBoth)
import qualified UnliftIO.Async as Async
import UnliftIO.Directory
import UnliftIO.Process
import UnliftIO.Timeout (timeout)

-- Happy flow: login yields a valid zauth token.
--
-- This test uses `withTestNginz` which responds with the user id and time stamp from the
-- token instead of proxying anywhere.  See also: 'testBearerToken2'
testBearerToken :: (HasCallStack) => App ()
testBearerToken :: HasCallStack => App ()
testBearerToken = do
  Codensity App Int -> forall b. (Int -> App b) -> App b
forall k (m :: k -> *) a.
Codensity m a -> forall (b :: k). (a -> m b) -> m b
runCodensity Codensity App Int
withTestNginz ((Int -> App ()) -> App ()) -> (Int -> App ()) -> App ()
forall a b. (a -> b) -> a -> b
$ \Int
port -> do
    alice <- Domain -> CreateUser -> App Value
forall domain.
(HasCallStack, MakesValue domain) =>
domain -> CreateUser -> App Value
randomUser Domain
OwnDomain CreateUser
forall a. Default a => a
def
    email <- asString $ alice %. "email"
    loginResp <- login alice email defPassword >>= getJSON 200
    token <- asString $ loginResp %. "access_token"

    req0 <- HTTP.parseRequest "http://localhost"
    let req =
          Request
req0
            { HTTP.port = port,
              HTTP.requestHeaders = [(hAuthorization, fromString $ "Bearer " <> token)]
            }
    submit "GET" req `bindResponse` \Response
resp -> do
      Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
      Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"user" App Value -> App Value -> App ()
forall a b.
(MakesValue a, MakesValue b, HasCallStack) =>
a -> b -> App ()
`shouldMatch` (Value
alice Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"qualified_id.id")
      Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"timestamp" App Value -> FilePath -> App ()
forall a b.
(MakesValue a, MakesValue b, HasCallStack) =>
a -> b -> App ()
`shouldNotMatch` FilePath
""
      timestampI <- (Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"timestamp" App Value -> (Value -> App FilePath) -> App FilePath
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App FilePath
forall a. (HasCallStack, MakesValue a) => a -> App FilePath
asString)
      let timestampUnix = CTime -> Int32 -> UnixTime
UnixTime ((Integer -> CTime
forall a. Num a => Integer -> a
fromInteger (Integer -> CTime) -> (FilePath -> Integer) -> FilePath -> CTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> Integer
forall a. Read a => FilePath -> a
read) FilePath
timestampI) Int32
0
      now <- liftIO $ getUnixTime
      assertBool "not in future" (timestampUnix > now)

-- Happy flow (zauth token encoded in AWS4_HMAC_SHA256)
--
-- This test uses `withTestNginz` which responds with the user id and time stamp from the
-- token instead of proxying anywhere.  See also: 'testAWS4_HMAC_SHA256_token2'
testAWS4_HMAC_SHA256_token :: (HasCallStack) => App ()
testAWS4_HMAC_SHA256_token :: HasCallStack => App ()
testAWS4_HMAC_SHA256_token = do
  Codensity App Int -> forall b. (Int -> App b) -> App b
forall k (m :: k -> *) a.
Codensity m a -> forall (b :: k). (a -> m b) -> m b
runCodensity Codensity App Int
withTestNginz ((Int -> App ()) -> App ()) -> (Int -> App ()) -> App ()
forall a b. (a -> b) -> a -> b
$ \Int
port -> do
    alice <- Domain -> CreateUser -> App Value
forall domain.
(HasCallStack, MakesValue domain) =>
domain -> CreateUser -> App Value
randomUser Domain
OwnDomain CreateUser
forall a. Default a => a
def
    email <- asString $ alice %. "email"
    loginResp <- login alice email defPassword >>= getJSON 200
    token <- asString $ loginResp %. "access_token"

    req0 <- HTTP.parseRequest "http://localhost"

    let mkReq ByteString
authHeader =
          Request
req0
            { HTTP.port = port,
              HTTP.requestHeaders = [(hAuthorization, authHeader)]
            }
        testCases =
          [ (Bool
True, FilePath -> ByteString
forall a. IsString a => FilePath -> a
fromString (FilePath -> ByteString) -> FilePath -> ByteString
forall a b. (a -> b) -> a -> b
$ FilePath
"AWS4-HMAC-SHA256 Credential=" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
token FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
", foo=bar"),
            (Bool
True, FilePath -> ByteString
forall a. IsString a => FilePath -> a
fromString (FilePath -> ByteString) -> FilePath -> ByteString
forall a b. (a -> b) -> a -> b
$ FilePath
"AWS4-HMAC-SHA256 Credential=" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
token),
            (Bool
True, FilePath -> ByteString
forall a. IsString a => FilePath -> a
fromString (FilePath -> ByteString) -> FilePath -> ByteString
forall a b. (a -> b) -> a -> b
$ FilePath
"AWS4-HMAC-SHA256 foo=bar, Credential=" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
token),
            (Bool
True, FilePath -> ByteString
forall a. IsString a => FilePath -> a
fromString (FilePath -> ByteString) -> FilePath -> ByteString
forall a b. (a -> b) -> a -> b
$ FilePath
"AWS4-HMAC-SHA256 foo=bar, Credential=" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
token FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
", baz=qux"),
            (Bool
True, FilePath -> ByteString
forall a. IsString a => FilePath -> a
fromString (FilePath -> ByteString) -> FilePath -> ByteString
forall a b. (a -> b) -> a -> b
$ FilePath
"AWS4-HMAC-SHA256 foo=bar,Credential=" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
token FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
",baz=qux"),
            (Bool
False, FilePath -> ByteString
forall a. IsString a => FilePath -> a
fromString (FilePath -> ByteString) -> FilePath -> ByteString
forall a b. (a -> b) -> a -> b
$ FilePath
"AWS4-HMAC-SHA256 Credential=bad")
          ]
    for_ testCases $ \(Bool
good, ByteString
header) -> do
      FilePath -> Request -> App Response
submit FilePath
"GET" (ByteString -> Request
mkReq ByteString
header) App Response -> (Response -> App ()) -> App ()
forall a.
HasCallStack =>
App Response -> (Response -> App a) -> App a
`bindResponse` \Response
resp -> do
        if Bool
good
          then do
            Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
            Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"user" App Value -> App Value -> App ()
forall a b.
(MakesValue a, MakesValue b, HasCallStack) =>
a -> b -> App ()
`shouldMatch` (Value
alice Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"qualified_id.id")
            Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"timestamp" App Value -> FilePath -> App ()
forall a b.
(MakesValue a, MakesValue b, HasCallStack) =>
a -> b -> App ()
`shouldNotMatch` FilePath
""
          else do
            Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
            Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"user" App Value -> FilePath -> App ()
forall a b.
(MakesValue a, MakesValue b, HasCallStack) =>
a -> b -> App ()
`shouldMatch` FilePath
""
            Response
resp.json App Value -> FilePath -> App Value
forall a.
(HasCallStack, MakesValue a) =>
a -> FilePath -> App Value
%. FilePath
"timestamp" App Value -> FilePath -> App ()
forall a b.
(MakesValue a, MakesValue b, HasCallStack) =>
a -> b -> App ()
`shouldMatch` FilePath
""

withTestNginz :: Codensity App Int
withTestNginz :: Codensity App Int
withTestNginz = do
  tmpDir <- (forall b. (FilePath -> App b) -> App b) -> Codensity App FilePath
forall k (m :: k -> *) a.
(forall (b :: k). (a -> m b) -> m b) -> Codensity m a
Codensity ((forall b. (FilePath -> App b) -> App b)
 -> Codensity App FilePath)
-> (forall b. (FilePath -> App b) -> App b)
-> Codensity App FilePath
forall a b. (a -> b) -> a -> b
$ FilePath -> (FilePath -> App b) -> App b
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
FilePath -> (FilePath -> m a) -> m a
withSystemTempDirectory FilePath
"integration-NginxZauthModule"
  env <- ask
  -- Create config
  let (keystorePath, oauthPubKey) = case env.servicesCwdBase of
        Maybe FilePath
Nothing ->
          ( FilePath
"/etc/wire/nginz/secrets/zauth.conf",
            FilePath
"/etc/wire/nginz/secrets/oauth_ed25519_pub.jwk"
          )
        Just FilePath
basedir ->
          ( FilePath
basedir FilePath -> FilePath -> FilePath
</> FilePath
"nginz/integration-test/resources/zauth/pubkeys.txt",
            FilePath
basedir FilePath -> FilePath -> FilePath
</> FilePath
"nginz/integration-test/resources/oauth/ed25519_public.jwk"
          )
      unixSocketPath = FilePath
tmpDir FilePath -> FilePath -> FilePath
</> FilePath
"sock"
      config =
        FilePath
nginxTestConfigTemplate
          -- Listen on unix socket because its too complicated to make nginx run
          -- on a random available port.
          FilePath -> (FilePath -> FilePath) -> FilePath
forall a b. a -> (a -> b) -> b
& FilePath -> FilePath -> FilePath -> FilePath
forall a. Eq a => [a] -> [a] -> [a] -> [a]
replace FilePath
"{socket_path}" FilePath
unixSocketPath
          FilePath -> (FilePath -> FilePath) -> FilePath
forall a b. a -> (a -> b) -> b
& FilePath -> FilePath -> FilePath -> FilePath
forall a. Eq a => [a] -> [a] -> [a] -> [a]
replace FilePath
"{pid_file}" (FilePath
tmpDir FilePath -> FilePath -> FilePath
</> FilePath
"pid")

      configPath = FilePath
tmpDir FilePath -> FilePath -> FilePath
</> FilePath
"nginx.conf"

  copyFile keystorePath (tmpDir </> "keystore")
  copyFile oauthPubKey (tmpDir </> "oauth-pub-key")
  liftIO $ writeFile (tmpDir </> "acl") ""
  liftIO $ writeFile configPath config

  let startNginx = do
        (_, Just stdoutHdl, Just stderrHdl, processHandle) <-
          CreateProcess
-> App (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle)
forall (m :: * -> *).
MonadIO m =>
CreateProcess
-> m (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle)
createProcess
            (FilePath -> [FilePath] -> CreateProcess
proc FilePath
"nginx" [FilePath
"-c", FilePath
configPath, FilePath
"-g", FilePath
"daemon off;", FilePath
"-e", FilePath
"/dev/stdout"])
              { cwd = Just tmpDir,
                std_out = CreatePipe,
                std_err = CreatePipe
              }
        -- Enable this when debugging
        -- liftIO $ void $ forkIO $ logToConsole id "nginx-zauth-module" stdoutHdl
        -- liftIO $ void $ forkIO $ logToConsole id "nginx-zauth-module" stderrHdl
        pure (stdoutHdl, stderrHdl, processHandle)

      stopNginx (a
_, b
_, ProcessHandle
processHandle) = do
        mPid <- IO (Maybe Pid) -> m (Maybe Pid)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Pid) -> m (Maybe Pid))
-> IO (Maybe Pid) -> m (Maybe Pid)
forall a b. (a -> b) -> a -> b
$ ProcessHandle -> IO (Maybe Pid)
getPid ProcessHandle
processHandle
        liftIO $ for_ mPid (signalProcess keyboardSignal)
        timeout 50000 (waitForProcess processHandle) >>= \case
          Just ExitCode
_ -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          Maybe ExitCode
Nothing -> do
            IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Maybe Pid -> (Pid -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Maybe Pid
mPid (Signal -> Pid -> IO ()
signalProcess Signal
killProcess)
            m ExitCode -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m ExitCode -> m ()) -> m ExitCode -> m ()
forall a b. (a -> b) -> a -> b
$ ProcessHandle -> m ExitCode
forall (m :: * -> *). MonadIO m => ProcessHandle -> m ExitCode
waitForProcess ProcessHandle
processHandle
  _ <- Codensity $ bracket startNginx stopNginx

  -- The http-client package doesn't support connecting to servers running on a
  -- unix domain socket. So, we bind to a random TCP port and forward the
  -- requests and responses to and from the unix domain socket of nginx.
  (port, sock) <- Codensity $ bracket (liftIO $ bindRandomPortTCP (fromString "*6")) (liftIO . NS.close . snd)
  _ <- Codensity $ bracket (async $ forwardToUnixDomain sock unixSocketPath) Async.cancel
  pure port

forwardToUnixDomain :: (MonadIO m) => Socket -> FilePath -> m ()
forwardToUnixDomain :: forall (m :: * -> *). MonadIO m => Socket -> FilePath -> m ()
forwardToUnixDomain Socket
tcpSock FilePath
unixSockAddr = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (IO () -> IO ()) -> IO () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  (conn, _) <- Socket -> IO (Socket, SockAddr)
NS.accept Socket
tcpSock
  void $ async $ do
    unixSock <- NS.socket NS.AF_UNIX NS.Stream NS.defaultProtocol
    NS.connect unixSock (NS.SockAddrUnix unixSockAddr)

    tcpToUnix <- async $ forward conn unixSock
    unixToTCP <- async $ forward unixSock conn

    void $ waitBoth tcpToUnix unixToTCP

forward :: Socket -> Socket -> IO ()
forward :: Socket -> Socket -> IO ()
forward Socket
src Socket
dst = do
  let loop :: IO ()
loop = do
        bs <- Socket -> Int -> IO ByteString
NSB.recv Socket
src Int
4096
        if BS.null bs
          then pure ()
          else NSB.sendAll dst bs >> loop
  IO ()
loop

nginxTestConfigTemplate :: String
nginxTestConfigTemplate :: FilePath
nginxTestConfigTemplate =
  FilePath
[r|
    events {
       worker_connections 128;
    }

    error_log /dev/stderr info;
    pid {pid_file};

    http {
       server {
          listen unix:{socket_path};
          zauth_keystore "./keystore";
          zauth_acl "./acl";
          oauth_pub_key "./oauth-pub-key";

          access_log /dev/stdout combined;

          location / {
            default_type application/json;
            return 200 '{"user":"$zauth_user", "timestamp": "$zauth_timestamp"}';
          }
       }
    }
 |]