-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Hasql.Pool.Extended where

import Data.Aeson
import Data.Map as Map
import Data.Misc
import Data.Set qualified as Set
import Data.UUID
import Hasql.Connection.Setting qualified as HasqlSetting
import Hasql.Connection.Setting.Connection qualified as HasqlConn
import Hasql.Connection.Setting.Connection.Param qualified as HasqlConfig
import Hasql.Pool as HasqlPool
import Hasql.Pool.Config qualified as HasqlPool
import Hasql.Pool.Observation
import Imports
import Prometheus
import Util.Options

data PoolConfig = PoolConfig
  { PoolConfig -> Int
size :: Int,
    PoolConfig -> Duration
acquisitionTimeout :: Duration,
    PoolConfig -> Duration
agingTimeout :: Duration,
    PoolConfig -> Duration
idlenessTimeout :: Duration
  }
  deriving (PoolConfig -> PoolConfig -> Bool
(PoolConfig -> PoolConfig -> Bool)
-> (PoolConfig -> PoolConfig -> Bool) -> Eq PoolConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PoolConfig -> PoolConfig -> Bool
== :: PoolConfig -> PoolConfig -> Bool
$c/= :: PoolConfig -> PoolConfig -> Bool
/= :: PoolConfig -> PoolConfig -> Bool
Eq, Int -> PoolConfig -> ShowS
[PoolConfig] -> ShowS
PoolConfig -> String
(Int -> PoolConfig -> ShowS)
-> (PoolConfig -> String)
-> ([PoolConfig] -> ShowS)
-> Show PoolConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PoolConfig -> ShowS
showsPrec :: Int -> PoolConfig -> ShowS
$cshow :: PoolConfig -> String
show :: PoolConfig -> String
$cshowList :: [PoolConfig] -> ShowS
showList :: [PoolConfig] -> ShowS
Show)

instance FromJSON PoolConfig where
  parseJSON :: Value -> Parser PoolConfig
parseJSON = String
-> (Object -> Parser PoolConfig) -> Value -> Parser PoolConfig
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"PoolConfig" ((Object -> Parser PoolConfig) -> Value -> Parser PoolConfig)
-> (Object -> Parser PoolConfig) -> Value -> Parser PoolConfig
forall a b. (a -> b) -> a -> b
$ \Object
o ->
    Int -> Duration -> Duration -> Duration -> PoolConfig
PoolConfig
      (Int -> Duration -> Duration -> Duration -> PoolConfig)
-> Parser Int
-> Parser (Duration -> Duration -> Duration -> PoolConfig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser Int
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"size"
      Parser (Duration -> Duration -> Duration -> PoolConfig)
-> Parser Duration -> Parser (Duration -> Duration -> PoolConfig)
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Key -> Parser Duration
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"acquisitionTimeout"
      Parser (Duration -> Duration -> PoolConfig)
-> Parser Duration -> Parser (Duration -> PoolConfig)
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Key -> Parser Duration
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"agingTimeout"
      Parser (Duration -> PoolConfig)
-> Parser Duration -> Parser PoolConfig
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
o Object -> Key -> Parser Duration
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"idlenessTimeout"

-- | Creates a pool from postgres config params
--
-- HasqlConn.params translates pgParams into connection (which just holds the connection string and is not a real connection)
-- HasqlSetting.connection unwraps the connection string out of connection
-- HasqlPool.staticConnectionSettings translates the connection string to the pool settings
-- HasqlPool.settings translates the pool settings into pool config
-- HasqlPool.acquire creates the pool.
-- ezpz.
initPostgresPool :: PoolConfig -> Map Text Text -> Maybe FilePathSecrets -> IO HasqlPool.Pool
initPostgresPool :: PoolConfig -> Map Text Text -> Maybe FilePathSecrets -> IO Pool
initPostgresPool PoolConfig
config Map Text Text
pgConfig Maybe FilePathSecrets
mFpSecrets = do
  Maybe Text
mPw <- Maybe FilePathSecrets
-> (FilePathSecrets -> IO Text) -> IO (Maybe Text)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Maybe FilePathSecrets
mFpSecrets FilePathSecrets -> IO Text
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
FilePathSecrets -> m a
initCredentials
  let pgConfigWithPw :: Map Text Text
pgConfigWithPw = Map Text Text
-> (Text -> Map Text Text) -> Maybe Text -> Map Text Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Map Text Text
pgConfig (\Text
pw -> Text -> Text -> Map Text Text -> Map Text Text
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Text
"password" Text
pw Map Text Text
pgConfig) Maybe Text
mPw
      pgParams :: [Param]
pgParams = (Text -> Text -> [Param]) -> Map Text Text -> [Param]
forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
Map.foldMapWithKey (\Text
k Text
v -> [Text -> Text -> Param
HasqlConfig.other Text
k Text
v]) Map Text Text
pgConfigWithPw
  HasqlPoolMetrics
metrics <- IO HasqlPoolMetrics
initHasqlPoolMetrics
  IORef Connections
connsRef <- Connections -> IO (IORef Connections)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef (Connections -> IO (IORef Connections))
-> Connections -> IO (IORef Connections)
forall a b. (a -> b) -> a -> b
$ Set UUID -> Set UUID -> Set UUID -> Connections
Connections Set UUID
forall a. Monoid a => a
mempty Set UUID
forall a. Monoid a => a
mempty Set UUID
forall a. Monoid a => a
mempty
  Config -> IO Pool
HasqlPool.acquire (Config -> IO Pool) -> Config -> IO Pool
forall a b. (a -> b) -> a -> b
$
    [Setting] -> Config
HasqlPool.settings
      [ [Setting] -> Setting
HasqlPool.staticConnectionSettings ([Setting] -> Setting) -> [Setting] -> Setting
forall a b. (a -> b) -> a -> b
$
          [Connection -> Setting
HasqlSetting.connection (Connection -> Setting) -> Connection -> Setting
forall a b. (a -> b) -> a -> b
$ [Param] -> Connection
HasqlConn.params [Param]
pgParams],
        Int -> Setting
HasqlPool.size PoolConfig
config.size,
        DiffTime -> Setting
HasqlPool.acquisitionTimeout PoolConfig
config.acquisitionTimeout.duration,
        DiffTime -> Setting
HasqlPool.agingTimeout PoolConfig
config.agingTimeout.duration,
        DiffTime -> Setting
HasqlPool.idlenessTimeout PoolConfig
config.idlenessTimeout.duration,
        (Observation -> IO ()) -> Setting
HasqlPool.observationHandler (IORef Connections -> HasqlPoolMetrics -> Observation -> IO ()
observationHandler IORef Connections
connsRef HasqlPoolMetrics
metrics)
      ]

data HasqlPoolMetrics = HasqlPoolMetrics
  { HasqlPoolMetrics -> Gauge
readyForUseGauge :: Gauge,
    HasqlPoolMetrics -> Gauge
inUseGauge :: Gauge,
    HasqlPoolMetrics -> Counter
establishedCounter :: Counter,
    HasqlPoolMetrics -> Counter
terminationCounter :: Counter,
    HasqlPoolMetrics -> Counter
sessionFailureCounter :: Counter,
    HasqlPoolMetrics -> Counter
sessionCounter :: Counter
  }

initHasqlPoolMetrics :: IO HasqlPoolMetrics
initHasqlPoolMetrics :: IO HasqlPoolMetrics
initHasqlPoolMetrics = do
  Gauge
-> Gauge
-> Counter
-> Counter
-> Counter
-> Counter
-> HasqlPoolMetrics
HasqlPoolMetrics
    (Gauge
 -> Gauge
 -> Counter
 -> Counter
 -> Counter
 -> Counter
 -> HasqlPoolMetrics)
-> IO Gauge
-> IO
     (Gauge
      -> Counter -> Counter -> Counter -> Counter -> HasqlPoolMetrics)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Metric Gauge -> IO Gauge
forall (m :: * -> *) s. MonadIO m => Metric s -> m s
register (Info -> Metric Gauge
gauge (Info -> Metric Gauge) -> Info -> Metric Gauge
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Info
Info Text
"wire_hasql_pool_ready_for_use" Text
"Number of hasql pool connections ready for use")
    IO
  (Gauge
   -> Counter -> Counter -> Counter -> Counter -> HasqlPoolMetrics)
-> IO Gauge
-> IO
     (Counter -> Counter -> Counter -> Counter -> HasqlPoolMetrics)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Metric Gauge -> IO Gauge
forall (m :: * -> *) s. MonadIO m => Metric s -> m s
register (Info -> Metric Gauge
gauge (Info -> Metric Gauge) -> Info -> Metric Gauge
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Info
Info Text
"wire_hasql_pool_in_use" Text
"Number of hasql pool connections ready for use")
    IO (Counter -> Counter -> Counter -> Counter -> HasqlPoolMetrics)
-> IO Counter
-> IO (Counter -> Counter -> Counter -> HasqlPoolMetrics)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Metric Counter -> IO Counter
forall (m :: * -> *) s. MonadIO m => Metric s -> m s
register (Info -> Metric Counter
counter (Info -> Metric Counter) -> Info -> Metric Counter
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Info
Info Text
"wire_hasql_pool_connection_established_count" Text
"Number of established connections")
    IO (Counter -> Counter -> Counter -> HasqlPoolMetrics)
-> IO Counter -> IO (Counter -> Counter -> HasqlPoolMetrics)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Metric Counter -> IO Counter
forall (m :: * -> *) s. MonadIO m => Metric s -> m s
register (Info -> Metric Counter
counter (Info -> Metric Counter) -> Info -> Metric Counter
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Info
Info Text
"wire_hasql_pool_connection_terminated_count" Text
"Number of terminated connections")
    IO (Counter -> Counter -> HasqlPoolMetrics)
-> IO Counter -> IO (Counter -> HasqlPoolMetrics)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Metric Counter -> IO Counter
forall (m :: * -> *) s. MonadIO m => Metric s -> m s
register (Info -> Metric Counter
counter (Info -> Metric Counter) -> Info -> Metric Counter
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Info
Info Text
"wire_hasql_pool_session_failure_count" Text
"Number of times a session has failed")
    IO (Counter -> HasqlPoolMetrics)
-> IO Counter -> IO HasqlPoolMetrics
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Metric Counter -> IO Counter
forall (m :: * -> *) s. MonadIO m => Metric s -> m s
register (Info -> Metric Counter
counter (Info -> Metric Counter) -> Info -> Metric Counter
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Info
Info Text
"wire_hasql_pool_session_count" Text
"Number of times a session was created")

data Connections = Connections
  { Connections -> Set UUID
connecting :: Set UUID,
    Connections -> Set UUID
inUse :: Set UUID,
    Connections -> Set UUID
readyForUse :: Set UUID
  }

observationHandler :: IORef Connections -> HasqlPoolMetrics -> Observation -> IO ()
observationHandler :: IORef Connections -> HasqlPoolMetrics -> Observation -> IO ()
observationHandler IORef Connections
connsRef HasqlPoolMetrics
metrics (ConnectionObservation UUID
connId ConnectionStatus
status) = do
  case ConnectionStatus
status of
    ConnectionStatus
ConnectingConnectionStatus -> do
      IORef Connections -> (Connections -> Connections) -> IO ()
forall (m :: * -> *) a. MonadIO m => IORef a -> (a -> a) -> m ()
modifyIORef' IORef Connections
connsRef (\Connections
conns -> Connections
conns {connecting = Set.insert connId conns.connecting})
    ReadyForUseConnectionStatus ConnectionReadyForUseReason
reason -> do
      Connections -> Connections
connsChange <- case ConnectionReadyForUseReason
reason of
        SessionFailedConnectionReadyForUseReason SessionError
_ -> do
          Counter -> IO ()
forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter HasqlPoolMetrics
metrics.sessionFailureCounter
          (Connections -> Connections) -> IO (Connections -> Connections)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Connections -> Connections) -> IO (Connections -> Connections))
-> (Connections -> Connections) -> IO (Connections -> Connections)
forall a b. (a -> b) -> a -> b
$ \Connections
conns -> Connections
conns {inUse = Set.delete connId conns.inUse}
        ConnectionReadyForUseReason
SessionSucceededConnectionReadyForUseReason -> do
          (Connections -> Connections) -> IO (Connections -> Connections)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Connections -> Connections) -> IO (Connections -> Connections))
-> (Connections -> Connections) -> IO (Connections -> Connections)
forall a b. (a -> b) -> a -> b
$ \Connections
conns -> Connections
conns {inUse = Set.delete connId conns.inUse}
        ConnectionReadyForUseReason
EstablishedConnectionReadyForUseReason -> do
          Counter -> IO ()
forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter HasqlPoolMetrics
metrics.establishedCounter
          (Connections -> Connections) -> IO (Connections -> Connections)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (\Connections
conns -> Connections
conns {connecting = Set.delete connId conns.connecting})

      (Int
inUseSize, Int
readyForUseSize) <- IORef Connections
-> (Connections -> (Connections, (Int, Int))) -> IO (Int, Int)
forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Connections
connsRef ((Connections -> (Connections, (Int, Int))) -> IO (Int, Int))
-> (Connections -> (Connections, (Int, Int))) -> IO (Int, Int)
forall a b. (a -> b) -> a -> b
$ \Connections
conns ->
        let newConns :: Connections
newConns = (Connections -> Connections
connsChange Connections
conns) {readyForUse = Set.insert connId conns.readyForUse}
         in (Connections
newConns, (Set UUID -> Int
forall a. Set a -> Int
Set.size Connections
newConns.inUse, Set UUID -> Int
forall a. Set a -> Int
Set.size Connections
newConns.readyForUse))

      Gauge -> Double -> IO ()
forall (m :: * -> *). MonadMonitor m => Gauge -> Double -> m ()
setGauge HasqlPoolMetrics
metrics.readyForUseGauge (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
readyForUseSize)
      Gauge -> Double -> IO ()
forall (m :: * -> *). MonadMonitor m => Gauge -> Double -> m ()
setGauge HasqlPoolMetrics
metrics.inUseGauge (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inUseSize)
    ConnectionStatus
InUseConnectionStatus -> do
      Counter -> IO ()
forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter HasqlPoolMetrics
metrics.sessionCounter
      (Int
inUseSize, Int
readyForUseSize) <- IORef Connections
-> (Connections -> (Connections, (Int, Int))) -> IO (Int, Int)
forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Connections
connsRef ((Connections -> (Connections, (Int, Int))) -> IO (Int, Int))
-> (Connections -> (Connections, (Int, Int))) -> IO (Int, Int)
forall a b. (a -> b) -> a -> b
$ \Connections
conns ->
        let newConns :: Connections
newConns =
              Connections
conns
                { readyForUse = Set.delete connId conns.readyForUse,
                  inUse = Set.insert connId conns.inUse
                }
         in (Connections
newConns, (Set UUID -> Int
forall a. Set a -> Int
Set.size Connections
newConns.inUse, Set UUID -> Int
forall a. Set a -> Int
Set.size Connections
newConns.readyForUse))
      Gauge -> Double -> IO ()
forall (m :: * -> *). MonadMonitor m => Gauge -> Double -> m ()
setGauge HasqlPoolMetrics
metrics.readyForUseGauge (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
readyForUseSize)
      Gauge -> Double -> IO ()
forall (m :: * -> *). MonadMonitor m => Gauge -> Double -> m ()
setGauge HasqlPoolMetrics
metrics.inUseGauge (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inUseSize)
    TerminatedConnectionStatus ConnectionTerminationReason
_ -> do
      (Int
inUseSize, Int
readyForUseSize) <- IORef Connections
-> (Connections -> (Connections, (Int, Int))) -> IO (Int, Int)
forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Connections
connsRef ((Connections -> (Connections, (Int, Int))) -> IO (Int, Int))
-> (Connections -> (Connections, (Int, Int))) -> IO (Int, Int)
forall a b. (a -> b) -> a -> b
$ \Connections
conns ->
        let newConns :: Connections
newConns =
              Connections
conns
                { connecting = Set.delete connId conns.connecting,
                  readyForUse = Set.delete connId conns.readyForUse,
                  inUse = Set.delete connId conns.inUse
                }
         in (Connections
newConns, (Set UUID -> Int
forall a. Set a -> Int
Set.size Connections
newConns.inUse, Set UUID -> Int
forall a. Set a -> Int
Set.size Connections
newConns.readyForUse))
      Counter -> IO ()
forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter HasqlPoolMetrics
metrics.terminationCounter
      Gauge -> Double -> IO ()
forall (m :: * -> *). MonadMonitor m => Gauge -> Double -> m ()
setGauge HasqlPoolMetrics
metrics.readyForUseGauge (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
readyForUseSize)
      Gauge -> Double -> IO ()
forall (m :: * -> *). MonadMonitor m => Gauge -> Double -> m ()
setGauge HasqlPoolMetrics
metrics.inUseGauge (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inUseSize)