-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Wire.API.MLS.Group.Serialisation
  ( GroupIdParts (..),
    groupIdParts,
    convToGroupId,
    groupIdToConv,
    nextGenGroupId,
  )
where

import Data.Bifunctor
import Data.Binary.Get
import Data.Binary.Put
import Data.ByteString.Conversion
import Data.ByteString.Lazy qualified as L
import Data.Domain
import Data.Id
import Data.Qualified
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Data.UUID qualified as UUID
import Imports
import Web.HttpApiData (FromHttpApiData (parseHeader))
import Wire.API.Conversation
import Wire.API.MLS.Group
import Wire.API.MLS.SubConversation

data GroupIdParts = GroupIdParts
  { GroupIdParts -> ConvType
convType :: ConvType,
    GroupIdParts -> Qualified ConvOrSubConvId
qConvId :: Qualified ConvOrSubConvId,
    GroupIdParts -> GroupIdGen
gidGen :: GroupIdGen
  }
  deriving (Int -> GroupIdParts -> ShowS
[GroupIdParts] -> ShowS
GroupIdParts -> String
(Int -> GroupIdParts -> ShowS)
-> (GroupIdParts -> String)
-> ([GroupIdParts] -> ShowS)
-> Show GroupIdParts
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GroupIdParts -> ShowS
showsPrec :: Int -> GroupIdParts -> ShowS
$cshow :: GroupIdParts -> String
show :: GroupIdParts -> String
$cshowList :: [GroupIdParts] -> ShowS
showList :: [GroupIdParts] -> ShowS
Show, GroupIdParts -> GroupIdParts -> Bool
(GroupIdParts -> GroupIdParts -> Bool)
-> (GroupIdParts -> GroupIdParts -> Bool) -> Eq GroupIdParts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: GroupIdParts -> GroupIdParts -> Bool
== :: GroupIdParts -> GroupIdParts -> Bool
$c/= :: GroupIdParts -> GroupIdParts -> Bool
/= :: GroupIdParts -> GroupIdParts -> Bool
Eq)

groupIdParts :: ConvType -> Qualified ConvOrSubConvId -> GroupIdParts
groupIdParts :: ConvType -> Qualified ConvOrSubConvId -> GroupIdParts
groupIdParts ConvType
ct Qualified ConvOrSubConvId
qcs =
  GroupIdParts
    { $sel:convType:GroupIdParts :: ConvType
convType = ConvType
ct,
      $sel:qConvId:GroupIdParts :: Qualified ConvOrSubConvId
qConvId = Qualified ConvOrSubConvId
qcs,
      $sel:gidGen:GroupIdParts :: GroupIdGen
gidGen = Word32 -> GroupIdGen
GroupIdGen Word32
0
    }

-- | Return the group ID associated to a conversation ID. Note that is not
-- assumed to be stable over time or even consistent among different backends.
convToGroupId :: GroupIdParts -> GroupId
convToGroupId :: GroupIdParts -> GroupId
convToGroupId GroupIdParts
parts = ByteString -> GroupId
GroupId (ByteString -> GroupId) -> (Put -> ByteString) -> Put -> GroupId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
L.toStrict (ByteString -> ByteString)
-> (Put -> ByteString) -> Put -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Put -> ByteString
runPut (Put -> GroupId) -> Put -> GroupId
forall a b. (a -> b) -> a -> b
$ do
  let cs :: ConvOrSubConvId
cs = Qualified ConvOrSubConvId -> ConvOrSubConvId
forall a. Qualified a -> a
qUnqualified GroupIdParts
parts.qConvId
      subId :: Text
subId = (SubConvId -> Text) -> Maybe SubConvId -> Text
forall m a. Monoid m => (a -> m) -> Maybe a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubConvId -> Text
unSubConvId ConvOrSubConvId
cs.subconv
  Word16 -> Put
putWord16be Word16
1 -- Version 1 of the GroupId format
  Word16 -> Put
putWord16be (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word16) -> Int -> Word16
forall a b. (a -> b) -> a -> b
$ ConvType -> Int
forall a. Enum a => a -> Int
fromEnum GroupIdParts
parts.convType)
  ByteString -> Put
putLazyByteString (ByteString -> Put)
-> (Id 'Conversation -> ByteString) -> Id 'Conversation -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UUID -> ByteString
UUID.toByteString (UUID -> ByteString)
-> (Id 'Conversation -> UUID) -> Id 'Conversation -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id 'Conversation -> UUID
forall {k} (a :: k). Id a -> UUID
toUUID (Id 'Conversation -> Put) -> Id 'Conversation -> Put
forall a b. (a -> b) -> a -> b
$ ConvOrSubConvId
cs.conv
  Word8 -> Put
putWord8 (Word8 -> Put) -> Word8 -> Put
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Text -> Int
T.length Text
subId)
  ByteString -> Put
putByteString (ByteString -> Put) -> ByteString -> Put
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 Text
subId
  Put -> (SubConvId -> Put) -> Maybe SubConvId -> Put
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> Put
forall a. a -> PutM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (Put -> SubConvId -> Put
forall a b. a -> b -> a
const (Put -> SubConvId -> Put) -> Put -> SubConvId -> Put
forall a b. (a -> b) -> a -> b
$ Word32 -> Put
putWord32be (GroupIdGen -> Word32
unGroupIdGen GroupIdParts
parts.gidGen)) ConvOrSubConvId
cs.subconv
  ByteString -> Put
putLazyByteString (ByteString -> Put) -> (Domain -> ByteString) -> Domain -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Domain -> ByteString
forall a. ToByteString a => a -> ByteString
toByteString (Domain -> Put) -> Domain -> Put
forall a b. (a -> b) -> a -> b
$ Qualified ConvOrSubConvId -> Domain
forall a. Qualified a -> Domain
qDomain GroupIdParts
parts.qConvId

groupIdToConv :: GroupId -> Either String GroupIdParts
groupIdToConv :: GroupId -> Either String GroupIdParts
groupIdToConv GroupId
gid = do
  (ByteString
rem', ByteOffset
_, (Word16
ct, ConvOrSubConvId
conv, GroupIdGen
gen)) <- ((ByteString, ByteOffset, String) -> String)
-> Either
     (ByteString, ByteOffset, String)
     (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen))
-> Either
     String
     (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen))
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (\(ByteString
_, ByteOffset
_, String
msg) -> String
msg) (Either
   (ByteString, ByteOffset, String)
   (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen))
 -> Either
      String
      (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen)))
-> Either
     (ByteString, ByteOffset, String)
     (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen))
-> Either
     String
     (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen))
forall a b. (a -> b) -> a -> b
$ Get (Word16, ConvOrSubConvId, GroupIdGen)
-> ByteString
-> Either
     (ByteString, ByteOffset, String)
     (ByteString, ByteOffset, (Word16, ConvOrSubConvId, GroupIdGen))
forall a.
Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
runGetOrFail Get (Word16, ConvOrSubConvId, GroupIdGen)
forall {k} {a :: k}.
Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
readConv (ByteString -> ByteString
L.fromStrict (GroupId -> ByteString
unGroupId GroupId
gid))
  Text
domain <- (UnicodeException -> String)
-> Either UnicodeException Text -> Either String Text
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first UnicodeException -> String
forall e. Exception e => e -> String
displayException (Either UnicodeException Text -> Either String Text)
-> (ByteString -> Either UnicodeException Text)
-> ByteString
-> Either String Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either UnicodeException Text
T.decodeUtf8' (ByteString -> Either UnicodeException Text)
-> (ByteString -> ByteString)
-> ByteString
-> Either UnicodeException Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
L.toStrict (ByteString -> Either String Text)
-> ByteString -> Either String Text
forall a b. (a -> b) -> a -> b
$ ByteString
rem'
  GroupIdParts -> Either String GroupIdParts
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    GroupIdParts
      { $sel:convType:GroupIdParts :: ConvType
convType = Int -> ConvType
forall a. Enum a => Int -> a
toEnum (Int -> ConvType) -> Int -> ConvType
forall a b. (a -> b) -> a -> b
$ Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
ct,
        $sel:qConvId:GroupIdParts :: Qualified ConvOrSubConvId
qConvId = ConvOrSubConvId -> Domain -> Qualified ConvOrSubConvId
forall a. a -> Domain -> Qualified a
Qualified ConvOrSubConvId
conv (Text -> Domain
Domain Text
domain),
        $sel:gidGen:GroupIdParts :: GroupIdGen
gidGen = GroupIdGen
gen
      }
  where
    readConv :: Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
readConv = do
      Word16
version <- Get Word16
getWord16be
      Word16
ct <- Get Word16
getWord16be
      Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Word16
version Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Word16
1) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"unsupported groupId version"
      Maybe UUID
mUUID <- ByteString -> Maybe UUID
UUID.fromByteString (ByteString -> Maybe UUID)
-> (ByteString -> ByteString) -> ByteString -> Maybe UUID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
L.fromStrict (ByteString -> Maybe UUID) -> Get ByteString -> Get (Maybe UUID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get ByteString
getByteString Int
16
      UUID
uuid <- Get UUID -> (UUID -> Get UUID) -> Maybe UUID -> Get UUID
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> Get UUID
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid conversation UUID in groupId") UUID -> Get UUID
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe UUID
mUUID
      Word8
n <- Get Word8
getWord8
      if Word8
n Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0
        then (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
-> Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
 -> Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen))
-> (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
-> Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
forall a b. (a -> b) -> a -> b
$ (Word16
ct, Id a -> ConvOrSubChoice (Id a) SubConvId
forall c s. c -> ConvOrSubChoice c s
Conv (UUID -> Id a
forall {k} (a :: k). UUID -> Id a
Id UUID
uuid), Word32 -> GroupIdGen
GroupIdGen Word32
0)
        else do
          ByteString
subConvIdBS <- Int -> Get ByteString
getByteString (Int -> Get ByteString) -> Int -> Get ByteString
forall a b. (a -> b) -> a -> b
$ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
n
          Text
subConvId <- (Text -> Get Text)
-> (Text -> Get Text) -> Either Text Text -> Get Text
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Get Text
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get Text) -> (Text -> String) -> Text -> Get Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack) Text -> Get Text
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Text Text -> Get Text) -> Either Text Text -> Get Text
forall a b. (a -> b) -> a -> b
$ ByteString -> Either Text Text
forall a. FromHttpApiData a => ByteString -> Either Text a
parseHeader ByteString
subConvIdBS
          Word32
gen <- Get Word32
getWord32be
          (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
-> Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
 -> Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen))
-> (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
-> Get (Word16, ConvOrSubChoice (Id a) SubConvId, GroupIdGen)
forall a b. (a -> b) -> a -> b
$ (Word16
ct, Id a -> SubConvId -> ConvOrSubChoice (Id a) SubConvId
forall c s. c -> s -> ConvOrSubChoice c s
SubConv (UUID -> Id a
forall {k} (a :: k). UUID -> Id a
Id UUID
uuid) (Text -> SubConvId
SubConvId Text
subConvId), Word32 -> GroupIdGen
GroupIdGen Word32
gen)

nextGenGroupId :: GroupId -> Either String GroupId
nextGenGroupId :: GroupId -> Either String GroupId
nextGenGroupId GroupId
gid = GroupIdParts -> GroupId
convToGroupId (GroupIdParts -> GroupId)
-> (GroupIdParts -> GroupIdParts) -> GroupIdParts -> GroupId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GroupIdParts -> GroupIdParts
succGen (GroupIdParts -> GroupId)
-> Either String GroupIdParts -> Either String GroupId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GroupId -> Either String GroupIdParts
groupIdToConv GroupId
gid
  where
    succGen :: GroupIdParts -> GroupIdParts
succGen GroupIdParts
parts =
      GroupIdParts
parts
        { gidGen = GroupIdGen (succ $ unGroupIdGen parts.gidGen)
        }