{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Cassandra.MigrateSchema (migrateSchema) where

import Cassandra (Client, Consistency (All, One), Keyspace (Keyspace), PrepQuery, QueryString (QueryString), R, S, Version (V4), W, params, query, query1, retry, runClient, write, x1)
import Cassandra.Schema
import Cassandra.Settings (Policy, defSettings, initialContactsPlain, setConnectTimeout, setContacts, setLogger, setMaxConnections, setPolicy, setPoolStripes, setPortNumber, setProtocolVersion, setResponseTimeout, setSendTimeout)
import Cassandra.Util (initCassandra)
import Control.Retry
import Data.List.NonEmpty qualified as NonEmpty
import Data.Text (pack)
import Data.Text.Lazy (fromStrict)
import Data.Time.Clock
import Data.UUID (UUID)
import Database.CQL.IO (Policy (Policy, acceptable, current, display, hostCount, onEvent, select, setup), schema)
import Database.CQL.IO.Tinylog qualified as CT
import Imports hiding (All, fromString, init, intercalate, log)
import System.Logger qualified as Log

-- FUTUREWORK: We could use the System.Logger.Class here in the future, but we don't have a ReaderT IO here (yet)
migrateSchema :: Log.Logger -> MigrationOpts -> [Migration] -> IO ()
migrateSchema :: Logger -> MigrationOpts -> [Migration] -> IO ()
migrateSchema Logger
l MigrationOpts
o [Migration]
ms = do
  hosts <- Text -> IO (NonEmpty [Char])
forall (m :: * -> *). MonadIO m => Text -> m (NonEmpty [Char])
initialContactsPlain (Text -> IO (NonEmpty [Char])) -> Text -> IO (NonEmpty [Char])
forall a b. (a -> b) -> a -> b
$ [Char] -> Text
pack (MigrationOpts -> [Char]
migHost MigrationOpts
o)
  let cqlSettings =
        Logger -> Settings -> Settings
setLogger (Logger -> Logger
CT.mkLogger Logger
l)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [[Char]] -> Settings -> Settings
setContacts (NonEmpty [Char] -> [Char]
forall a. NonEmpty a -> a
NonEmpty.head NonEmpty [Char]
hosts) (NonEmpty [Char] -> [[Char]]
forall a. NonEmpty a -> [a]
NonEmpty.tail NonEmpty [Char]
hosts)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PortNumber -> Settings -> Settings
setPortNumber (Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> PortNumber) -> Word16 -> PortNumber
forall a b. (a -> b) -> a -> b
$ MigrationOpts -> Word16
migPort MigrationOpts
o)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Settings -> Settings
setMaxConnections Int
1
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Settings -> Settings
setPoolStripes Int
1
          -- 'migrationPolicy' ensures we only talk to one host for all queries
          -- required for correct functioning of 'waitForSchemaConsistency'
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Policy -> Settings -> Settings
setPolicy IO Policy
migrationPolicy
          -- use higher timeouts on schema migrations to reduce the probability
          -- of a timeout happening during 'migAction' or 'metaInsert',
          -- as that can lead to a state where schema migrations cannot be re-run
          -- without manual action.
          -- (due to e.g. "cannot create table X, already exists" errors)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Settings -> Settings
setConnectTimeout NominalDiffTime
20
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Settings -> Settings
setSendTimeout NominalDiffTime
20
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Settings -> Settings
setResponseTimeout NominalDiffTime
50
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Settings -> Settings
setProtocolVersion Version
V4
          (Settings -> Settings) -> Settings -> Settings
forall a b. (a -> b) -> a -> b
$ Settings
defSettings
  cas <- initCassandra cqlSettings o.migTlsCa l
  runClient cas $ do
    let keyspace = Text -> Keyspace
Keyspace (Text -> Keyspace)
-> (MigrationOpts -> Text) -> MigrationOpts -> Keyspace
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MigrationOpts -> Text
migKeyspace (MigrationOpts -> Keyspace) -> MigrationOpts -> Keyspace
forall a b. (a -> b) -> a -> b
$ MigrationOpts
o
    when (migReset o) $ do
      info "Dropping keyspace."
      void $ schema (dropKeyspace keyspace) (params All ())
    createKeyspace keyspace (migRepl o)
    useKeyspace keyspace
    void $ schema metaCreate (params All ())
    migrations <- newer <$> schemaVersion
    if null migrations
      then info "No new migrations."
      else info "New migrations found."
    forM_ migrations $ \Migration {Int32
Text
Client ()
migVersion :: Int32
migText :: Text
migAction :: Client ()
migAction :: Migration -> Client ()
migText :: Migration -> Text
migVersion :: Migration -> Int32
..} -> do
      Text -> Client ()
info (Text -> Client ()) -> Text -> Client ()
forall a b. (a -> b) -> a -> b
$ Text
"[" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Char] -> Text
pack (Int32 -> [Char]
forall a. Show a => a -> [Char]
show Int32
migVersion) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"] " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
migText
      Client ()
migAction
      now <- IO UTCTime -> Client UTCTime
forall a. IO a -> Client a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
      write metaInsert (params All (migVersion, migText, now))
      info "Waiting for schema version consistency across peers..."
      waitForSchemaConsistency
      info "... done waiting."
  where
    newer :: Maybe Int32 -> [Migration]
newer Maybe Int32
v =
      (Migration -> Bool) -> [Migration] -> [Migration]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Int32 -> Bool)
-> (Int32 -> Int32 -> Bool) -> Maybe Int32 -> Int32 -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Bool -> Int32 -> Bool
forall a b. a -> b -> a
const Bool
False) Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Maybe Int32
v (Int32 -> Bool) -> (Migration -> Int32) -> Migration -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Migration -> Int32
migVersion)
        ([Migration] -> [Migration])
-> ([Migration] -> [Migration]) -> [Migration] -> [Migration]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Migration -> Migration -> Ordering) -> [Migration] -> [Migration]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (\Migration
x Migration
y -> Migration -> Int32
migVersion Migration
x Int32 -> Int32 -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Migration -> Int32
migVersion Migration
y)
        ([Migration] -> [Migration]) -> [Migration] -> [Migration]
forall a b. (a -> b) -> a -> b
$ [Migration]
ms
    info :: Text -> Client ()
info = IO () -> Client ()
forall a. IO a -> Client a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Client ()) -> (Text -> IO ()) -> Text -> Client ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Logger -> Level -> (Msg -> Msg) -> IO ()
forall (m :: * -> *).
MonadIO m =>
Logger -> Level -> (Msg -> Msg) -> m ()
Log.log Logger
l Level
Log.Info ((Msg -> Msg) -> IO ()) -> (Text -> Msg -> Msg) -> Text -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
Log.msg
    dropKeyspace :: Keyspace -> QueryString S () ()
    dropKeyspace :: Keyspace -> QueryString S () ()
dropKeyspace (Keyspace Text
k) = Text -> QueryString S () ()
forall k a b. Text -> QueryString k a b
QueryString (Text -> QueryString S () ()) -> Text -> QueryString S () ()
forall a b. (a -> b) -> a -> b
$ Text
"drop keyspace if exists \"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
fromStrict Text
k Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\""
    metaCreate :: QueryString S () ()
    metaCreate :: QueryString S () ()
metaCreate = QueryString S () ()
"create columnfamily if not exists meta (id int, version int, descr text, date timestamp, primary key (id, version))"
    metaInsert :: QueryString W (Int32, Text, UTCTime) ()
    metaInsert :: QueryString W (Int32, Text, UTCTime) ()
metaInsert = QueryString W (Int32, Text, UTCTime) ()
"insert into meta (id, version, descr, date) values (1,?,?,?)"

-- | Retrieve and compare local and peer system schema versions.
-- if they don't match, retry once per second for 30 seconds
waitForSchemaConsistency :: Client ()
waitForSchemaConsistency :: Client ()
waitForSchemaConsistency = do
  Client (UUID, [UUID]) -> Client ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Client (UUID, [UUID]) -> Client ())
-> Client (UUID, [UUID]) -> Client ()
forall a b. (a -> b) -> a -> b
$ Int
-> ((UUID, [UUID]) -> Bool)
-> Client (UUID, [UUID])
-> Client (UUID, [UUID])
forall (m :: * -> *) a.
MonadIO m =>
Int -> (a -> Bool) -> m a -> m a
retryWhileN Int
30 (UUID, [UUID]) -> Bool
inDisagreement Client (UUID, [UUID])
getSystemVersions
  where
    getSystemVersions :: Client (UUID, [UUID])
    getSystemVersions :: Client (UUID, [UUID])
getSystemVersions = do
      -- These two sub-queries must be made to the same node.
      -- (comparing local from node A and peers from node B wouldn't be correct)
      -- using the custom 'migrationPolicy' when connecting to cassandra ensures this.
      mbLocalVersion <- Client (Maybe UUID)
systemLocalVersion
      peers <- systemPeerVersions
      case mbLocalVersion of
        Just UUID
localVersion -> (UUID, [UUID]) -> Client (UUID, [UUID])
forall a. a -> Client a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((UUID, [UUID]) -> Client (UUID, [UUID]))
-> (UUID, [UUID]) -> Client (UUID, [UUID])
forall a b. (a -> b) -> a -> b
$ (UUID
localVersion, [UUID]
peers)
        Maybe UUID
Nothing -> [Char] -> Client (UUID, [UUID])
forall a. HasCallStack => [Char] -> a
error [Char]
"No system_version in system.local (should never happen)"
    inDisagreement :: (UUID, [UUID]) -> Bool
    inDisagreement :: (UUID, [UUID]) -> Bool
inDisagreement (UUID
localVersion, [UUID]
peers) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (UUID -> Bool) -> [UUID] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (UUID -> UUID -> Bool
forall a. Eq a => a -> a -> Bool
== UUID
localVersion) [UUID]
peers
    systemLocalVersion :: Client (Maybe UUID)
    systemLocalVersion :: Client (Maybe UUID)
systemLocalVersion = (Identity UUID -> UUID) -> Maybe (Identity UUID) -> Maybe UUID
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Identity UUID -> UUID
forall a. Identity a -> a
runIdentity (Maybe (Identity UUID) -> Maybe UUID)
-> Client (Maybe (Identity UUID)) -> Client (Maybe UUID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Client (Maybe (Identity UUID))
qry
      where
        qry :: Client (Maybe (Identity UUID))
qry = RetrySettings
-> Client (Maybe (Identity UUID)) -> Client (Maybe (Identity UUID))
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x1 (PrepQuery R () (Identity UUID)
-> QueryParams () -> Client (Maybe (Identity UUID))
forall (m :: * -> *) a b (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, Tuple b, RunQ q) =>
q R a b -> QueryParams a -> m (Maybe b)
query1 PrepQuery R () (Identity UUID)
cql (Consistency -> () -> QueryParams ()
forall a. Consistency -> a -> QueryParams a
params Consistency
One ()))
        cql :: PrepQuery R () (Identity UUID)
        cql :: PrepQuery R () (Identity UUID)
cql = PrepQuery R () (Identity UUID)
"select schema_version from system.local"
    systemPeerVersions :: Client [UUID]
    systemPeerVersions :: Client [UUID]
systemPeerVersions = (Identity UUID -> UUID) -> [Identity UUID] -> [UUID]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Identity UUID -> UUID
forall a. Identity a -> a
runIdentity ([Identity UUID] -> [UUID])
-> Client [Identity UUID] -> Client [UUID]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Client [Identity UUID]
qry
      where
        qry :: Client [Identity UUID]
qry = RetrySettings -> Client [Identity UUID] -> Client [Identity UUID]
forall (m :: * -> *) a.
MonadClient m =>
RetrySettings -> m a -> m a
retry RetrySettings
x1 (PrepQuery R () (Identity UUID)
-> QueryParams () -> Client [Identity UUID]
forall (m :: * -> *) a b (q :: * -> * -> * -> *).
(MonadClient m, Tuple a, Tuple b, RunQ q) =>
q R a b -> QueryParams a -> m [b]
query PrepQuery R () (Identity UUID)
cql (Consistency -> () -> QueryParams ()
forall a. Consistency -> a -> QueryParams a
params Consistency
One ()))
        cql :: PrepQuery R () (Identity UUID)
        cql :: PrepQuery R () (Identity UUID)
cql = PrepQuery R () (Identity UUID)
"select schema_version from system.peers"

retryWhileN :: (MonadIO m) => Int -> (a -> Bool) -> m a -> m a
retryWhileN :: forall (m :: * -> *) a.
MonadIO m =>
Int -> (a -> Bool) -> m a -> m a
retryWhileN Int
n a -> Bool
f m a
m =
  RetryPolicyM m
-> (RetryStatus -> a -> m Bool) -> (RetryStatus -> m a) -> m a
forall (m :: * -> *) b.
MonadIO m =>
RetryPolicyM m
-> (RetryStatus -> b -> m Bool) -> (RetryStatus -> m b) -> m b
retrying
    (Int -> RetryPolicyM m
forall (m :: * -> *). Monad m => Int -> RetryPolicyM m
constantDelay Int
1000000 RetryPolicyM m -> RetryPolicyM m -> RetryPolicyM m
forall a. Semigroup a => a -> a -> a
<> Int -> RetryPolicy
limitRetries Int
n)
    ((a -> m Bool) -> RetryStatus -> a -> m Bool
forall a b. a -> b -> a
const (Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> (a -> Bool) -> a -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Bool
f))
    (m a -> RetryStatus -> m a
forall a b. a -> b -> a
const m a
m)

-- | The migrationPolicy selects only one and always the same host
migrationPolicy :: IO Policy
migrationPolicy :: IO Policy
migrationPolicy = do
  h <- Maybe Host -> IO (IORef (Maybe Host))
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef Maybe Host
forall a. Maybe a
Nothing
  pure $
    Policy
      { setup = setHost h,
        onEvent = const $ pure (),
        select = readIORef h,
        acceptable = const $ pure True,
        hostCount = fromIntegral . length . maybeToList <$> readIORef h,
        display = ("migrationPolicy: " ++) . show <$> readIORef h,
        current = maybeToList <$> readIORef h
      }
  where
    setHost :: IORef (Maybe a) -> [a] -> p -> m ()
setHost IORef (Maybe a)
h (a
a : [a]
_) p
_ = IORef (Maybe a) -> Maybe a -> m ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef (Maybe a)
h (a -> Maybe a
forall a. a -> Maybe a
Just a
a)
    setHost IORef (Maybe a)
_ [a]
_ p
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()