{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

-- | This module exports components from Cassandra's Database.CQL.IO, adding a few functions we find useful, that are built on top of it.
module Cassandra.Exec
  ( params,
    paramsP,
    x5,
    x1,
    paginateC,
    PageWithState (..),
    paginateWithState,
    paginateWithStateC,
    paramsPagingState,
    pwsHasMore,
    module C,
  )
where

import Cassandra.CQL (Consistency, R)
import Control.Monad.Catch
import Data.Conduit
-- We only use these locally.
import Database.CQL.IO (ProtocolError (UnexpectedResponse), RetrySettings, RunQ, defRetrySettings, eagerRetrySettings, getResult, hrHost, hrResponse, runQ)
-- Things we just import and re-export.
import Database.CQL.IO as C (BatchM, Client, ClientState, MonadClient, Page (..), PrepQuery, Row, addPrepQuery, addQuery, adjustConsistency, adjustResponseTimeout, adjustSendTimeout, batch, emptyPage, init, liftClient, localState, paginate, prepared, query, query1, queryString, retry, runClient, schema, setConsistency, setSerialConsistency, setType, shutdown, trans, write)
import Database.CQL.Protocol (Error, QueryParams (QueryParams), Tuple, pagingState)
import Database.CQL.Protocol qualified as Protocol
import Imports hiding (init)

params :: Consistency -> a -> QueryParams a
params :: forall a. Consistency -> a -> QueryParams a
params Consistency
c a
p = Consistency
-> Bool
-> a
-> Maybe Int32
-> Maybe PagingState
-> Maybe SerialConsistency
-> Maybe Bool
-> QueryParams a
forall a.
Consistency
-> Bool
-> a
-> Maybe Int32
-> Maybe PagingState
-> Maybe SerialConsistency
-> Maybe Bool
-> QueryParams a
QueryParams Consistency
c Bool
False a
p Maybe Int32
forall a. Maybe a
Nothing Maybe PagingState
forall a. Maybe a
Nothing Maybe SerialConsistency
forall a. Maybe a
Nothing Maybe Bool
forall a. Maybe a
Nothing
{-# INLINE params #-}

paramsP :: Consistency -> a -> Int32 -> QueryParams a
paramsP :: forall a. Consistency -> a -> Int32 -> QueryParams a
paramsP Consistency
c a
p Int32
n = Consistency
-> Bool
-> a
-> Maybe Int32
-> Maybe PagingState
-> Maybe SerialConsistency
-> Maybe Bool
-> QueryParams a
forall a.
Consistency
-> Bool
-> a
-> Maybe Int32
-> Maybe PagingState
-> Maybe SerialConsistency
-> Maybe Bool
-> QueryParams a
QueryParams Consistency
c Bool
False a
p (Int32 -> Maybe Int32
forall a. a -> Maybe a
Just Int32
n) Maybe PagingState
forall a. Maybe a
Nothing Maybe SerialConsistency
forall a. Maybe a
Nothing Maybe Bool
forall a. Maybe a
Nothing
{-# INLINE paramsP #-}

-- | 'x5' must only be used for idempotent queries, or for cases
-- when a duplicate write has no severe consequences in
-- the context of the application's data model.
-- For more info see e.g.
-- https://docs.datastax.com/en/developer/java-driver//3.6/manual/idempotence/
--
-- The eager retry policy permits 5 retries with exponential
-- backoff (base-2) with an initial delay of 100ms, i.e. the
-- retries will be performed with 100ms, 200ms, 400ms, 800ms
-- and 1.6s delay, respectively, for a maximum delay of ~3s.
x5 :: RetrySettings
x5 :: RetrySettings
x5 = RetrySettings
eagerRetrySettings
{-# INLINE x5 #-}

-- | Single, immediate retry, always safe.
-- The 'defRetryHandlers' used are safe also with non-idempotent queries.
x1 :: RetrySettings
x1 :: RetrySettings
x1 = RetrySettings
defRetrySettings
{-# INLINE x1 #-}

data CassandraError
  = Cassandra !Error
  | Comm !IOException
  | InvalidData !Text
  | Other !SomeException
  deriving (Int -> CassandraError -> ShowS
[CassandraError] -> ShowS
CassandraError -> String
(Int -> CassandraError -> ShowS)
-> (CassandraError -> String)
-> ([CassandraError] -> ShowS)
-> Show CassandraError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CassandraError -> ShowS
showsPrec :: Int -> CassandraError -> ShowS
$cshow :: CassandraError -> String
show :: CassandraError -> String
$cshowList :: [CassandraError] -> ShowS
showList :: [CassandraError] -> ShowS
Show)

-- | Stream results of a query.
--
-- You can execute this conduit by doing @transPipe (runClient ...)@.
paginateC ::
  (Tuple a, Tuple b, RunQ q, MonadClient m) =>
  q R a b ->
  QueryParams a ->
  RetrySettings ->
  ConduitM () [b] m ()
paginateC :: forall a b (q :: * -> * -> * -> *) (m :: * -> *).
(Tuple a, Tuple b, RunQ q, MonadClient m) =>
q R a b -> QueryParams a -> RetrySettings -> ConduitM () [b] m ()
paginateC q R a b
q QueryParams a
p RetrySettings
r = Page b -> ConduitT () [b] m ()
go (Page b -> ConduitT () [b] m ())
-> ConduitT () [b] m (Page b) -> ConduitT () [b] m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Page b) -> ConduitT () [b] m (Page b)
forall (m :: * -> *) a. Monad m => m a -> ConduitT () [b] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RetrySettings -> m (Page b) -> m (Page b)
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
r (q R a b -> QueryParams a -> m (Page b)
forall (m :: * -> *) a b (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, Tuple b, RunQ q) =>
q R a b -> QueryParams a -> m (Page b)
paginate q R a b
q QueryParams a
p))
  where
    go :: Page b -> ConduitT () [b] m ()
go Page b
page = do
      Bool -> ConduitT () [b] m () -> ConduitT () [b] m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([b] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Page b -> [b]
forall a. Page a -> [a]
result Page b
page)) (ConduitT () [b] m () -> ConduitT () [b] m ())
-> ConduitT () [b] m () -> ConduitT () [b] m ()
forall a b. (a -> b) -> a -> b
$
        [b] -> ConduitT () [b] m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (Page b -> [b]
forall a. Page a -> [a]
result Page b
page)
      Bool -> ConduitT () [b] m () -> ConduitT () [b] m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Page b -> Bool
forall a. Page a -> Bool
hasMore Page b
page) (ConduitT () [b] m () -> ConduitT () [b] m ())
-> ConduitT () [b] m () -> ConduitT () [b] m ()
forall a b. (a -> b) -> a -> b
$
        Page b -> ConduitT () [b] m ()
go (Page b -> ConduitT () [b] m ())
-> ConduitT () [b] m (Page b) -> ConduitT () [b] m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Page b) -> ConduitT () [b] m (Page b)
forall (m :: * -> *) a. Monad m => m a -> ConduitT () [b] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RetrySettings -> m (Page b) -> m (Page b)
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
r (Client (Page b) -> m (Page b)
forall a. Client a -> m a
forall (m :: * -> *) a. MonadClient m => Client a -> m a
liftClient (Page b -> Client (Page b)
forall a. Page a -> Client (Page a)
nextPage Page b
page)))

data PageWithState a = PageWithState
  { forall a. PageWithState a -> [a]
pwsResults :: [a],
    forall a. PageWithState a -> Maybe PagingState
pwsState :: Maybe Protocol.PagingState
  }
  deriving ((forall a b. (a -> b) -> PageWithState a -> PageWithState b)
-> (forall a b. a -> PageWithState b -> PageWithState a)
-> Functor PageWithState
forall a b. a -> PageWithState b -> PageWithState a
forall a b. (a -> b) -> PageWithState a -> PageWithState b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> PageWithState a -> PageWithState b
fmap :: forall a b. (a -> b) -> PageWithState a -> PageWithState b
$c<$ :: forall a b. a -> PageWithState b -> PageWithState a
<$ :: forall a b. a -> PageWithState b -> PageWithState a
Functor)

-- | Like 'paginate' but exposes the paging state. This paging state can be
-- serialised and sent to consumers of the API. The state is not good for long
-- term storage as the bytestring format may change when the schema of a table
-- changes or when cassandra is upgraded.
paginateWithState :: (MonadClient m, Tuple a, Tuple b, RunQ q) => q R a b -> QueryParams a -> m (PageWithState b)
paginateWithState :: forall (m :: * -> *) a b (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, Tuple b, RunQ q) =>
q R a b -> QueryParams a -> m (PageWithState b)
paginateWithState q R a b
q QueryParams a
p = do
  let p' :: QueryParams a
p' = QueryParams a
p {Protocol.pageSize = Protocol.pageSize p <|> Just 10000}
  HostResponse R a b
r <- q R a b -> QueryParams a -> m (HostResponse R a b)
forall (m :: * -> *) a b k.
(MonadClient m, Tuple a, Tuple b) =>
q k a b -> QueryParams a -> m (HostResponse k a b)
forall (q :: * -> * -> * -> *) (m :: * -> *) a b k.
(RunQ q, MonadClient m, Tuple a, Tuple b) =>
q k a b -> QueryParams a -> m (HostResponse k a b)
runQ q R a b
q QueryParams a
p'
  HostResponse R a b -> m (Result R a b)
forall (m :: * -> *) k a b.
MonadThrow m =>
HostResponse k a b -> m (Result k a b)
getResult HostResponse R a b
r m (Result R a b)
-> (Result R a b -> m (PageWithState b)) -> m (PageWithState b)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Protocol.RowsResult MetaData
m [b]
b ->
      PageWithState b -> m (PageWithState b)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PageWithState b -> m (PageWithState b))
-> PageWithState b -> m (PageWithState b)
forall a b. (a -> b) -> a -> b
$ [b] -> Maybe PagingState -> PageWithState b
forall a. [a] -> Maybe PagingState -> PageWithState a
PageWithState [b]
b (MetaData -> Maybe PagingState
pagingState MetaData
m)
    Result R a b
_ -> ProtocolError -> m (PageWithState b)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (ProtocolError -> m (PageWithState b))
-> ProtocolError -> m (PageWithState b)
forall a b. (a -> b) -> a -> b
$ Host -> Response R a b -> ProtocolError
forall k a b. Host -> Response k a b -> ProtocolError
UnexpectedResponse (HostResponse R a b -> Host
forall k a b. HostResponse k a b -> Host
hrHost HostResponse R a b
r) (HostResponse R a b -> Response R a b
forall k a b. HostResponse k a b -> Response k a b
hrResponse HostResponse R a b
r)

-- | Like 'paginateWithState' but returns a conduit instead of one page.
--
-- This can be used with 'paginateWithState' like this:
-- @
--   main :: IO ()
--   main = do
--     runConduit $
--       paginateWithStateC getUsers
--       .| mapC doSomethingWithAPageOfUsers
--   where
--     getUsers state = paginateWithState getUsersQuery (paramsPagingState Quorum () 10000 state)
-- @
paginateWithStateC :: forall m a. (Monad m) => (Maybe Protocol.PagingState -> m (PageWithState a)) -> ConduitT () [a] m ()
paginateWithStateC :: forall (m :: * -> *) a.
Monad m =>
(Maybe PagingState -> m (PageWithState a)) -> ConduitT () [a] m ()
paginateWithStateC Maybe PagingState -> m (PageWithState a)
getPage = do
  PageWithState a -> ConduitT () [a] m ()
go (PageWithState a -> ConduitT () [a] m ())
-> ConduitT () [a] m (PageWithState a) -> ConduitT () [a] m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (PageWithState a) -> ConduitT () [a] m (PageWithState a)
forall (m :: * -> *) a. Monad m => m a -> ConduitT () [a] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe PagingState -> m (PageWithState a)
getPage Maybe PagingState
forall a. Maybe a
Nothing)
  where
    go :: PageWithState a -> ConduitT () [a] m ()
    go :: PageWithState a -> ConduitT () [a] m ()
go PageWithState a
page = do
      Bool -> ConduitT () [a] m () -> ConduitT () [a] m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null PageWithState a
page.pwsResults) (ConduitT () [a] m () -> ConduitT () [a] m ())
-> ConduitT () [a] m () -> ConduitT () [a] m ()
forall a b. (a -> b) -> a -> b
$
        [a] -> ConduitT () [a] m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (PageWithState a
page.pwsResults)
      Bool -> ConduitT () [a] m () -> ConduitT () [a] m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (PageWithState a -> Bool
forall a. PageWithState a -> Bool
pwsHasMore PageWithState a
page) (ConduitT () [a] m () -> ConduitT () [a] m ())
-> ConduitT () [a] m () -> ConduitT () [a] m ()
forall a b. (a -> b) -> a -> b
$
        PageWithState a -> ConduitT () [a] m ()
go (PageWithState a -> ConduitT () [a] m ())
-> ConduitT () [a] m (PageWithState a) -> ConduitT () [a] m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (PageWithState a) -> ConduitT () [a] m (PageWithState a)
forall (m :: * -> *) a. Monad m => m a -> ConduitT () [a] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe PagingState -> m (PageWithState a)
getPage PageWithState a
page.pwsState)

paramsPagingState :: Consistency -> a -> Int32 -> Maybe Protocol.PagingState -> QueryParams a
paramsPagingState :: forall a.
Consistency -> a -> Int32 -> Maybe PagingState -> QueryParams a
paramsPagingState Consistency
c a
p Int32
n Maybe PagingState
state = Consistency
-> Bool
-> a
-> Maybe Int32
-> Maybe PagingState
-> Maybe SerialConsistency
-> Maybe Bool
-> QueryParams a
forall a.
Consistency
-> Bool
-> a
-> Maybe Int32
-> Maybe PagingState
-> Maybe SerialConsistency
-> Maybe Bool
-> QueryParams a
QueryParams Consistency
c Bool
False a
p (Int32 -> Maybe Int32
forall a. a -> Maybe a
Just Int32
n) Maybe PagingState
state Maybe SerialConsistency
forall a. Maybe a
Nothing Maybe Bool
forall a. Maybe a
Nothing
{-# INLINE paramsPagingState #-}

pwsHasMore :: PageWithState a -> Bool
pwsHasMore :: forall a. PageWithState a -> Bool
pwsHasMore = Maybe PagingState -> Bool
forall a. Maybe a -> Bool
isJust (Maybe PagingState -> Bool)
-> (PageWithState a -> Maybe PagingState)
-> PageWithState a
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PageWithState a -> Maybe PagingState
forall a. PageWithState a -> Maybe PagingState
pwsState