module Hasql.Comms.Session
  ( Session,

    -- * Constructors
    cleanUpAfterInterruption,

    -- * Executors
    toHandler,
  )
where

import Hasql.Comms.Roundtrip qualified as Roundtrip
import Hasql.Platform.Prelude
import Hasql.Pq qualified as Pq

-- | Serial execution of commands in the scope of a connection.
newtype Session a = Session (Pq.Connection -> IO (Either Error a))
  deriving
    ((forall a b. (a -> b) -> Session a -> Session b)
-> (forall a b. a -> Session b -> Session a) -> Functor Session
forall a b. a -> Session b -> Session a
forall a b. (a -> b) -> Session a -> Session b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Session a -> Session b
fmap :: forall a b. (a -> b) -> Session a -> Session b
$c<$ :: forall a b. a -> Session b -> Session a
<$ :: forall a b. a -> Session b -> Session a
Functor, Functor Session
Functor Session =>
(forall a. a -> Session a)
-> (forall a b. Session (a -> b) -> Session a -> Session b)
-> (forall a b c.
    (a -> b -> c) -> Session a -> Session b -> Session c)
-> (forall a b. Session a -> Session b -> Session b)
-> (forall a b. Session a -> Session b -> Session a)
-> Applicative Session
forall a. a -> Session a
forall a b. Session a -> Session b -> Session a
forall a b. Session a -> Session b -> Session b
forall a b. Session (a -> b) -> Session a -> Session b
forall a b c. (a -> b -> c) -> Session a -> Session b -> Session c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> Session a
pure :: forall a. a -> Session a
$c<*> :: forall a b. Session (a -> b) -> Session a -> Session b
<*> :: forall a b. Session (a -> b) -> Session a -> Session b
$cliftA2 :: forall a b c. (a -> b -> c) -> Session a -> Session b -> Session c
liftA2 :: forall a b c. (a -> b -> c) -> Session a -> Session b -> Session c
$c*> :: forall a b. Session a -> Session b -> Session b
*> :: forall a b. Session a -> Session b -> Session b
$c<* :: forall a b. Session a -> Session b -> Session a
<* :: forall a b. Session a -> Session b -> Session a
Applicative, Applicative Session
Applicative Session =>
(forall a b. Session a -> (a -> Session b) -> Session b)
-> (forall a b. Session a -> Session b -> Session b)
-> (forall a. a -> Session a)
-> Monad Session
forall a. a -> Session a
forall a b. Session a -> Session b -> Session b
forall a b. Session a -> (a -> Session b) -> Session b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. Session a -> (a -> Session b) -> Session b
>>= :: forall a b. Session a -> (a -> Session b) -> Session b
$c>> :: forall a b. Session a -> Session b -> Session b
>> :: forall a b. Session a -> Session b -> Session b
$creturn :: forall a. a -> Session a
return :: forall a. a -> Session a
Monad, MonadError Error)
    via (ExceptT Error (ReaderT Pq.Connection IO))

type Error = Text

-- * Constructors

-- | Bring the connection to a clean state after an interruption.
--
-- This includes:
-- - Leaving pipeline mode if we are in it.
-- - Bringing the transaction status to idle if we are in a transaction.
-- - Deallocating all prepared statements.
cleanUpAfterInterruption :: Session ()
cleanUpAfterInterruption :: Session ()
cleanUpAfterInterruption = do
  Session ()
drainResults
  Session ()
cancel
  Session ()
drainResults
  -- Ensure we are out of pipeline mode.
  Session ()
leavePipeline
  -- Ensure we are in idle transaction state.
  Session ()
bringTransactionStatusToIdle
  Session ()
deallocateAllPreparedStatements

bringTransactionStatusToIdle :: Session ()
bringTransactionStatusToIdle :: Session ()
bringTransactionStatusToIdle = do
  TransactionStatus
transactionStatus <- Session TransactionStatus
getTransactionStatus
  case TransactionStatus
transactionStatus of
    TransactionStatus
Pq.TransIdle -> () -> Session ()
forall a. a -> Session a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    TransactionStatus
Pq.TransInTrans -> do
      ByteString -> Session ()
runScript ByteString
"ABORT"
    TransactionStatus
Pq.TransActive -> do
      -- A command is still in progress.
      Session ()
drainResults
      -- Check status again after draining.
      TransactionStatus
transactionStatus <- Session TransactionStatus
getTransactionStatus
      case TransactionStatus
transactionStatus of
        TransactionStatus
Pq.TransIdle -> () -> Session ()
forall a. a -> Session a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        TransactionStatus
Pq.TransInTrans -> do
          ByteString -> Session ()
runScript ByteString
"ABORT"
        TransactionStatus
Pq.TransActive -> do
          -- If we're still active, there's not much we can do.
          -- The connection is probably in a bad state.
          Error -> Session ()
forall a. Error -> Session a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
"Failed to bring transaction status to idle after draining results"
        TransactionStatus
Pq.TransInError -> do
          ByteString -> Session ()
runScript ByteString
"ABORT"
        TransactionStatus
Pq.TransUnknown -> do
          -- Unknown state (connection issue), there's not much we can do.
          Error -> Session ()
forall a. Error -> Session a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
"Transaction status is unknown, connection is corrupted"
    TransactionStatus
Pq.TransInError -> do
      -- Transaction is in error state, we need to abort it.
      ByteString -> Session ()
runScript ByteString
"ABORT"
    TransactionStatus
Pq.TransUnknown -> do
      -- Unknown state (connection issue), there's not much we can do.
      Error -> Session ()
forall a. Error -> Session a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
"Transaction status is unknown, connection is corrupted"

leavePipeline :: Session ()
leavePipeline :: Session ()
leavePipeline = do
  PipelineStatus
pipelineStatus <- Session PipelineStatus
getPipelineStatus
  Bool -> Session () -> Session ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (PipelineStatus
pipelineStatus PipelineStatus -> PipelineStatus -> Bool
forall a. Eq a => a -> a -> Bool
== PipelineStatus
Pq.PipelineOn) do
    -- In pipeline mode, we need to ensure the pipeline is synchronized before exiting.
    -- Send a pipeline sync marker to flush any pending operations.
    Bool
syncSuccess <- Session Bool
sendPipelineSync
    Bool -> Session () -> Session ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
syncSuccess Session ()
drainResults
    -- After sync, send a flush to ensure all queued commands are sent to the server.
    Bool
flushSuccess <- Session Bool
sendFlushRequest
    Bool -> Session () -> Session ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
flushSuccess Session ()
drainResults
    -- Try to exit pipeline mode.
    -- This might fail if there are pending results that need to be consumed.
    Bool
success <- Session Bool
exitPipelineMode
    Bool -> Session () -> Session ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
success do
      -- If exit failed, drain results and try again.
      Session ()
drainResults
      Bool
success <- Session Bool
exitPipelineMode
      Bool -> Session () -> Session ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
success do
        -- If it still fails, there's not much we can do.
        -- The connection is probably in a bad state.
        Maybe ByteString
errorMessage <- Session (Maybe ByteString)
getErrorMessage
        let message :: Error
message = case Maybe ByteString
errorMessage of
              Maybe ByteString
Nothing -> Error
"Failed to exit pipeline mode after draining results"
              Just ByteString
details -> Error
"Failed to exit pipeline mode after draining results: " Error -> Error -> Error
forall a. Semigroup a => a -> a -> a
<> ByteString -> Error
decodeUtf8Lenient ByteString
details
        Error -> Session ()
forall a. Error -> Session a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
message

deallocateAllPreparedStatements :: Session ()
deallocateAllPreparedStatements :: Session ()
deallocateAllPreparedStatements =
  ByteString -> Session ()
runScript ByteString
"DEALLOCATE ALL"

cancel :: Session ()
cancel :: Session ()
cancel = (Connection -> IO (Either Error ())) -> Session ()
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  Maybe Cancel
mCancel <- Connection -> IO (Maybe Cancel)
Pq.getCancel Connection
connection
  case Maybe Cancel
mCancel of
    Just Cancel
cancel -> do
      Either ByteString ()
result <- Cancel -> IO (Either ByteString ())
Pq.cancel Cancel
cancel
      case Either ByteString ()
result of
        Left ByteString
errorMessage ->
          Either Error () -> IO (Either Error ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Error -> Either Error ()
forall a b. a -> Either a b
Left (Error
"Failed to cancel: " Error -> Error -> Error
forall a. Semigroup a => a -> a -> a
<> ByteString -> Error
decodeUtf8Lenient ByteString
errorMessage))
        Right () ->
          Either Error () -> IO (Either Error ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Error ()
forall a b. b -> Either a b
Right ())
    Maybe Cancel
Nothing -> Either Error () -> IO (Either Error ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Error ()
forall a b. b -> Either a b
Right ())

getErrorMessage :: Session (Maybe ByteString)
getErrorMessage :: Session (Maybe ByteString)
getErrorMessage = (Connection -> IO (Either Error (Maybe ByteString)))
-> Session (Maybe ByteString)
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  Maybe ByteString -> Either Error (Maybe ByteString)
forall a b. b -> Either a b
Right (Maybe ByteString -> Either Error (Maybe ByteString))
-> IO (Maybe ByteString) -> IO (Either Error (Maybe ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO (Maybe ByteString)
Pq.errorMessage Connection
connection

getTransactionStatus :: Session Pq.TransactionStatus
getTransactionStatus :: Session TransactionStatus
getTransactionStatus = (Connection -> IO (Either Error TransactionStatus))
-> Session TransactionStatus
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  TransactionStatus -> Either Error TransactionStatus
forall a b. b -> Either a b
Right (TransactionStatus -> Either Error TransactionStatus)
-> IO TransactionStatus -> IO (Either Error TransactionStatus)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO TransactionStatus
Pq.transactionStatus Connection
connection

getPipelineStatus :: Session Pq.PipelineStatus
getPipelineStatus :: Session PipelineStatus
getPipelineStatus = (Connection -> IO (Either Error PipelineStatus))
-> Session PipelineStatus
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  PipelineStatus -> Either Error PipelineStatus
forall a b. b -> Either a b
Right (PipelineStatus -> Either Error PipelineStatus)
-> IO PipelineStatus -> IO (Either Error PipelineStatus)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO PipelineStatus
Pq.pipelineStatus Connection
connection

exitPipelineMode :: Session Bool
exitPipelineMode :: Session Bool
exitPipelineMode = (Connection -> IO (Either Error Bool)) -> Session Bool
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  Bool -> Either Error Bool
forall a b. b -> Either a b
Right (Bool -> Either Error Bool) -> IO Bool -> IO (Either Error Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Bool
Pq.exitPipelineMode Connection
connection

sendPipelineSync :: Session Bool
sendPipelineSync :: Session Bool
sendPipelineSync = (Connection -> IO (Either Error Bool)) -> Session Bool
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  Bool -> Either Error Bool
forall a b. b -> Either a b
Right (Bool -> Either Error Bool) -> IO Bool -> IO (Either Error Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Bool
Pq.pipelineSync Connection
connection

sendFlushRequest :: Session Bool
sendFlushRequest :: Session Bool
sendFlushRequest = (Connection -> IO (Either Error Bool)) -> Session Bool
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  Bool -> Either Error Bool
forall a b. b -> Either a b
Right (Bool -> Either Error Bool) -> IO Bool -> IO (Either Error Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Bool
Pq.sendFlushRequest Connection
connection

-- Drain all pending results from the connection.
drainResults :: Session ()
drainResults :: Session ()
drainResults = (Connection -> IO (Either Error ())) -> Session ()
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection ->
  let go :: IO ()
go = do
        Maybe Result
mResult <- Connection -> IO (Maybe Result)
Pq.getResult Connection
connection
        case Maybe Result
mResult of
          Maybe Result
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          Just Result
_ -> IO ()
go
   in IO ()
go IO () -> Either Error () -> IO (Either Error ())
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> () -> Either Error ()
forall a b. b -> Either a b
Right ()

runScript :: ByteString -> Session ()
runScript :: ByteString -> Session ()
runScript ByteString
script = Roundtrip () () -> Session ()
forall a. Roundtrip () a -> Session a
runRoundtrip (() -> ByteString -> Roundtrip () ()
forall context. context -> ByteString -> Roundtrip context ()
Roundtrip.query () ByteString
script)

runRoundtrip :: Roundtrip.Roundtrip () a -> Session a
runRoundtrip :: forall a. Roundtrip () a -> Session a
runRoundtrip Roundtrip () a
roundtrip = (Connection -> IO (Either Error a)) -> Session a
forall a. (Connection -> IO (Either Error a)) -> Session a
Session \Connection
connection -> do
  Either (Error ()) a
result <- Roundtrip () a -> Connection -> IO (Either (Error ()) a)
forall context a.
Roundtrip context a -> Connection -> IO (Either (Error context) a)
Roundtrip.toSerialIO Roundtrip () a
roundtrip Connection
connection
  case Either (Error ()) a
result of
    Left Error ()
err ->
      let message :: Error
message = case Error ()
err of
            Roundtrip.ClientError () Maybe ByteString
Nothing ->
              Error
"Unknown client error occurred"
            Roundtrip.ClientError () (Just ByteString
details) ->
              Error
"Client error occurred: " Error -> Error -> Error
forall a. Semigroup a => a -> a -> a
<> ByteString -> Error
decodeUtf8Lenient ByteString
details
            Roundtrip.ServerError Error ()
recvError ->
              Error
"Server error occurred: " Error -> Error -> Error
forall a. Semigroup a => a -> a -> a
<> String -> Error
forall a. IsString a => String -> a
fromString (Error () -> String
forall a. Show a => a -> String
show Error ()
recvError)
       in Either Error a -> IO (Either Error a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Error -> Either Error a
forall a b. a -> Either a b
Left Error
message)
    Right a
value -> Either Error a -> IO (Either Error a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Either Error a
forall a b. b -> Either a b
Right a
value)

-- * Executors

toHandler :: Session a -> Pq.Connection -> IO (Either Text a)
toHandler :: forall a. Session a -> Connection -> IO (Either Error a)
toHandler (Session Connection -> IO (Either Error a)
run) = Connection -> IO (Either Error a)
run