-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Cassandra.Util
  ( defInitCassandra,
    initCassandraForService,
    initCassandra,
    Writetime (..),
    writetimeToInt64,
  )
where

import Cassandra.CQL
import Cassandra.Options
import Cassandra.Schema
import Cassandra.Settings (dcFilterPolicyIfConfigured, initialContactsDisco, initialContactsPlain, mkLogger)
import Data.Aeson
import Data.Fixed
import Data.List.NonEmpty qualified as NE
import Data.Text (pack, unpack)
import Data.Time (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock (secondsToNominalDiffTime)
import Data.Time.Clock.POSIX
import Database.CQL.IO
import Database.CQL.IO.Tinylog qualified as CT
import Imports hiding (init)
import OpenSSL.Session qualified as OpenSSL
import System.Logger qualified as Log

defInitCassandra :: CassandraOpts -> Log.Logger -> IO ClientState
defInitCassandra :: CassandraOpts -> Logger -> IO ClientState
defInitCassandra CassandraOpts
opts Logger
logger = do
  let basicCasSettings :: Settings
basicCasSettings =
        Logger -> Settings -> Settings
setLogger (Logger -> Logger
CT.mkLogger Logger
logger)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PortNumber -> Settings -> Settings
setPortNumber (Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral CassandraOpts
opts.endpoint.port)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String] -> Settings -> Settings
setContacts (Text -> String
unpack CassandraOpts
opts.endpoint.host) []
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Keyspace -> Settings -> Settings
setKeyspace (Text -> Keyspace
Keyspace CassandraOpts
opts.keyspace)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Settings -> Settings
setProtocolVersion Version
V4
          (Settings -> Settings) -> Settings -> Settings
forall a b. (a -> b) -> a -> b
$ Settings
defSettings
  Settings -> Maybe String -> Logger -> IO ClientState
initCassandra Settings
basicCasSettings CassandraOpts
opts.tlsCa Logger
logger

-- | Create Cassandra `ClientState` ("connection") for a service
initCassandraForService ::
  CassandraOpts ->
  String ->
  Maybe Text ->
  Maybe Int32 ->
  Log.Logger ->
  IO ClientState
initCassandraForService :: CassandraOpts
-> String -> Maybe Text -> Maybe Int32 -> Logger -> IO ClientState
initCassandraForService CassandraOpts
opts String
serviceName Maybe Text
discoUrl Maybe Int32
mbSchemaVersion Logger
logger = do
  NonEmpty String
c <-
    IO (NonEmpty String)
-> (Text -> IO (NonEmpty String))
-> Maybe Text
-> IO (NonEmpty String)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      (Text -> IO (NonEmpty String)
forall (m :: * -> *). MonadIO m => Text -> m (NonEmpty String)
initialContactsPlain CassandraOpts
opts.endpoint.host)
      (String -> String -> IO (NonEmpty String)
forall (m :: * -> *).
MonadIO m =>
String -> String -> m (NonEmpty String)
initialContactsDisco (String
"cassandra_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
serviceName) (String -> IO (NonEmpty String))
-> (Text -> String) -> Text -> IO (NonEmpty String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
unpack)
      Maybe Text
discoUrl
  let basicCasSettings :: Settings
basicCasSettings =
        Logger -> Settings -> Settings
setLogger (Logger -> Logger
mkLogger (Maybe Text -> Logger -> Logger
Log.clone (Text -> Maybe Text
forall a. a -> Maybe a
Just (String -> Text
pack (String
"cassandra." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
serviceName))) Logger
logger))
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String] -> Settings -> Settings
setContacts (NonEmpty String -> String
forall a. NonEmpty a -> a
NE.head NonEmpty String
c) (NonEmpty String -> [String]
forall a. NonEmpty a -> [a]
NE.tail NonEmpty String
c)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PortNumber -> Settings -> Settings
setPortNumber (Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral CassandraOpts
opts.endpoint.port)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Keyspace -> Settings -> Settings
setKeyspace (Text -> Keyspace
Keyspace CassandraOpts
opts.keyspace)
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Settings -> Settings
setMaxConnections Int
4
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Settings -> Settings
setPoolStripes Int
4
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Settings -> Settings
setSendTimeout NominalDiffTime
3
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Settings -> Settings
setResponseTimeout NominalDiffTime
10
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Settings -> Settings
setProtocolVersion Version
V4
          (Settings -> Settings)
-> (Settings -> Settings) -> Settings -> Settings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Policy -> Settings -> Settings
setPolicy (Logger -> Maybe Text -> IO Policy
dcFilterPolicyIfConfigured Logger
logger CassandraOpts
opts.filterNodesByDatacentre)
          (Settings -> Settings) -> Settings -> Settings
forall a b. (a -> b) -> a -> b
$ Settings
defSettings
  ClientState
p <- Settings -> Maybe String -> Logger -> IO ClientState
initCassandra Settings
basicCasSettings CassandraOpts
opts.tlsCa Logger
logger
  IO () -> (Int32 -> IO ()) -> Maybe Int32 -> IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (\Int32
v -> ClientState -> Client () -> IO ()
forall (m :: * -> *) a. MonadIO m => ClientState -> Client a -> m a
runClient ClientState
p (Client () -> IO ()) -> Client () -> IO ()
forall a b. (a -> b) -> a -> b
$ (Int32 -> Client ()
versionCheck Int32
v)) Maybe Int32
mbSchemaVersion
  ClientState -> IO ClientState
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ClientState
p

initCassandra :: Settings -> Maybe FilePath -> Log.Logger -> IO ClientState
initCassandra :: Settings -> Maybe String -> Logger -> IO ClientState
initCassandra Settings
settings (Just String
tlsCaPath) Logger
logger = do
  SSLContext
sslContext <- String -> IO SSLContext
createSSLContext String
tlsCaPath
  let settings' :: Settings
settings' = SSLContext -> Settings -> Settings
setSSLContext SSLContext
sslContext Settings
settings
  Settings -> IO ClientState
forall (m :: * -> *). MonadIO m => Settings -> m ClientState
init Settings
settings'
  where
    createSSLContext :: FilePath -> IO OpenSSL.SSLContext
    createSSLContext :: String -> IO SSLContext
createSSLContext String
certFile = do
      IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (IO () -> IO ()) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Logger -> (Msg -> Msg) -> IO ()
forall (m :: * -> *). MonadIO m => Logger -> (Msg -> Msg) -> m ()
Log.debug Logger
logger (String -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
Log.msg (String
"TLS cert file path: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
forall a. Show a => a -> String
show String
certFile))
      SSLContext
sslContext <- IO SSLContext
OpenSSL.context
      SSLContext -> String -> IO ()
OpenSSL.contextSetCAFile SSLContext
sslContext String
certFile
      SSLContext -> VerificationMode -> IO ()
OpenSSL.contextSetVerificationMode
        SSLContext
sslContext
        OpenSSL.VerifyPeer
          { vpFailIfNoPeerCert :: Bool
vpFailIfNoPeerCert = Bool
True,
            vpClientOnce :: Bool
vpClientOnce = Bool
True,
            vpCallback :: Maybe (Bool -> X509StoreCtx -> IO Bool)
vpCallback = Maybe (Bool -> X509StoreCtx -> IO Bool)
forall a. Maybe a
Nothing
          }
      SSLContext -> IO SSLContext
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SSLContext
sslContext
initCassandra Settings
settings Maybe String
Nothing Logger
logger = do
  IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (IO () -> IO ()) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Logger -> (Msg -> Msg) -> IO ()
forall (m :: * -> *). MonadIO m => Logger -> (Msg -> Msg) -> m ()
Log.debug Logger
logger (Text -> Msg -> Msg
forall a. ToBytes a => a -> Msg -> Msg
Log.msg (Text
"No TLS cert file path configured." :: Text))
  Settings -> IO ClientState
forall (m :: * -> *). MonadIO m => Settings -> m ClientState
init Settings
settings

-- | Read cassandra's writetimes https://docs.datastax.com/en/dse/5.1/cql/cql/cql_using/useWritetime.html
-- as UTCTime values without any loss of precision
newtype Writetime a = Writetime {forall {k} (a :: k). Writetime a -> UTCTime
writetimeToUTC :: UTCTime}
  deriving (Writetime a -> Writetime a -> Bool
(Writetime a -> Writetime a -> Bool)
-> (Writetime a -> Writetime a -> Bool) -> Eq (Writetime a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (a :: k). Writetime a -> Writetime a -> Bool
$c== :: forall k (a :: k). Writetime a -> Writetime a -> Bool
== :: Writetime a -> Writetime a -> Bool
$c/= :: forall k (a :: k). Writetime a -> Writetime a -> Bool
/= :: Writetime a -> Writetime a -> Bool
Eq, Int -> Writetime a -> String -> String
[Writetime a] -> String -> String
Writetime a -> String
(Int -> Writetime a -> String -> String)
-> (Writetime a -> String)
-> ([Writetime a] -> String -> String)
-> Show (Writetime a)
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
forall k (a :: k). Int -> Writetime a -> String -> String
forall k (a :: k). [Writetime a] -> String -> String
forall k (a :: k). Writetime a -> String
$cshowsPrec :: forall k (a :: k). Int -> Writetime a -> String -> String
showsPrec :: Int -> Writetime a -> String -> String
$cshow :: forall k (a :: k). Writetime a -> String
show :: Writetime a -> String
$cshowList :: forall k (a :: k). [Writetime a] -> String -> String
showList :: [Writetime a] -> String -> String
Show, (forall a b. (a -> b) -> Writetime a -> Writetime b)
-> (forall a b. a -> Writetime b -> Writetime a)
-> Functor Writetime
forall a b. a -> Writetime b -> Writetime a
forall a b. (a -> b) -> Writetime a -> Writetime b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Writetime a -> Writetime b
fmap :: forall a b. (a -> b) -> Writetime a -> Writetime b
$c<$ :: forall a b. a -> Writetime b -> Writetime a
<$ :: forall a b. a -> Writetime b -> Writetime a
Functor)

instance Cql (Writetime a) where
  ctype :: Tagged (Writetime a) ColumnType
ctype = ColumnType -> Tagged (Writetime a) ColumnType
forall a b. b -> Tagged a b
Tagged ColumnType
BigIntColumn
  toCql :: Writetime a -> Value
toCql = Int64 -> Value
CqlBigInt (Int64 -> Value) -> (Writetime a -> Int64) -> Writetime a -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writetime a -> Int64
forall {k} (a :: k). Writetime a -> Int64
writetimeToInt64
  fromCql :: Value -> Either String (Writetime a)
fromCql (CqlBigInt Int64
n) =
    Writetime a -> Either String (Writetime a)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      (Writetime a -> Either String (Writetime a))
-> (Int64 -> Writetime a) -> Int64 -> Either String (Writetime a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> Writetime a
forall {k} (a :: k). UTCTime -> Writetime a
Writetime
      (UTCTime -> Writetime a)
-> (Int64 -> UTCTime) -> Int64 -> Writetime a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> UTCTime
posixSecondsToUTCTime
      (NominalDiffTime -> UTCTime)
-> (Int64 -> NominalDiffTime) -> Int64 -> UTCTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pico -> NominalDiffTime
secondsToNominalDiffTime
      (Pico -> NominalDiffTime)
-> (Int64 -> Pico) -> Int64 -> NominalDiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Pico
forall k (a :: k). Integer -> Fixed a
MkFixed
      (Integer -> Pico) -> (Int64 -> Integer) -> Int64 -> Pico
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
1_000_000)
      (Integer -> Integer) -> (Int64 -> Integer) -> Int64 -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int64 @Integer
      (Int64 -> Either String (Writetime a))
-> Int64 -> Either String (Writetime a)
forall a b. (a -> b) -> a -> b
$ Int64
n
  fromCql Value
_ = String -> Either String (Writetime a)
forall a b. a -> Either a b
Left String
"Writetime: bigint expected"

-- | This yields the same int as it is returned by WRITETIME()
writetimeToInt64 :: Writetime a -> Int64
writetimeToInt64 :: forall {k} (a :: k). Writetime a -> Int64
writetimeToInt64 =
  forall a b. (Integral a, Num b) => a -> b
fromIntegral @Integer @Int64
    (Integer -> Int64)
-> (Writetime a -> Integer) -> Writetime a -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
1_000_000)
    (Integer -> Integer)
-> (Writetime a -> Integer) -> Writetime a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pico -> Integer
forall {k} (a :: k). Fixed a -> Integer
unfixed
    (Pico -> Integer)
-> (Writetime a -> Pico) -> Writetime a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Pico
nominalDiffTimeToSeconds
    (NominalDiffTime -> Pico)
-> (Writetime a -> NominalDiffTime) -> Writetime a -> Pico
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> NominalDiffTime
utcTimeToPOSIXSeconds
    (UTCTime -> NominalDiffTime)
-> (Writetime a -> UTCTime) -> Writetime a -> NominalDiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writetime a -> UTCTime
forall {k} (a :: k). Writetime a -> UTCTime
writetimeToUTC
  where
    unfixed :: Fixed a -> Integer
    unfixed :: forall {k} (a :: k). Fixed a -> Integer
unfixed (MkFixed Integer
n) = Integer
n

instance ToJSON (Writetime a) where
  toJSON :: Writetime a -> Value
toJSON = Int64 -> Value
forall a. ToJSON a => a -> Value
toJSON (Int64 -> Value) -> (Writetime a -> Int64) -> Writetime a -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writetime a -> Int64
forall {k} (a :: k). Writetime a -> Int64
writetimeToInt64