-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Test.OAuth where

import API.Brig
import API.BrigInternal
import API.Common (defPassword)
import Data.String.Conversions
import Network.HTTP.Types
import Network.URI
import SetupHelpers
import Testlib.Prelude

testOAuthRevokeSession :: (HasCallStack) => App ()
testOAuthRevokeSession :: HasCallStack => App ()
testOAuthRevokeSession = do
  user <- Domain -> CreateUser -> App Value
forall domain.
(HasCallStack, MakesValue domain) =>
domain -> CreateUser -> App Value
randomUser Domain
OwnDomain CreateUser
forall a. Default a => a
def
  let uri = String
"https://example.com"
  cid <- createOAuthClient user "foobar" uri >>= getJSON 200 >>= flip (%.) "client_id"
  let scopes = [String
"write:conversations"]

  -- create a session that will be revoked later
  (tokenToBeRevoked, sessionToBeRevoked) <- do
    token <- generateAccessToken user cid scopes uri
    [app] <- getOAuthApplications user >>= getJSON 200 >>= asList
    [session] <- app %. "sessions" >>= asList
    pure (token, session)

  -- create another session and assert that there are two sessions
  validToken <- do
    token <- generateAccessToken user cid scopes uri
    [app] <- getOAuthApplications user >>= getJSON 200 >>= asList
    sessions <- app %. "sessions" >>= asList
    length sessions `shouldMatchInt` 2
    pure token

  -- attempt to revoke a session with a wrong password should fail
  sessionToBeRevoked
    %. "refresh_token_id"
    >>= asString
    >>= deleteOAuthSession user cid "foobar"
    >>= assertStatus 403

  -- revoke the first session and assert that there is only one session left
  sessionToBeRevoked
    %. "refresh_token_id"
    >>= asString
    >>= deleteOAuthSession user cid defPassword
    >>= assertSuccess
  [app] <- getOAuthApplications user >>= getJSON 200 >>= asList
  sessions <- app %. "sessions" >>= asList
  length sessions `shouldMatchInt` 1

  -- try to use the revoked token and assert that it fails
  tokenToBeRevoked
    %. "refresh_token"
    >>= asString
    >>= createOAuthAccessTokenWithRefreshToken user cid
    >>= assertStatus 403

  -- try to use the valid token and assert that it works
  validToken
    %. "refresh_token"
    >>= asString
    >>= createOAuthAccessTokenWithRefreshToken user cid
    >>= assertSuccess

testRevokeApplicationAccountAccessV6 :: App ()
testRevokeApplicationAccountAccessV6 :: App ()
testRevokeApplicationAccountAccessV6 = do
  user <- Domain -> CreateUser -> App Value
forall domain.
(HasCallStack, MakesValue domain) =>
domain -> CreateUser -> App Value
randomUser Domain
OwnDomain CreateUser
forall a. Default a => a
def
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 0
  let uri = String
"https://example.com"
  let scopes = [String
"write:conversations"]
  replicateM_ 3 $ do
    cid <- createOAuthClient user "foobar" uri >>= getJSON 200 >>= flip (%.) "client_id"
    generateAccessToken user cid scopes uri
  [cid1, cid2, cid3] <- getOAuthApplications user >>= getJSON 200 >>= asList >>= mapM (%. "id")
  revokeApplicationAccessV6 user cid1 >>= assertSuccess
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 2
    ids <- for apps $ \Value
app -> Value
app Value -> String -> App Value
forall a. (HasCallStack, MakesValue a) => a -> String -> App Value
%. String
"id"
    ids `shouldMatchSet` [cid2, cid3]
  revokeApplicationAccessV6 user cid2 >>= assertSuccess
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 1
    ids <- for apps $ \Value
app -> Value
app Value -> String -> App Value
forall a. (HasCallStack, MakesValue a) => a -> String -> App Value
%. String
"id"
    ids `shouldMatchSet` [cid3]
  revokeApplicationAccessV6 user cid3 >>= assertSuccess
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 0

testRevokeApplicationAccountAccess :: App ()
testRevokeApplicationAccountAccess :: App ()
testRevokeApplicationAccountAccess = do
  user <- Domain -> CreateUser -> App Value
forall domain.
(HasCallStack, MakesValue domain) =>
domain -> CreateUser -> App Value
randomUser Domain
OwnDomain CreateUser
forall a. Default a => a
def
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 0
  let uri = String
"https://example.com"
  let scopes = [String
"write:conversations"]
  replicateM_ 3 $ do
    cid <- createOAuthClient user "foobar" uri >>= getJSON 200 >>= flip (%.) "client_id"
    generateAccessToken user cid scopes uri
  [cid1, cid2, cid3] <- getOAuthApplications user >>= getJSON 200 >>= asList >>= mapM (%. "id")
  revokeApplicationAccess user cid1 "foobar" >>= assertStatus 403
  revokeApplicationAccess user cid1 defPassword >>= assertSuccess
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 2
    ids <- for apps $ \Value
app -> Value
app Value -> String -> App Value
forall a. (HasCallStack, MakesValue a) => a -> String -> App Value
%. String
"id"
    ids `shouldMatchSet` [cid2, cid3]
  revokeApplicationAccess user cid2 defPassword >>= assertSuccess
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 1
    ids <- for apps $ \Value
app -> Value
app Value -> String -> App Value
forall a. (HasCallStack, MakesValue a) => a -> String -> App Value
%. String
"id"
    ids `shouldMatchSet` [cid3]
  revokeApplicationAccess user cid3 defPassword >>= assertSuccess
  bindResponse (getOAuthApplications user) $ \Response
resp -> do
    Response
resp.status Int -> Int -> App ()
forall a. (MakesValue a, HasCallStack) => a -> Int -> App ()
`shouldMatchInt` Int
200
    apps <- Response
resp.json App Value -> (Value -> App [Value]) -> App [Value]
forall a b. App a -> (a -> App b) -> App b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> App [Value]
forall a. (HasCallStack, MakesValue a) => a -> App [Value]
asList
    length apps `shouldMatchInt` 0

generateAccessToken :: (MakesValue cid, MakesValue user) => user -> cid -> [String] -> String -> App Value
generateAccessToken :: forall cid user.
(MakesValue cid, MakesValue user) =>
user -> cid -> [String] -> String -> App Value
generateAccessToken user
user cid
cid [String]
scopes String
uri = do
  authCodeResponse <- user -> cid -> [String] -> String -> App Response
forall user cid.
(HasCallStack, MakesValue user, MakesValue cid) =>
user -> cid -> [String] -> String -> App Response
generateOAuthAuthorizationCode user
user cid
cid [String]
scopes String
uri
  let location = URI -> Maybe URI -> URI
forall a. a -> Maybe a -> a
fromMaybe (String -> URI
forall a. HasCallStack => String -> a
error String
"no location header") (Maybe URI -> URI) -> Maybe URI -> URI
forall a b. (a -> b) -> a -> b
$ String -> Maybe URI
parseURI (String -> Maybe URI)
-> ((HeaderName, ByteString) -> String)
-> (HeaderName, ByteString)
-> Maybe URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> String)
-> ((HeaderName, ByteString) -> ByteString)
-> (HeaderName, ByteString)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderName, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((HeaderName, ByteString) -> Maybe URI)
-> Maybe (HeaderName, ByteString) -> Maybe URI
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Response -> Maybe (HeaderName, ByteString)
locationHeader Response
authCodeResponse
  let code = String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"no code query param" ByteString -> String
forall a b. ConvertibleStrings a b => a -> b
cs (Maybe ByteString -> String) -> Maybe ByteString -> String
forall a b. (a -> b) -> a -> b
$ Maybe (Maybe ByteString) -> Maybe ByteString
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe ByteString) -> Maybe ByteString)
-> Maybe (Maybe ByteString) -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
-> [(ByteString, Maybe ByteString)] -> Maybe (Maybe ByteString)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs String
"code") ([(ByteString, Maybe ByteString)] -> Maybe (Maybe ByteString))
-> [(ByteString, Maybe ByteString)] -> Maybe (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> [(ByteString, Maybe ByteString)]
parseQuery (ByteString -> [(ByteString, Maybe ByteString)])
-> ByteString -> [(ByteString, Maybe ByteString)]
forall a b. (a -> b) -> a -> b
$ String -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs URI
location.uriQuery
  createOAuthAccessToken user cid code uri >>= getJSON 200