-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

-- | Temporary exclusive claims on 'Text'ual values which may be subject
-- to contention, i.e. where strong guarantees on uniqueness are desired.
module Wire.UserStore.Unique
  ( withClaim,
    deleteClaim,
    lookupClaims,

    -- * Re-exports
    Timeout,
    TimeoutUnit (..),
    (#),
  )
where

import Cassandra as C
import Control.Concurrent.Timeout
import Data.Id
import Data.Timeout
import Imports

-- | Obtain a (temporary) exclusive claim on a 'Text' value for some
-- 'Id'entifier. The claim expires after the provided timeout, whether
-- it was successful or not. Contention can thus render the value
-- unavailable until the timeout expires if no contender succeeds.
-- The given 'IO' computation is only run when the claim was successful
-- and is responsible for turning the temporary claim into permanent
-- ownership, if desired.
withClaim ::
  -- | The 'Id' associated with the claim.
  Id a ->
  -- | The value on which to acquire the claim.
  Text ->
  -- | The minimum timeout (i.e. duration) of the claim.
  Timeout ->
  -- | The computation to run with a successful claim.
  Client b ->
  -- | 'Just b' if the claim was successful and the 'IO'
  --   computation completed within the given timeout.
  Client (Maybe b)
withClaim :: forall {k} (a :: k) b.
Id a -> Text -> Timeout -> Client b -> Client (Maybe b)
withClaim Id a
u Text
v Timeout
t Client b
act = do
  [Id a]
claims <- Text -> Client [Id a]
forall {k} (m :: * -> *) (a :: k).
MonadClient m =>
Text -> m [Id a]
lookupClaims Text
v
  case [Id a]
claims of
    [] -> Client (Maybe b)
claim -- Free
    [Id a
u'] | Id a
u Id a -> Id a -> Bool
forall a. Eq a => a -> a -> Bool
== Id a
u' -> Client (Maybe b)
claim -- Claimed by 'u' (retries are allowed).
    [Id a]
_ -> Maybe b -> Client (Maybe b)
forall a. a -> Client a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe b
forall a. Maybe a
Nothing -- Conflicting claims, TTL must expire.
  where
    -- [Note: Guarantees]
    claim :: Client (Maybe b)
claim = do
      let ttl :: Int32
ttl = Int32 -> Int32 -> Int32
forall a. Ord a => a -> a -> a
max Int32
minTtl (Word64 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Timeout
t Timeout -> TimeoutUnit -> Word64
#> TimeoutUnit
Second))
      RetrySettings -> Client () -> Client ()
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x5 (Client () -> Client ()) -> Client () -> Client ()
forall a b. (a -> b) -> a -> b
$ PrepQuery W (Int32, Set (Id a), Text) ()
-> QueryParams (Int32, Set (Id a), Text) -> Client ()
forall (m :: * -> *) a (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, RunQ q) =>
q W a () -> QueryParams a -> m ()
write PrepQuery W (Int32, Set (Id a), Text) ()
forall {k} (a :: k). PrepQuery W (Int32, Set (Id a), Text) ()
upsertQuery (QueryParams (Int32, Set (Id a), Text) -> Client ())
-> QueryParams (Int32, Set (Id a), Text) -> Client ()
forall a b. (a -> b) -> a -> b
$ Consistency
-> (Int32, Set (Id a), Text)
-> QueryParams (Int32, Set (Id a), Text)
forall a. Consistency -> a -> QueryParams a
params Consistency
LocalQuorum (Int32
ttl Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
2, [Id a] -> Set (Id a)
forall a. [a] -> Set a
C.Set [Id a
u], Text
v)
      Bool
claimed <- ([Id a] -> [Id a] -> Bool
forall a. Eq a => a -> a -> Bool
== [Id a
u]) ([Id a] -> Bool) -> Client [Id a] -> Client Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Text -> Client [Id a]
forall {k} (m :: * -> *) (a :: k).
MonadClient m =>
Text -> m [Id a]
lookupClaims Text
v
      if Bool
claimed
        then do
          IO b
act' <- Client b -> Client (IO b)
forall a. Client a -> Client (IO a)
clientToIO Client b
act
          IO (Maybe b) -> Client (Maybe b)
forall a. IO a -> Client a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe b) -> Client (Maybe b))
-> IO (Maybe b) -> Client (Maybe b)
forall a b. (a -> b) -> a -> b
$ Timeout -> IO b -> IO (Maybe b)
forall (μ :: * -> *) α.
MonadBase IO μ =>
Timeout -> IO α -> μ (Maybe α)
timeout (Int32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
ttl Word64 -> TimeoutUnit -> Timeout
# TimeoutUnit
Second) IO b
act'
        else Maybe b -> Client (Maybe b)
forall a. a -> Client a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe b
forall a. Maybe a
Nothing
    upsertQuery :: PrepQuery W (Int32, C.Set (Id a), Text) ()
    upsertQuery :: forall {k} (a :: k). PrepQuery W (Int32, Set (Id a), Text) ()
upsertQuery = PrepQuery W (Int32, Set (Id a), Text) ()
"UPDATE unique_claims USING TTL ? SET claims = claims + ? WHERE value = ?"

deleteClaim ::
  -- | The 'Id' associated with the claim.
  Id a ->
  -- | The value on which to acquire the claim.
  Text ->
  -- | The minimum timeout (i.e. duration) of the rest of the claim.  (Each
  --   claim can have more than one claimer (even though this is a feature we
  --   never use), so removing a claim is an update operation on the database.
  --   Therefore, we reset the TTL the same way we reset it in 'withClaim'.)
  Timeout ->
  Client ()
deleteClaim :: forall {k} (a :: k). Id a -> Text -> Timeout -> Client ()
deleteClaim Id a
u Text
v Timeout
t = do
  let ttl :: Int32
ttl = Int32 -> Int32 -> Int32
forall a. Ord a => a -> a -> a
max Int32
minTtl (Word64 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Timeout
t Timeout -> TimeoutUnit -> Word64
#> TimeoutUnit
Second))
  RetrySettings -> Client () -> Client ()
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x5 (Client () -> Client ()) -> Client () -> Client ()
forall a b. (a -> b) -> a -> b
$ PrepQuery W (Int32, Set (Id a), Text) ()
-> QueryParams (Int32, Set (Id a), Text) -> Client ()
forall (m :: * -> *) a (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, RunQ q) =>
q W a () -> QueryParams a -> m ()
write PrepQuery W (Int32, Set (Id a), Text) ()
forall {k} (a :: k). PrepQuery W (Int32, Set (Id a), Text) ()
cql (QueryParams (Int32, Set (Id a), Text) -> Client ())
-> QueryParams (Int32, Set (Id a), Text) -> Client ()
forall a b. (a -> b) -> a -> b
$ Consistency
-> (Int32, Set (Id a), Text)
-> QueryParams (Int32, Set (Id a), Text)
forall a. Consistency -> a -> QueryParams a
params Consistency
LocalQuorum (Int32
ttl Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
2, [Id a] -> Set (Id a)
forall a. [a] -> Set a
C.Set [Id a
u], Text
v)
  where
    cql :: PrepQuery W (Int32, C.Set (Id a), Text) ()
    cql :: forall {k} (a :: k). PrepQuery W (Int32, Set (Id a), Text) ()
cql = PrepQuery W (Int32, Set (Id a), Text) ()
"UPDATE unique_claims USING TTL ? SET claims = claims - ? WHERE value = ?"

-- | Lookup the current claims on a value.
lookupClaims :: (MonadClient m) => Text -> m [Id a]
lookupClaims :: forall {k} (m :: * -> *) (a :: k).
MonadClient m =>
Text -> m [Id a]
lookupClaims Text
v =
  (Maybe (Identity (Set (Id a))) -> [Id a])
-> m (Maybe (Identity (Set (Id a)))) -> m [Id a]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Id a]
-> (Identity (Set (Id a)) -> [Id a])
-> Maybe (Identity (Set (Id a)))
-> [Id a]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (Set (Id a) -> [Id a]
forall a. Set a -> [a]
fromSet (Set (Id a) -> [Id a])
-> (Identity (Set (Id a)) -> Set (Id a))
-> Identity (Set (Id a))
-> [Id a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Set (Id a)) -> Set (Id a)
forall a. Identity a -> a
runIdentity)) (m (Maybe (Identity (Set (Id a)))) -> m [Id a])
-> m (Maybe (Identity (Set (Id a)))) -> m [Id a]
forall a b. (a -> b) -> a -> b
$
    RetrySettings
-> m (Maybe (Identity (Set (Id a))))
-> m (Maybe (Identity (Set (Id a))))
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x1 (m (Maybe (Identity (Set (Id a))))
 -> m (Maybe (Identity (Set (Id a)))))
-> m (Maybe (Identity (Set (Id a))))
-> m (Maybe (Identity (Set (Id a))))
forall a b. (a -> b) -> a -> b
$
      PrepQuery R (Identity Text) (Identity (Set (Id a)))
-> QueryParams (Identity Text) -> m (Maybe (Identity (Set (Id a))))
forall (m :: * -> *) a b (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, Tuple b, RunQ q) =>
q R a b -> QueryParams a -> m (Maybe b)
query1 PrepQuery R (Identity Text) (Identity (Set (Id a)))
forall {k} (a :: k).
PrepQuery R (Identity Text) (Identity (Set (Id a)))
cql (QueryParams (Identity Text) -> m (Maybe (Identity (Set (Id a)))))
-> QueryParams (Identity Text) -> m (Maybe (Identity (Set (Id a))))
forall a b. (a -> b) -> a -> b
$
        Consistency -> Identity Text -> QueryParams (Identity Text)
forall a. Consistency -> a -> QueryParams a
params Consistency
LocalQuorum (Text -> Identity Text
forall a. a -> Identity a
Identity Text
v)
  where
    cql :: PrepQuery R (Identity Text) (Identity (C.Set (Id a)))
    cql :: forall {k} (a :: k).
PrepQuery R (Identity Text) (Identity (Set (Id a)))
cql = PrepQuery R (Identity Text) (Identity (Set (Id a)))
"SELECT claims FROM unique_claims WHERE value = ?"

clientToIO :: Client a -> Client (IO a)
clientToIO :: forall a. Client a -> Client (IO a)
clientToIO Client a
act = do
  ClientState
s <- Client ClientState
forall r (m :: * -> *). MonadReader r m => m r
ask
  IO a -> Client (IO a)
forall a. a -> Client a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IO a -> Client (IO a)) -> IO a -> Client (IO a)
forall a b. (a -> b) -> a -> b
$ ClientState -> Client a -> IO a
forall (m :: * -> *) a. MonadIO m => ClientState -> Client a -> m a
runClient ClientState
s Client a
act

minTtl :: Int32
minTtl :: Int32
minTtl = Int32
60 -- Seconds

-- [Note: Guarantees]
-- ~~~~~~~~~~~~~~~~~~
-- Correct operation (i.e. uniqueness of claims) rests on the following
-- properties of the implementation, which must have a negligible probability
-- of failure:
--
-- 1. CRDT properties of CQL Sets with the only operation of element addition,
--    in particular that all element additions are preserved in a concurrent
--    setting (cf. https://aphyr.com/posts/294-jepsen-cassandra).
--
-- 2. Strong read consistency (QUORUM write followed by QUORUM read)
--    combined with the conflict-free property of Set element insertions (1)
--    ensures that, of any two concurrent claims, at least one of them is
--    bound to see both inserted elements, hence failing the claim.
--
-- 3. The 'IO' computation that is run while holding a claim must complete
--    within the given timeout. The effective timeout (i.e. the row TTL)
--    is doubled, for an extra safety margin.