-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Galley.Cassandra.Conversation.MLS
  ( acquireCommitLock,
    releaseCommitLock,
    lookupMLSClients,
    lookupMLSClientLeafIndices,
  )
where

import Cassandra
import Cassandra.Settings
import Control.Arrow
import Data.Time
import Galley.API.MLS.Types
import Galley.Cassandra.Queries qualified as Cql
import Galley.Data.Types
import Imports
import Wire.API.MLS.Epoch
import Wire.API.MLS.Group

acquireCommitLock :: GroupId -> Epoch -> NominalDiffTime -> Client LockAcquired
acquireCommitLock :: GroupId -> Epoch -> NominalDiffTime -> Client LockAcquired
acquireCommitLock GroupId
groupId Epoch
epoch NominalDiffTime
ttl = do
  [Row]
rows <-
    RetrySettings -> Client [Row] -> Client [Row]
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x5 (Client [Row] -> Client [Row]) -> Client [Row] -> Client [Row]
forall a b. (a -> b) -> a -> b
$
      PrepQuery W (GroupId, Epoch, Int32) Row
-> QueryParams (GroupId, Epoch, Int32) -> Client [Row]
forall (m :: * -> *) a (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, RunQ q) =>
q W a Row -> QueryParams a -> m [Row]
trans
        PrepQuery W (GroupId, Epoch, Int32) Row
Cql.acquireCommitLock
        ( Consistency
-> (GroupId, Epoch, Int32) -> QueryParams (GroupId, Epoch, Int32)
forall a. Consistency -> a -> QueryParams a
params
            Consistency
LocalQuorum
            (GroupId
groupId, Epoch
epoch, NominalDiffTime -> Int32
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round NominalDiffTime
ttl)
        )
          { serialConsistency = Just LocalSerialConsistency
          }
  LockAcquired -> Client LockAcquired
forall a. a -> Client a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LockAcquired -> Client LockAcquired)
-> LockAcquired -> Client LockAcquired
forall a b. (a -> b) -> a -> b
$
    if [Row] -> Bool
checkTransSuccess [Row]
rows
      then LockAcquired
Acquired
      else LockAcquired
NotAcquired

releaseCommitLock :: GroupId -> Epoch -> Client ()
releaseCommitLock :: GroupId -> Epoch -> Client ()
releaseCommitLock GroupId
groupId Epoch
epoch =
  RetrySettings -> Client () -> Client ()
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x5 (Client () -> Client ()) -> Client () -> Client ()
forall a b. (a -> b) -> a -> b
$
    PrepQuery W (GroupId, Epoch) ()
-> QueryParams (GroupId, Epoch) -> Client ()
forall (m :: * -> *) a (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, RunQ q) =>
q W a () -> QueryParams a -> m ()
write
      PrepQuery W (GroupId, Epoch) ()
Cql.releaseCommitLock
      ( Consistency -> (GroupId, Epoch) -> QueryParams (GroupId, Epoch)
forall a. Consistency -> a -> QueryParams a
params
          Consistency
LocalQuorum
          (GroupId
groupId, Epoch
epoch)
      )

checkTransSuccess :: [Row] -> Bool
checkTransSuccess :: [Row] -> Bool
checkTransSuccess [] = Bool
False
checkTransSuccess (Row
row : [Row]
_) = (String -> Bool)
-> (Maybe Bool -> Bool) -> Either String (Maybe Bool) -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Bool -> String -> Bool
forall a b. a -> b -> a
const Bool
False) (Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False) (Either String (Maybe Bool) -> Bool)
-> Either String (Maybe Bool) -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> Row -> Either String (Maybe Bool)
forall a. Cql a => Int -> Row -> Either String a
fromRow Int
0 Row
row

lookupMLSClientLeafIndices :: GroupId -> Client (ClientMap, IndexMap)
lookupMLSClientLeafIndices :: GroupId -> Client (ClientMap, IndexMap)
lookupMLSClientLeafIndices GroupId
groupId = do
  [(Domain, UserId, ClientId, Int32, Bool)]
entries <- RetrySettings
-> Client [(Domain, UserId, ClientId, Int32, Bool)]
-> Client [(Domain, UserId, ClientId, Int32, Bool)]
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x5 (PrepQuery
  R (Identity GroupId) (Domain, UserId, ClientId, Int32, Bool)
-> QueryParams (Identity GroupId)
-> Client [(Domain, UserId, ClientId, Int32, Bool)]
forall (m :: * -> *) a b (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, Tuple b, RunQ q) =>
q R a b -> QueryParams a -> m [b]
query PrepQuery
  R (Identity GroupId) (Domain, UserId, ClientId, Int32, Bool)
Cql.lookupMLSClients (Consistency -> Identity GroupId -> QueryParams (Identity GroupId)
forall a. Consistency -> a -> QueryParams a
params Consistency
LocalQuorum (GroupId -> Identity GroupId
forall a. a -> Identity a
Identity GroupId
groupId)))
  (ClientMap, IndexMap) -> Client (ClientMap, IndexMap)
forall a. a -> Client a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ClientMap, IndexMap) -> Client (ClientMap, IndexMap))
-> (ClientMap, IndexMap) -> Client (ClientMap, IndexMap)
forall a b. (a -> b) -> a -> b
$ ([(Domain, UserId, ClientId, Int32, Bool)] -> ClientMap
mkClientMap ([(Domain, UserId, ClientId, Int32, Bool)] -> ClientMap)
-> ([(Domain, UserId, ClientId, Int32, Bool)] -> IndexMap)
-> [(Domain, UserId, ClientId, Int32, Bool)]
-> (ClientMap, IndexMap)
forall b c c'. (b -> c) -> (b -> c') -> b -> (c, c')
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& [(Domain, UserId, ClientId, Int32, Bool)] -> IndexMap
mkIndexMap) [(Domain, UserId, ClientId, Int32, Bool)]
entries

lookupMLSClients :: GroupId -> Client ClientMap
lookupMLSClients :: GroupId -> Client ClientMap
lookupMLSClients = ((ClientMap, IndexMap) -> ClientMap)
-> Client (ClientMap, IndexMap) -> Client ClientMap
forall a b. (a -> b) -> Client a -> Client b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClientMap, IndexMap) -> ClientMap
forall a b. (a, b) -> a
fst (Client (ClientMap, IndexMap) -> Client ClientMap)
-> (GroupId -> Client (ClientMap, IndexMap))
-> GroupId
-> Client ClientMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GroupId -> Client (ClientMap, IndexMap)
lookupMLSClientLeafIndices