-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Galley.API.MLS.Commit.Core
  ( getCommitData,
    incrementEpoch,
    getClientInfo,
    HasProposalActionEffects,
    ProposalErrors,
    HandleMLSProposalFailures (..),
  )
where

import Control.Comonad
import Data.Id
import Data.Qualified
import Data.Time
import Galley.API.Error
import Galley.API.MLS.Conversation
import Galley.API.MLS.IncomingMessage
import Galley.API.MLS.Proposal
import Galley.API.MLS.Types
import Galley.Effects
import Galley.Effects.BrigAccess
import Galley.Effects.ConversationStore
import Galley.Effects.FederatorAccess
import Galley.Effects.SubConversationStore
import Galley.Env
import Galley.Options
import Imports
import Polysemy
import Polysemy.Error
import Polysemy.Input
import Polysemy.Internal
import Polysemy.State
import Polysemy.TinyLog
import Wire.API.Conversation.Protocol
import Wire.API.Conversation.Role
import Wire.API.Error
import Wire.API.Error.Galley
import Wire.API.Federation.API
import Wire.API.Federation.API.Brig
import Wire.API.Federation.Endpoint
import Wire.API.Federation.Error
import Wire.API.Federation.Version
import Wire.API.MLS.CipherSuite
import Wire.API.MLS.Commit
import Wire.API.MLS.Credential
import Wire.API.MLS.Serialisation
import Wire.API.MLS.SubConversation
import Wire.API.User.Client
import Wire.NotificationSubsystem

type HasProposalActionEffects r =
  ( Member BackendNotificationQueueAccess r,
    Member BrigAccess r,
    Member ConversationStore r,
    Member (Error InternalError) r,
    Member (ErrorS 'ConvNotFound) r,
    Member (ErrorS 'MLSClientMismatch) r,
    Member (Error MLSProposalFailure) r,
    Member (ErrorS 'MissingLegalholdConsent) r,
    Member (ErrorS 'MLSUnsupportedProposal) r,
    Member (Error MLSProtocolError) r,
    Member (Error NonFederatingBackends) r,
    Member (Error UnreachableBackends) r,
    Member (ErrorS 'MLSSelfRemovalNotAllowed) r,
    Member ExternalAccess r,
    Member FederatorAccess r,
    Member (Input Env) r,
    Member (Input Opts) r,
    Member (Input UTCTime) r,
    Member LegalHoldStore r,
    Member MemberStore r,
    Member ProposalStore r,
    Member SubConversationStore r,
    Member TeamStore r,
    Member TinyLog r,
    Member NotificationSubsystem r,
    Member Random r
  )

getCommitData ::
  ( HasProposalEffects r,
    Member (ErrorS 'MLSProposalNotFound) r
  ) =>
  ClientIdentity ->
  Local ConvOrSubConv ->
  Epoch ->
  CipherSuiteTag ->
  IncomingBundle ->
  Sem r ProposalAction
getCommitData :: forall (r :: EffectRow).
(HasProposalEffects r, Member (ErrorS 'MLSProposalNotFound) r) =>
ClientIdentity
-> Local ConvOrSubConv
-> Epoch
-> CipherSuiteTag
-> IncomingBundle
-> Sem r ProposalAction
getCommitData ClientIdentity
senderIdentity Local ConvOrSubConv
lConvOrSub Epoch
epoch CipherSuiteTag
ciphersuite IncomingBundle
bundle = do
  let convOrSub :: ConvOrSubConv
convOrSub = Local ConvOrSubConv -> ConvOrSubConv
forall (t :: QTag) a. QualifiedWithTag t a -> a
tUnqualified Local ConvOrSubConv
lConvOrSub
      groupId :: GroupId
groupId = ConversationMLSData -> GroupId
cnvmlsGroupId ConvOrSubConv
convOrSub.mlsMeta

  IndexMap
-> Sem (State IndexMap : r) ProposalAction -> Sem r ProposalAction
forall s (r :: EffectRow) a. s -> Sem (State s : r) a -> Sem r a
evalState ConvOrSubConv
convOrSub.indexMap (Sem (State IndexMap : r) ProposalAction -> Sem r ProposalAction)
-> Sem (State IndexMap : r) ProposalAction -> Sem r ProposalAction
forall a b. (a -> b) -> a -> b
$ do
    ProposalAction
creatorAction <-
      if Epoch
epoch Epoch -> Epoch -> Bool
forall a. Eq a => a -> a -> Bool
== Word64 -> Epoch
Epoch Word64
0
        then ClientIdentity -> Sem (State IndexMap : r) ProposalAction
forall (r :: EffectRow).
Member (State IndexMap) r =>
ClientIdentity -> Sem r ProposalAction
addProposedClient ClientIdentity
senderIdentity
        else Sem (State IndexMap : r) ProposalAction
forall a. Monoid a => a
mempty
    [Proposal]
proposals <-
      (ProposalOrRef -> Sem (State IndexMap : r) Proposal)
-> [ProposalOrRef] -> Sem (State IndexMap : r) [Proposal]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
        (Epoch
-> CipherSuiteTag
-> GroupId
-> ProposalOrRef
-> Sem (State IndexMap : r) Proposal
forall (r :: EffectRow).
(Member (Error MLSProtocolError) r,
 Member (ErrorS 'MLSInvalidLeafNodeIndex) r, Member ProposalStore r,
 Member (State IndexMap) r,
 Member (ErrorS 'MLSProposalNotFound) r) =>
Epoch
-> CipherSuiteTag -> GroupId -> ProposalOrRef -> Sem r Proposal
derefOrCheckProposal Epoch
epoch CipherSuiteTag
ciphersuite GroupId
groupId)
        IncomingBundle
bundle.commit.value.proposals
    ProposalAction
action <- CipherSuiteTag
-> GroupId -> [Proposal] -> Sem (State IndexMap : r) ProposalAction
forall (r :: EffectRow).
(Member (State IndexMap) r, Member (Error MLSProtocolError) r,
 Member (ErrorS 'MLSUnsupportedProposal) r,
 Member (ErrorS 'MLSInvalidLeafNodeIndex) r) =>
CipherSuiteTag -> GroupId -> [Proposal] -> Sem r ProposalAction
applyProposals CipherSuiteTag
ciphersuite GroupId
groupId [Proposal]
proposals
    ProposalAction -> Sem (State IndexMap : r) ProposalAction
forall a. a -> Sem (State IndexMap : r) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ProposalAction
creatorAction ProposalAction -> ProposalAction -> ProposalAction
forall a. Semigroup a => a -> a -> a
<> ProposalAction
action)

incrementEpoch ::
  ( Member ConversationStore r,
    Member (ErrorS 'ConvNotFound) r,
    Member MemberStore r,
    Member SubConversationStore r
  ) =>
  ConvOrSubConv ->
  Sem r ConvOrSubConv
incrementEpoch :: forall (r :: EffectRow).
(Member ConversationStore r, Member (ErrorS 'ConvNotFound) r,
 Member MemberStore r, Member SubConversationStore r) =>
ConvOrSubConv -> Sem r ConvOrSubConv
incrementEpoch (Conv MLSConversation
c) = do
  let epoch' :: Epoch
epoch' = Epoch -> Epoch
forall a. Enum a => a -> a
succ (ConversationMLSData -> Epoch
cnvmlsEpoch (MLSConversation -> ConversationMLSData
mcMLSData MLSConversation
c))
  ConvId -> Epoch -> Sem r ()
forall (r :: EffectRow).
Member ConversationStore r =>
ConvId -> Epoch -> Sem r ()
setConversationEpoch (MLSConversation -> ConvId
mcId MLSConversation
c) Epoch
epoch'
  Conversation
conv <- ConvId -> Sem r (Maybe Conversation)
forall (r :: EffectRow).
Member ConversationStore r =>
ConvId -> Sem r (Maybe Conversation)
getConversation (MLSConversation -> ConvId
mcId MLSConversation
c) Sem r (Maybe Conversation)
-> (Maybe Conversation -> Sem r Conversation) -> Sem r Conversation
forall a b. Sem r a -> (a -> Sem r b) -> Sem r b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
noteS @'ConvNotFound
  (MLSConversation -> ConvOrSubConv)
-> Sem r MLSConversation -> Sem r ConvOrSubConv
forall a b. (a -> b) -> Sem r a -> Sem r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MLSConversation -> ConvOrSubConv
forall c s. c -> ConvOrSubChoice c s
Conv (Conversation -> Sem r (Maybe MLSConversation)
forall (r :: EffectRow).
Member MemberStore r =>
Conversation -> Sem r (Maybe MLSConversation)
mkMLSConversation Conversation
conv Sem r (Maybe MLSConversation)
-> (Maybe MLSConversation -> Sem r MLSConversation)
-> Sem r MLSConversation
forall a b. Sem r a -> (a -> Sem r b) -> Sem r b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
noteS @'ConvNotFound)
incrementEpoch (SubConv MLSConversation
c SubConversation
s) = do
  let epoch' :: Epoch
epoch' = Epoch -> Epoch
forall a. Enum a => a -> a
succ (ConversationMLSData -> Epoch
cnvmlsEpoch (SubConversation -> ConversationMLSData
scMLSData SubConversation
s))
  ConvId -> SubConvId -> Epoch -> Sem r ()
forall (r :: EffectRow).
Member SubConversationStore r =>
ConvId -> SubConvId -> Epoch -> Sem r ()
setSubConversationEpoch (SubConversation -> ConvId
scParentConvId SubConversation
s) (SubConversation -> SubConvId
scSubConvId SubConversation
s) Epoch
epoch'
  SubConversation
subconv <-
    ConvId -> SubConvId -> Sem r (Maybe SubConversation)
forall (r :: EffectRow).
Member SubConversationStore r =>
ConvId -> SubConvId -> Sem r (Maybe SubConversation)
getSubConversation (MLSConversation -> ConvId
mcId MLSConversation
c) (SubConversation -> SubConvId
scSubConvId SubConversation
s) Sem r (Maybe SubConversation)
-> (Maybe SubConversation -> Sem r SubConversation)
-> Sem r SubConversation
forall a b. Sem r a -> (a -> Sem r b) -> Sem r b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {k} (e :: k) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
forall (e :: GalleyError) (r :: EffectRow) a.
Member (ErrorS e) r =>
Maybe a -> Sem r a
noteS @'ConvNotFound
  ConvOrSubConv -> Sem r ConvOrSubConv
forall a. a -> Sem r a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MLSConversation -> SubConversation -> ConvOrSubConv
forall c s. c -> s -> ConvOrSubChoice c s
SubConv MLSConversation
c SubConversation
subconv)

getClientInfo ::
  ( Member BrigAccess r,
    Member FederatorAccess r
  ) =>
  Local x ->
  Qualified UserId ->
  CipherSuiteTag ->
  Sem r (Either FederationError (Set ClientInfo))
getClientInfo :: forall (r :: EffectRow) x.
(Member BrigAccess r, Member FederatorAccess r) =>
Local x
-> Qualified UserId
-> CipherSuiteTag
-> Sem r (Either FederationError (Set ClientInfo))
getClientInfo Local x
loc =
  Local x
-> (Local UserId
    -> CipherSuiteTag
    -> Sem r (Either FederationError (Set ClientInfo)))
-> (Remote UserId
    -> CipherSuiteTag
    -> Sem r (Either FederationError (Set ClientInfo)))
-> Qualified UserId
-> CipherSuiteTag
-> Sem r (Either FederationError (Set ClientInfo))
forall x a b.
Local x -> (Local a -> b) -> (Remote a -> b) -> Qualified a -> b
foldQualified Local x
loc (\Local UserId
lusr -> (Set ClientInfo -> Either FederationError (Set ClientInfo))
-> Sem r (Set ClientInfo)
-> Sem r (Either FederationError (Set ClientInfo))
forall a b. (a -> b) -> Sem r a -> Sem r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Set ClientInfo -> Either FederationError (Set ClientInfo)
forall a b. b -> Either a b
Right (Sem r (Set ClientInfo)
 -> Sem r (Either FederationError (Set ClientInfo)))
-> (CipherSuiteTag -> Sem r (Set ClientInfo))
-> CipherSuiteTag
-> Sem r (Either FederationError (Set ClientInfo))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Local UserId -> CipherSuiteTag -> Sem r (Set ClientInfo)
forall (r :: EffectRow).
Member BrigAccess r =>
Local UserId -> CipherSuiteTag -> Sem r (Set ClientInfo)
getLocalMLSClients Local UserId
lusr) Remote UserId
-> CipherSuiteTag
-> Sem r (Either FederationError (Set ClientInfo))
forall (r :: EffectRow).
Member FederatorAccess r =>
Remote UserId
-> CipherSuiteTag
-> Sem r (Either FederationError (Set ClientInfo))
getRemoteMLSClients

getRemoteMLSClients ::
  ( Member FederatorAccess r
  ) =>
  Remote UserId ->
  CipherSuiteTag ->
  Sem r (Either FederationError (Set ClientInfo))
getRemoteMLSClients :: forall (r :: EffectRow).
Member FederatorAccess r =>
Remote UserId
-> CipherSuiteTag
-> Sem r (Either FederationError (Set ClientInfo))
getRemoteMLSClients Remote UserId
rusr CipherSuiteTag
suite = do
  let mcr :: MLSClientsRequest
mcr =
        MLSClientsRequest
          { $sel:userId:MLSClientsRequest :: UserId
userId = Remote UserId -> UserId
forall (t :: QTag) a. QualifiedWithTag t a -> a
tUnqualified Remote UserId
rusr,
            $sel:cipherSuite:MLSClientsRequest :: CipherSuite
cipherSuite = CipherSuiteTag -> CipherSuite
tagCipherSuite CipherSuiteTag
suite
          }
  Remote UserId
-> FederatorClient 'Brig (Set ClientInfo)
-> Sem r (Either FederationError (Set ClientInfo))
forall (r :: EffectRow) (c :: Component) x a.
(Member FederatorAccess r, KnownComponent c) =>
Remote x -> FederatorClient c a -> Sem r (Either FederationError a)
runFederatedEither Remote UserId
rusr (FederatorClient 'Brig (Set ClientInfo)
 -> Sem r (Either FederationError (Set ClientInfo)))
-> FederatorClient 'Brig (Set ClientInfo)
-> Sem r (Either FederationError (Set ClientInfo))
forall a b. (a -> b) -> a -> b
$
    forall {k} (comp :: Component) (name :: k)
       (fedM :: Component -> * -> *) (showcomp :: Symbol) api x.
(AddAnnotation 'Remote showcomp (FedPath name) x,
 showcomp ~ ShowComponent comp, HasFedEndpoint comp api name,
 HasClient (fedM comp) api, KnownComponent comp, IsNamed name,
 FederationMonad fedM, Typeable (Client (fedM comp) api)) =>
Client (fedM comp) api
forall (comp :: Component) (name :: Symbol)
       (fedM :: Component -> * -> *) (showcomp :: Symbol) api x.
(AddAnnotation 'Remote showcomp (FedPath name) x,
 showcomp ~ ShowComponent comp, HasFedEndpoint comp api name,
 HasClient (fedM comp) api, KnownComponent comp, IsNamed name,
 FederationMonad fedM, Typeable (Client (fedM comp) api)) =>
Client (fedM comp) api
fedClient @'Brig @"get-mls-clients" MLSClientsRequest
mcr
      FederatorClient 'Brig (Set ClientInfo)
-> FederatorClient 'Brig (Set ClientInfo)
-> FederatorClient 'Brig (Set ClientInfo)
forall a.
FederatorClient 'Brig a
-> FederatorClient 'Brig a -> FederatorClient 'Brig a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall {k} (comp :: Component) (name :: k)
       (fedM :: Component -> * -> *) (showcomp :: Symbol) api x.
(AddAnnotation 'Remote showcomp (FedPath name) x,
 showcomp ~ ShowComponent comp, HasFedEndpoint comp api name,
 HasClient (fedM comp) api, KnownComponent comp, IsNamed name,
 FederationMonad fedM, Typeable (Client (fedM comp) api)) =>
Client (fedM comp) api
forall (comp :: Component) name (fedM :: Component -> * -> *)
       (showcomp :: Symbol) api x.
(AddAnnotation 'Remote showcomp (FedPath name) x,
 showcomp ~ ShowComponent comp, HasFedEndpoint comp api name,
 HasClient (fedM comp) api, KnownComponent comp, IsNamed name,
 FederationMonad fedM, Typeable (Client (fedM comp) api)) =>
Client (fedM comp) api
fedClient @'Brig @(Versioned 'V0 "get-mls-clients") (MLSClientsRequest -> MLSClientsRequestV0
mlsClientsRequestToV0 MLSClientsRequest
mcr)

--------------------------------------------------------------------------------
-- Error handling of proposal execution

-- The following errors are caught by 'executeProposalAction' and wrapped in a
-- 'MLSProposalFailure'. This way errors caused by the execution of proposals are
-- separated from those caused by the commit processing itself.
type ProposalErrors =
  '[ Error FederationError,
     Error InvalidInput,
     ErrorS ('ActionDenied 'AddConversationMember),
     ErrorS ('ActionDenied 'LeaveConversation),
     ErrorS ('ActionDenied 'RemoveConversationMember),
     ErrorS 'ConvAccessDenied,
     ErrorS 'InvalidOperation,
     ErrorS 'NotATeamMember,
     ErrorS 'NotConnected,
     ErrorS 'TooManyMembers
   ]

class HandleMLSProposalFailures effs r where
  handleMLSProposalFailures :: Sem (Append effs r) a -> Sem r a

class HandleMLSProposalFailure eff r where
  handleMLSProposalFailure :: Sem (eff ': r) a -> Sem r a

instance HandleMLSProposalFailures '[] r where
  handleMLSProposalFailures :: forall a. Sem (Append '[] r) a -> Sem r a
handleMLSProposalFailures = Sem r a -> Sem r a
Sem (Append '[] r) a -> Sem r a
forall a. a -> a
id

instance
  ( HandleMLSProposalFailures effs r,
    HandleMLSProposalFailure eff (Append effs r)
  ) =>
  HandleMLSProposalFailures (eff ': effs) r
  where
  handleMLSProposalFailures :: forall a. Sem (Append (eff : effs) r) a -> Sem r a
handleMLSProposalFailures = forall (effs :: EffectRow) (r :: EffectRow) a.
HandleMLSProposalFailures effs r =>
Sem (Append effs r) a -> Sem r a
handleMLSProposalFailures @effs (Sem (Append effs r) a -> Sem r a)
-> (Sem (eff : Append effs r) a -> Sem (Append effs r) a)
-> Sem (eff : Append effs r) a
-> Sem r a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (eff :: (* -> *) -> * -> *) (r :: EffectRow) a.
HandleMLSProposalFailure eff r =>
Sem (eff : r) a -> Sem r a
handleMLSProposalFailure @eff

instance
  (APIError e, Member (Error MLSProposalFailure) r) =>
  HandleMLSProposalFailure (Error e) r
  where
  handleMLSProposalFailure :: forall a. Sem (Error e : r) a -> Sem r a
handleMLSProposalFailure = (e -> MLSProposalFailure) -> Sem (Error e : r) a -> Sem r a
forall e1 e2 (r :: EffectRow) a.
Member (Error e2) r =>
(e1 -> e2) -> Sem (Error e1 : r) a -> Sem r a
mapError (JSONResponse -> MLSProposalFailure
MLSProposalFailure (JSONResponse -> MLSProposalFailure)
-> (e -> JSONResponse) -> e -> MLSProposalFailure
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> JSONResponse
forall e. APIError e => e -> JSONResponse
toResponse)