-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Wire.API.Routes.MultiTablePaging.State
  ( MultiTablePagingState (..),
    PagingTable (..),
  )
where

import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Attoparsec.ByteString qualified as AB
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as Base64Url
import Data.OpenApi qualified as S
import Data.Proxy
import Data.Schema
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import GHC.TypeLits
import Imports
import Servant (FromHttpApiData (..), ToHttpApiData (..))

-- | The state of a multi-table paginated query. It is made of a reference to
-- the table currently being paginated, as well as an opaque token returned by
-- Cassandra.
data MultiTablePagingState (name :: Symbol) tables = MultiTablePagingState
  { forall (name :: Symbol) tables.
MultiTablePagingState name tables -> tables
mtpsTable :: tables,
    forall (name :: Symbol) tables.
MultiTablePagingState name tables -> Maybe ByteString
mtpsState :: Maybe ByteString
  }
  deriving stock (Int -> MultiTablePagingState name tables -> ShowS
[MultiTablePagingState name tables] -> ShowS
MultiTablePagingState name tables -> String
(Int -> MultiTablePagingState name tables -> ShowS)
-> (MultiTablePagingState name tables -> String)
-> ([MultiTablePagingState name tables] -> ShowS)
-> Show (MultiTablePagingState name tables)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (name :: Symbol) tables.
Show tables =>
Int -> MultiTablePagingState name tables -> ShowS
forall (name :: Symbol) tables.
Show tables =>
[MultiTablePagingState name tables] -> ShowS
forall (name :: Symbol) tables.
Show tables =>
MultiTablePagingState name tables -> String
$cshowsPrec :: forall (name :: Symbol) tables.
Show tables =>
Int -> MultiTablePagingState name tables -> ShowS
showsPrec :: Int -> MultiTablePagingState name tables -> ShowS
$cshow :: forall (name :: Symbol) tables.
Show tables =>
MultiTablePagingState name tables -> String
show :: MultiTablePagingState name tables -> String
$cshowList :: forall (name :: Symbol) tables.
Show tables =>
[MultiTablePagingState name tables] -> ShowS
showList :: [MultiTablePagingState name tables] -> ShowS
Show, MultiTablePagingState name tables
-> MultiTablePagingState name tables -> Bool
(MultiTablePagingState name tables
 -> MultiTablePagingState name tables -> Bool)
-> (MultiTablePagingState name tables
    -> MultiTablePagingState name tables -> Bool)
-> Eq (MultiTablePagingState name tables)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (name :: Symbol) tables.
Eq tables =>
MultiTablePagingState name tables
-> MultiTablePagingState name tables -> Bool
$c== :: forall (name :: Symbol) tables.
Eq tables =>
MultiTablePagingState name tables
-> MultiTablePagingState name tables -> Bool
== :: MultiTablePagingState name tables
-> MultiTablePagingState name tables -> Bool
$c/= :: forall (name :: Symbol) tables.
Eq tables =>
MultiTablePagingState name tables
-> MultiTablePagingState name tables -> Bool
/= :: MultiTablePagingState name tables
-> MultiTablePagingState name tables -> Bool
Eq)
  deriving ([MultiTablePagingState name tables] -> Value
[MultiTablePagingState name tables] -> Encoding
MultiTablePagingState name tables -> Value
MultiTablePagingState name tables -> Encoding
(MultiTablePagingState name tables -> Value)
-> (MultiTablePagingState name tables -> Encoding)
-> ([MultiTablePagingState name tables] -> Value)
-> ([MultiTablePagingState name tables] -> Encoding)
-> ToJSON (MultiTablePagingState name tables)
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
[MultiTablePagingState name tables] -> Value
forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
[MultiTablePagingState name tables] -> Encoding
forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
MultiTablePagingState name tables -> Value
forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
MultiTablePagingState name tables -> Encoding
$ctoJSON :: forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
MultiTablePagingState name tables -> Value
toJSON :: MultiTablePagingState name tables -> Value
$ctoEncoding :: forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
MultiTablePagingState name tables -> Encoding
toEncoding :: MultiTablePagingState name tables -> Encoding
$ctoJSONList :: forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
[MultiTablePagingState name tables] -> Value
toJSONList :: [MultiTablePagingState name tables] -> Value
$ctoEncodingList :: forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
[MultiTablePagingState name tables] -> Encoding
toEncodingList :: [MultiTablePagingState name tables] -> Encoding
ToJSON, Value -> Parser [MultiTablePagingState name tables]
Value -> Parser (MultiTablePagingState name tables)
(Value -> Parser (MultiTablePagingState name tables))
-> (Value -> Parser [MultiTablePagingState name tables])
-> FromJSON (MultiTablePagingState name tables)
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
Value -> Parser [MultiTablePagingState name tables]
forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
Value -> Parser (MultiTablePagingState name tables)
$cparseJSON :: forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
Value -> Parser (MultiTablePagingState name tables)
parseJSON :: Value -> Parser (MultiTablePagingState name tables)
$cparseJSONList :: forall (name :: Symbol) tables.
(PagingTable tables, KnownSymbol name) =>
Value -> Parser [MultiTablePagingState name tables]
parseJSONList :: Value -> Parser [MultiTablePagingState name tables]
FromJSON, Typeable (MultiTablePagingState name tables)
Typeable (MultiTablePagingState name tables) =>
(Proxy (MultiTablePagingState name tables)
 -> Declare (Definitions Schema) NamedSchema)
-> ToSchema (MultiTablePagingState name tables)
Proxy (MultiTablePagingState name tables)
-> Declare (Definitions Schema) NamedSchema
forall a.
Typeable a =>
(Proxy a -> Declare (Definitions Schema) NamedSchema) -> ToSchema a
forall (name :: Symbol) tables.
(KnownSymbol name, Typeable tables, PagingTable tables) =>
Typeable (MultiTablePagingState name tables)
forall (name :: Symbol) tables.
(KnownSymbol name, Typeable tables, PagingTable tables) =>
Proxy (MultiTablePagingState name tables)
-> Declare (Definitions Schema) NamedSchema
$cdeclareNamedSchema :: forall (name :: Symbol) tables.
(KnownSymbol name, Typeable tables, PagingTable tables) =>
Proxy (MultiTablePagingState name tables)
-> Declare (Definitions Schema) NamedSchema
declareNamedSchema :: Proxy (MultiTablePagingState name tables)
-> Declare (Definitions Schema) NamedSchema
S.ToSchema) via Schema (MultiTablePagingState name tables)

encodePagingState :: (PagingTable tables) => MultiTablePagingState name tables -> ByteString
encodePagingState :: forall tables (name :: Symbol).
PagingTable tables =>
MultiTablePagingState name tables -> ByteString
encodePagingState (MultiTablePagingState tables
table Maybe ByteString
state) =
  let encodedTable :: Word8
encodedTable = tables -> Word8
forall t. PagingTable t => t -> Word8
encodePagingTable tables
table
      encodedState :: ByteString
encodedState = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" Maybe ByteString
state
   in Word8 -> ByteString -> ByteString
BS.cons Word8
encodedTable ByteString
encodedState

parsePagingState :: (PagingTable tables) => ByteString -> Either String (MultiTablePagingState name tables)
parsePagingState :: forall tables (name :: Symbol).
PagingTable tables =>
ByteString -> Either String (MultiTablePagingState name tables)
parsePagingState = Parser (MultiTablePagingState name tables)
-> ByteString -> Either String (MultiTablePagingState name tables)
forall a. Parser a -> ByteString -> Either String a
AB.parseOnly Parser (MultiTablePagingState name tables)
forall tables (name :: Symbol).
PagingTable tables =>
Parser (MultiTablePagingState name tables)
pagingStateParser

pagingStateParser :: (PagingTable tables) => AB.Parser (MultiTablePagingState name tables)
pagingStateParser :: forall tables (name :: Symbol).
PagingTable tables =>
Parser (MultiTablePagingState name tables)
pagingStateParser = do
  tables
table <- Parser Word8
AB.anyWord8 Parser Word8
-> (Word8 -> Parser ByteString tables) -> Parser ByteString tables
forall a b.
Parser ByteString a
-> (a -> Parser ByteString b) -> Parser ByteString b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Word8 -> Parser ByteString tables
forall t (m :: * -> *).
(PagingTable t, MonadFail m) =>
Word8 -> m t
forall (m :: * -> *). MonadFail m => Word8 -> m tables
decodePagingTable
  Maybe ByteString
state <- (Parser ByteString ()
forall t. Chunk t => Parser t ()
AB.endOfInput Parser ByteString ()
-> Maybe ByteString -> Parser ByteString (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Maybe ByteString
forall a. Maybe a
Nothing) Parser ByteString (Maybe ByteString)
-> Parser ByteString (Maybe ByteString)
-> Parser ByteString (Maybe ByteString)
forall a.
Parser ByteString a -> Parser ByteString a -> Parser ByteString a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> Parser ByteString ByteString
-> Parser ByteString (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
AB.takeByteString Parser ByteString (Maybe ByteString)
-> Parser ByteString () -> Parser ByteString (Maybe ByteString)
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
forall t. Chunk t => Parser t ()
AB.endOfInput)
  MultiTablePagingState name tables
-> Parser (MultiTablePagingState name tables)
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MultiTablePagingState name tables
 -> Parser (MultiTablePagingState name tables))
-> MultiTablePagingState name tables
-> Parser (MultiTablePagingState name tables)
forall a b. (a -> b) -> a -> b
$ tables -> Maybe ByteString -> MultiTablePagingState name tables
forall (name :: Symbol) tables.
tables -> Maybe ByteString -> MultiTablePagingState name tables
MultiTablePagingState tables
table Maybe ByteString
state

instance (PagingTable tables) => ToHttpApiData (MultiTablePagingState name tables) where
  toQueryParam :: MultiTablePagingState name tables -> Text
toQueryParam = (ByteString -> Text
Text.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Base64Url.encode) (ByteString -> Text)
-> (MultiTablePagingState name tables -> ByteString)
-> MultiTablePagingState name tables
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MultiTablePagingState name tables -> ByteString
forall tables (name :: Symbol).
PagingTable tables =>
MultiTablePagingState name tables -> ByteString
encodePagingState

instance (PagingTable tables) => FromHttpApiData (MultiTablePagingState name tables) where
  parseQueryParam :: Text -> Either Text (MultiTablePagingState name tables)
parseQueryParam =
    (String -> Text)
-> Either String (MultiTablePagingState name tables)
-> Either Text (MultiTablePagingState name tables)
forall a c b. (a -> c) -> Either a b -> Either c b
mapLeft String -> Text
Text.pack
      (Either String (MultiTablePagingState name tables)
 -> Either Text (MultiTablePagingState name tables))
-> (Text -> Either String (MultiTablePagingState name tables))
-> Text
-> Either Text (MultiTablePagingState name tables)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Either String (MultiTablePagingState name tables)
forall tables (name :: Symbol).
PagingTable tables =>
ByteString -> Either String (MultiTablePagingState name tables)
parsePagingState (ByteString -> Either String (MultiTablePagingState name tables))
-> (Text -> Either String ByteString)
-> Text
-> Either String (MultiTablePagingState name tables)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ByteString -> Either String ByteString
Base64Url.decode (ByteString -> Either String ByteString)
-> (Text -> ByteString) -> Text -> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
Text.encodeUtf8))

-- | A class for values that can be encoded with a single byte. Used to add a
-- byte of extra information to the paging state in order to recover the table
-- information from a paging token.
class PagingTable t where
  -- Using 'Word8' because 256 tables ought to be enough.
  encodePagingTable :: t -> Word8
  decodePagingTable :: (MonadFail m) => Word8 -> m t

instance (PagingTable tables, KnownSymbol name) => ToSchema (MultiTablePagingState name tables) where
  schema :: ValueSchema NamedSwaggerDoc (MultiTablePagingState name tables)
schema =
    (ByteString -> Text
Text.decodeUtf8 (ByteString -> Text)
-> (MultiTablePagingState name tables -> ByteString)
-> MultiTablePagingState name tables
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Base64Url.encode (ByteString -> ByteString)
-> (MultiTablePagingState name tables -> ByteString)
-> MultiTablePagingState name tables
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MultiTablePagingState name tables -> ByteString
forall tables (name :: Symbol).
PagingTable tables =>
MultiTablePagingState name tables -> ByteString
encodePagingState)
      (MultiTablePagingState name tables -> Text)
-> SchemaP
     NamedSwaggerDoc
     Value
     Value
     Text
     (MultiTablePagingState name tables)
-> ValueSchema NamedSwaggerDoc (MultiTablePagingState name tables)
forall (p :: * -> * -> *) a a' b.
Profunctor p =>
(a -> a') -> p a' b -> p a b
.= Text
-> (Text -> Either String (MultiTablePagingState name tables))
-> SchemaP
     NamedSwaggerDoc
     Value
     Value
     Text
     (MultiTablePagingState name tables)
forall a.
Text
-> (Text -> Either String a)
-> SchemaP NamedSwaggerDoc Value Value Text a
parsedText
        (String -> Text
Text.pack (Proxy name -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @name)) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_PagingState")
        (ByteString -> Either String (MultiTablePagingState name tables)
forall tables (name :: Symbol).
PagingTable tables =>
ByteString -> Either String (MultiTablePagingState name tables)
parsePagingState (ByteString -> Either String (MultiTablePagingState name tables))
-> (Text -> Either String ByteString)
-> Text
-> Either String (MultiTablePagingState name tables)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ByteString -> Either String ByteString
Base64Url.decode (ByteString -> Either String ByteString)
-> (Text -> ByteString) -> Text -> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
Text.encodeUtf8))