-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Galley.API.MLS.Util where

import Control.Comonad
import Data.Hex
import Data.Id
import Data.Qualified
import Data.Set qualified as Set
import Data.Text qualified as T
import Galley.Data.Conversation.Types hiding (Conversation)
import Galley.Data.Conversation.Types qualified as Data
import Galley.Data.Types
import Galley.Effects
import Galley.Effects.ConversationStore
import Galley.Effects.MemberStore
import Galley.Effects.ProposalStore
import Galley.Effects.SubConversationStore
import Imports
import Polysemy
import Polysemy.Error
import Polysemy.Resource (Resource, bracket)
import Polysemy.TinyLog (TinyLog)
import Polysemy.TinyLog qualified as TinyLog
import System.Logger qualified as Log
import Wire.API.Conversation hiding (Member)
import Wire.API.Error
import Wire.API.Error.Galley
import Wire.API.MLS.Epoch
import Wire.API.MLS.Group.Serialisation
import Wire.API.MLS.LeafNode
import Wire.API.MLS.Proposal
import Wire.API.MLS.Serialisation
import Wire.API.MLS.SubConversation

getLocalConvForUser ::
  ( Member (ErrorS 'ConvNotFound) r,
    Member ConversationStore r,
    Member MemberStore r
  ) =>
  Qualified UserId ->
  Local ConvId ->
  Sem r Data.Conversation
getLocalConvForUser :: forall (r :: EffectRow).
(Member (ErrorS 'ConvNotFound) r, Member ConversationStore r,
 Member MemberStore r) =>
Qualified UserId -> Local ConvId -> Sem r Conversation
getLocalConvForUser Qualified UserId
qusr Local ConvId
lcnv = do
  Conversation
conv <- ConvId -> Sem r (Maybe Conversation)
forall (r :: EffectRow).
Member ConversationStore r =>
ConvId -> Sem r (Maybe Conversation)
getConversation (Local ConvId -> ConvId
forall (t :: QTag) a. QualifiedWithTag t a -> a
tUnqualified Local ConvId
lcnv) Sem r (Maybe Conversation)
-> (Maybe Conversation -> Sem r Conversation) -> Sem r Conversation
forall a b. Sem r a -> (a -> Sem r b) -> Sem r b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
noteS @'ConvNotFound

  -- check that sender is part of conversation
  Bool
isMember' <-
    Local ConvId
-> (Local UserId -> Sem r Bool)
-> (Remote UserId -> Sem r Bool)
-> Qualified UserId
-> Sem r Bool
forall x a b.
Local x -> (Local a -> b) -> (Remote a -> b) -> Qualified a -> b
foldQualified
      Local ConvId
lcnv
      ( (Maybe LocalMember -> Bool)
-> Sem r (Maybe LocalMember) -> Sem r Bool
forall a b. (a -> b) -> Sem r a -> Sem r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe LocalMember -> Bool
forall a. Maybe a -> Bool
isJust
          (Sem r (Maybe LocalMember) -> Sem r Bool)
-> (Local UserId -> Sem r (Maybe LocalMember))
-> Local UserId
-> Sem r Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvId -> UserId -> Sem r (Maybe LocalMember)
forall (r :: EffectRow).
Member MemberStore r =>
ConvId -> UserId -> Sem r (Maybe LocalMember)
getLocalMember (Conversation -> ConvId
convId Conversation
conv)
          (UserId -> Sem r (Maybe LocalMember))
-> (Local UserId -> UserId)
-> Local UserId
-> Sem r (Maybe LocalMember)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Local UserId -> UserId
forall (t :: QTag) a. QualifiedWithTag t a -> a
tUnqualified
      )
      ((Maybe RemoteMember -> Bool)
-> Sem r (Maybe RemoteMember) -> Sem r Bool
forall a b. (a -> b) -> Sem r a -> Sem r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe RemoteMember -> Bool
forall a. Maybe a -> Bool
isJust (Sem r (Maybe RemoteMember) -> Sem r Bool)
-> (Remote UserId -> Sem r (Maybe RemoteMember))
-> Remote UserId
-> Sem r Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvId -> Remote UserId -> Sem r (Maybe RemoteMember)
forall (r :: EffectRow).
Member MemberStore r =>
ConvId -> Remote UserId -> Sem r (Maybe RemoteMember)
getRemoteMember (Conversation -> ConvId
convId Conversation
conv))
      Qualified UserId
qusr
  Bool -> Sem r () -> Sem r ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
isMember' (Sem r () -> Sem r ()) -> Sem r () -> Sem r ()
forall a b. (a -> b) -> a -> b
$ forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Sem r a
throwS @'ConvNotFound

  Conversation -> Sem r Conversation
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Conversation
conv

getPendingBackendRemoveProposals ::
  ( Member ProposalStore r,
    Member TinyLog r
  ) =>
  GroupId ->
  Epoch ->
  Sem r (Set LeafIndex)
getPendingBackendRemoveProposals :: forall (r :: EffectRow).
(Member ProposalStore r, Member TinyLog r) =>
GroupId -> Epoch -> Sem r (Set LeafIndex)
getPendingBackendRemoveProposals GroupId
gid Epoch
epoch = do
  [(Maybe ProposalOrigin, RawMLS Proposal)]
proposals <- GroupId -> Epoch -> Sem r [(Maybe ProposalOrigin, RawMLS Proposal)]
forall (r :: EffectRow).
Member ProposalStore r =>
GroupId -> Epoch -> Sem r [(Maybe ProposalOrigin, RawMLS Proposal)]
getAllPendingProposals GroupId
gid Epoch
epoch
  [LeafIndex]
indexList <-
    [Maybe LeafIndex] -> [LeafIndex]
forall a. [Maybe a] -> [a]
catMaybes
      ([Maybe LeafIndex] -> [LeafIndex])
-> Sem r [Maybe LeafIndex] -> Sem r [LeafIndex]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Maybe ProposalOrigin, RawMLS Proposal)]
-> ((Maybe ProposalOrigin, RawMLS Proposal)
    -> Sem r (Maybe LeafIndex))
-> Sem r [Maybe LeafIndex]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for
        [(Maybe ProposalOrigin, RawMLS Proposal)]
proposals
        ( \case
            (Just ProposalOrigin
ProposalOriginBackend, RawMLS Proposal
proposal) -> case RawMLS Proposal
proposal.value of
              RemoveProposal LeafIndex
i -> Maybe LeafIndex -> Sem r (Maybe LeafIndex)
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LeafIndex -> Maybe LeafIndex
forall a. a -> Maybe a
Just LeafIndex
i)
              Proposal
_ -> Maybe LeafIndex -> Sem r (Maybe LeafIndex)
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe LeafIndex
forall a. Maybe a
Nothing
            (Just ProposalOrigin
ProposalOriginClient, RawMLS Proposal
_) -> Maybe LeafIndex -> Sem r (Maybe LeafIndex)
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe LeafIndex
forall a. Maybe a
Nothing
            (Maybe ProposalOrigin
Nothing, RawMLS Proposal
_) -> do
              (Msg -> Msg) -> Sem r ()
forall msg (r :: EffectRow).
Member (Logger msg) r =>
msg -> Sem r ()
TinyLog.warn ((Msg -> Msg) -> Sem r ()) -> (Msg -> Msg) -> Sem r ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
Log.msg (ByteString
"found pending proposal without origin, ignoring" :: ByteString)
              Maybe LeafIndex -> Sem r (Maybe LeafIndex)
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe LeafIndex
forall a. Maybe a
Nothing
        )

  let indexSet :: Set LeafIndex
indexSet = [LeafIndex] -> Set LeafIndex
forall a. Ord a => [a] -> Set a
Set.fromList [LeafIndex]
indexList
  Bool -> Sem r () -> Sem r ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([LeafIndex] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LeafIndex]
indexList Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Set LeafIndex -> Int
forall a. Set a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Set LeafIndex
indexSet) (Sem r () -> Sem r ()) -> Sem r () -> Sem r ()
forall a b. (a -> b) -> a -> b
$ do
    (Msg -> Msg) -> Sem r ()
forall msg (r :: EffectRow).
Member (Logger msg) r =>
msg -> Sem r ()
TinyLog.warn ((Msg -> Msg) -> Sem r ()) -> (Msg -> Msg) -> Sem r ()
forall a b. (a -> b) -> a -> b
$
      ByteString -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
Log.msg (ByteString
"found duplicate proposals" :: ByteString)
        (Msg -> Msg) -> (Msg -> Msg) -> Msg -> Msg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
Log.field ByteString
"groupId" (ByteString
"0x" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
forall t. Hex t => t -> t
hex (GroupId -> ByteString
unGroupId GroupId
gid))
        (Msg -> Msg) -> (Msg -> Msg) -> Msg -> Msg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word64 -> Msg -> Msg
forall a. ToBytes a => ByteString -> a -> Msg -> Msg
Log.field ByteString
"epoch" (Epoch -> Word64
epochNumber Epoch
epoch)
  Set LeafIndex -> Sem r (Set LeafIndex)
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set LeafIndex
indexSet

withCommitLock ::
  forall r a.
  ( Members
      '[ Resource,
         ConversationStore,
         ErrorS 'MLSStaleMessage,
         SubConversationStore
       ]
      r
  ) =>
  Local ConvOrSubConvId ->
  GroupId ->
  Epoch ->
  Sem r a ->
  Sem r a
withCommitLock :: forall (r :: EffectRow) a.
Members
  '[Resource, ConversationStore, ErrorS 'MLSStaleMessage,
    SubConversationStore]
  r =>
Local ConvOrSubConvId -> GroupId -> Epoch -> Sem r a -> Sem r a
withCommitLock Local ConvOrSubConvId
lConvOrSubId GroupId
gid Epoch
epoch Sem r a
action =
  Sem r () -> (() -> Sem r ()) -> (() -> Sem r a) -> Sem r a
forall (r :: EffectRow) a c b.
Member Resource r =>
Sem r a -> (a -> Sem r c) -> (a -> Sem r b) -> Sem r b
bracket
    ( GroupId -> Epoch -> NominalDiffTime -> Sem r LockAcquired
forall (r :: EffectRow).
Member ConversationStore r =>
GroupId -> Epoch -> NominalDiffTime -> Sem r LockAcquired
acquireCommitLock GroupId
gid Epoch
epoch NominalDiffTime
ttl Sem r LockAcquired -> (LockAcquired -> Sem r ()) -> Sem r ()
forall a b. Sem r a -> (a -> Sem r b) -> Sem r b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \LockAcquired
lockAcquired ->
        Bool -> Sem r () -> Sem r ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (LockAcquired
lockAcquired LockAcquired -> LockAcquired -> Bool
forall a. Eq a => a -> a -> Bool
== LockAcquired
NotAcquired) (Sem r () -> Sem r ()) -> Sem r () -> Sem r ()
forall a b. (a -> b) -> a -> b
$
          forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Sem r a
throwS @'MLSStaleMessage
    )
    (Sem r () -> () -> Sem r ()
forall a b. a -> b -> a
const (Sem r () -> () -> Sem r ()) -> Sem r () -> () -> Sem r ()
forall a b. (a -> b) -> a -> b
$ GroupId -> Epoch -> Sem r ()
forall (r :: EffectRow).
Member ConversationStore r =>
GroupId -> Epoch -> Sem r ()
releaseCommitLock GroupId
gid Epoch
epoch)
    ((() -> Sem r a) -> Sem r a) -> (() -> Sem r a) -> Sem r a
forall a b. (a -> b) -> a -> b
$ \()
_ -> do
      Epoch
actualEpoch <-
        Epoch -> Maybe Epoch -> Epoch
forall a. a -> Maybe a -> a
fromMaybe (Word64 -> Epoch
Epoch Word64
0) (Maybe Epoch -> Epoch) -> Sem r (Maybe Epoch) -> Sem r Epoch
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case Local ConvOrSubConvId -> ConvOrSubConvId
forall (t :: QTag) a. QualifiedWithTag t a -> a
tUnqualified Local ConvOrSubConvId
lConvOrSubId of
          Conv ConvId
cnv -> ConvId -> Sem r (Maybe Epoch)
forall (r :: EffectRow).
Member ConversationStore r =>
ConvId -> Sem r (Maybe Epoch)
getConversationEpoch ConvId
cnv
          SubConv ConvId
cnv SubConvId
sub -> ConvId -> SubConvId -> Sem r (Maybe Epoch)
forall (r :: EffectRow).
Member SubConversationStore r =>
ConvId -> SubConvId -> Sem r (Maybe Epoch)
getSubConversationEpoch ConvId
cnv SubConvId
sub
      Bool -> Sem r () -> Sem r ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Epoch
actualEpoch Epoch -> Epoch -> Bool
forall a. Eq a => a -> a -> Bool
== Epoch
epoch) (Sem r () -> Sem r ()) -> Sem r () -> Sem r ()
forall a b. (a -> b) -> a -> b
$ forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Sem r a
throwS @'MLSStaleMessage
      Sem r a
action
  where
    ttl :: NominalDiffTime
ttl = Int -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
600 :: Int) -- 10 minutes

getConvFromGroupId ::
  (Member (Error MLSProtocolError) r) =>
  GroupId ->
  Sem r (ConvType, Qualified ConvOrSubConvId)
getConvFromGroupId :: forall (r :: EffectRow).
Member (Error MLSProtocolError) r =>
GroupId -> Sem r (ConvType, Qualified ConvOrSubConvId)
getConvFromGroupId GroupId
gid = case GroupId -> Either String GroupIdParts
groupIdToConv GroupId
gid of
  Left String
e -> MLSProtocolError -> Sem r (ConvType, Qualified ConvOrSubConvId)
forall e (r :: EffectRow) a. Member (Error e) r => e -> Sem r a
throw (Text -> MLSProtocolError
mlsProtocolError (String -> Text
T.pack String
e))
  Right GroupIdParts
parts -> (ConvType, Qualified ConvOrSubConvId)
-> Sem r (ConvType, Qualified ConvOrSubConvId)
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GroupIdParts
parts.convType, GroupIdParts
parts.qConvId)