-- |
-- This module provides a low-level effectful API dealing with the connections to the database.
module Hasql.Connection
  ( Connection,
    acquire,
    release,
    use,
  )
where

import Data.Text qualified as Text
import Hasql.Comms.Session qualified as Comms.Session
import Hasql.Connection.Config qualified as Config
import Hasql.Connection.ServerVersion qualified as ServerVersion
import Hasql.Connection.Settings qualified as Settings
import Hasql.Engine.Contexts.Session qualified as Session
import Hasql.Engine.Errors
import Hasql.Engine.Structures.ConnectionState qualified as ConnectionState
import Hasql.Engine.Structures.StatementCache qualified as StatementCache
import Hasql.Platform.Prelude
import Hasql.Pq qualified as Pq

-- |
-- A single connection to the database.
newtype Connection
  = Connection (MVar ConnectionState.ConnectionState)

-- |
-- Establish a connection according to the provided settings.
acquire ::
  Settings.Settings ->
  IO (Either ConnectionError Connection)
acquire :: Settings -> IO (Either ConnectionError Connection)
acquire Settings
settings =
  {-# SCC "acquire" #-}
  ExceptT ConnectionError IO Connection
-> IO (Either ConnectionError Connection)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT do
    let config :: Config
config = Settings -> Config
forall a. Constructs a => a -> Config
Config.construct Settings
settings

    -- Connect:
    Connection
pqConnection <- IO Connection -> ExceptT ConnectionError IO Connection
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT ConnectionError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ByteString -> IO Connection
Pq.connectdb (Config -> ByteString
Config.connectionString Config
config))

    -- Check status:
    ConnStatus
status <- IO ConnStatus -> ExceptT ConnectionError IO ConnStatus
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT ConnectionError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Connection -> IO ConnStatus
Pq.status Connection
pqConnection)
    case ConnStatus
status of
      ConnStatus
Pq.ConnectionOk -> () -> ExceptT ConnectionError IO ()
forall a. a -> ExceptT ConnectionError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      ConnStatus
_ -> do
        Maybe ByteString
errorMessage <- IO (Maybe ByteString)
-> ExceptT ConnectionError IO (Maybe ByteString)
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT ConnectionError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Connection -> IO (Maybe ByteString)
Pq.errorMessage Connection
pqConnection)
        ConnectionError -> ExceptT ConnectionError IO ()
forall a. ConnectionError -> ExceptT ConnectionError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Maybe ByteString -> ConnectionError
interpretConnectionError Maybe ByteString
errorMessage)

    -- Check version:
    ServerVersion
version <- IO ServerVersion -> ExceptT ConnectionError IO ServerVersion
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT ConnectionError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Connection -> IO ServerVersion
ServerVersion.load Connection
pqConnection)
    Bool
-> ExceptT ConnectionError IO () -> ExceptT ConnectionError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ServerVersion
version ServerVersion -> ServerVersion -> Bool
forall a. Ord a => a -> a -> Bool
< ServerVersion
ServerVersion.minimum) do
      ConnectionError -> ExceptT ConnectionError IO ()
forall a. ConnectionError -> ExceptT ConnectionError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ConnectionError
CompatibilityConnectionError (Text
"Server version is lower than 9: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ServerVersion -> Text
ServerVersion.toText ServerVersion
version))

    -- Initialize:
    IO (Maybe Result) -> ExceptT ConnectionError IO (Maybe Result)
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT ConnectionError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift do
      Connection -> ByteString -> IO (Maybe Result)
Pq.exec Connection
pqConnection do
        ByteString
"SET client_encoding = 'UTF8';\n\
        \SET client_min_messages TO WARNING;"

    let connectionState :: ConnectionState
connectionState =
          ConnectionState.ConnectionState
            { preparedStatements :: Bool
ConnectionState.preparedStatements = Bool -> Bool
not (Config -> Bool
Config.noPreparedStatements Config
config),
              statementCache :: StatementCache
ConnectionState.statementCache = StatementCache
StatementCache.empty,
              oidCache :: OidCache
ConnectionState.oidCache = OidCache
forall a. Monoid a => a
mempty,
              connection :: Connection
ConnectionState.connection = Connection
pqConnection
            }
    MVar ConnectionState
connectionRef <- IO (MVar ConnectionState)
-> ExceptT ConnectionError IO (MVar ConnectionState)
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT ConnectionError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ConnectionState -> IO (MVar ConnectionState)
forall a. a -> IO (MVar a)
newMVar ConnectionState
connectionState)
    pure (MVar ConnectionState -> Connection
Connection MVar ConnectionState
connectionRef)
  where
    interpretConnectionError :: Maybe ByteString -> ConnectionError
    interpretConnectionError :: Maybe ByteString -> ConnectionError
interpretConnectionError Maybe ByteString
errorMessage =
      case Maybe ByteString
errorMessage of
        Maybe ByteString
Nothing -> Text -> ConnectionError
OtherConnectionError Text
"Unknown connection error"
        Just ByteString
msg ->
          let msgText :: Text
msgText = ByteString -> Text
decodeUtf8Lenient ByteString
msg
              msgLower :: Text
msgLower = Text -> Text
Text.toLower Text
msgText
           in if
                | (Text -> Bool) -> [Text] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Text -> Text -> Bool
`Text.isInfixOf` Text
msgLower) [Text]
networkingErrors -> Text -> ConnectionError
NetworkingConnectionError Text
msgText
                | (Text -> Bool) -> [Text] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Text -> Text -> Bool
`Text.isInfixOf` Text
msgLower) [Text]
authenticationErrors -> Text -> ConnectionError
AuthenticationConnectionError Text
msgText
                | Bool
otherwise -> Text -> ConnectionError
OtherConnectionError (ByteString -> Text
decodeUtf8Lenient ByteString
msg)

    networkingErrors :: [Text]
    networkingErrors :: [Text]
networkingErrors =
      [ Text
"could not connect to server",
        Text
"no such file or directory",
        Text
"connection refused",
        Text
"timeout expired",
        Text
"host not found",
        Text
"could not translate host name"
      ]

    authenticationErrors :: [Text]
    authenticationErrors :: [Text]
authenticationErrors =
      [ Text
"authentication failed",
        Text
"password authentication failed",
        Text
"no password supplied",
        Text
"peer authentication failed"
      ]

-- |
-- Release the connection.
release :: Connection -> IO ()
release :: Connection -> IO ()
release (Connection MVar ConnectionState
connectionRef) =
  IO () -> IO ()
forall a. IO a -> IO a
mask_ do
    ConnectionState
connectionState <- MVar ConnectionState -> IO ConnectionState
forall a. MVar a -> IO a
readMVar MVar ConnectionState
connectionRef
    Connection -> IO ()
Pq.finish (ConnectionState -> Connection
ConnectionState.connection ConnectionState
connectionState)

-- |
-- Execute a sequence of operations with exclusive access to the connection.
--
-- Blocks until the connection is available when there is another session running upon the connection on a different thread.
use :: Connection -> Session.Session a -> IO (Either SessionError a)
use :: forall a. Connection -> Session a -> IO (Either SessionError a)
use (Connection MVar ConnectionState
var) Session a
session =
  ((forall a. IO a -> IO a) -> IO (Either SessionError a))
-> IO (Either SessionError a)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask \forall a. IO a -> IO a
restore -> do
    connectionState :: ConnectionState
connectionState@ConnectionState.ConnectionState {Bool
Connection
OidCache
StatementCache
preparedStatements :: ConnectionState -> Bool
statementCache :: ConnectionState -> StatementCache
oidCache :: ConnectionState -> OidCache
connection :: ConnectionState -> Connection
preparedStatements :: Bool
statementCache :: StatementCache
oidCache :: OidCache
connection :: Connection
..} <- MVar ConnectionState -> IO ConnectionState
forall a. MVar a -> IO a
takeMVar MVar ConnectionState
var
    Either SomeException (Either SessionError a, ConnectionState)
result <- forall e a. Exception e => IO a -> IO (Either e a)
try @SomeException (IO (Either SessionError a, ConnectionState)
-> IO (Either SessionError a, ConnectionState)
forall a. IO a -> IO a
restore (Session a
-> ConnectionState -> IO (Either SessionError a, ConnectionState)
forall a.
Session a
-> ConnectionState -> IO (Either SessionError a, ConnectionState)
Session.run Session a
session ConnectionState
connectionState))
    case Either SomeException (Either SessionError a, ConnectionState)
result of
      Left SomeException
exception -> do
        -- If an exception happened, we need to bring the connection back to idle
        -- without resetting (to preserve session state).
        Either Text ()
result <- Session () -> Connection -> IO (Either Text ())
forall a. Session a -> Connection -> IO (Either Text a)
Comms.Session.toHandler Session ()
Comms.Session.cleanUpAfterInterruption Connection
connection
        case Either Text ()
result of
          Left Text
err -> do
            -- If cleanup failed, we have to close the connection.
            -- There's not much else we can do.
            Connection -> IO ()
Pq.finish Connection
connection
            MVar ConnectionState -> ConnectionState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ConnectionState
var (ConnectionState -> ConnectionState
ConnectionState.resetPreparedStatementsCache ConnectionState
connectionState)
            let message :: Text
message =
                  [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat
                    [ Text
"Failed to clean up after interruption.\n",
                      Text
err,
                      Text
"\n",
                      Text
"The following exception was raised during the operation:\n",
                      String -> Text
Text.pack (SomeException -> String
forall e. Exception e => e -> String
displayException SomeException
exception)
                    ]
            pure (SessionError -> Either SessionError a
forall a b. a -> Either a b
Left (Text -> SessionError
DriverSessionError Text
message))
          Right () -> do
            MVar ConnectionState -> ConnectionState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ConnectionState
var (ConnectionState -> ConnectionState
ConnectionState.resetPreparedStatementsCache ConnectionState
connectionState)
            SomeException -> IO (Either SessionError a)
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
exception
      Right (Either SessionError a
result, !ConnectionState
newState) -> do
        MVar ConnectionState -> ConnectionState -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ConnectionState
var ConnectionState
newState
        pure Either SessionError a
result