module Network.WebSockets.Connection.PingPong
    ( withPingPong
    , PingPongOptions(..)
    , PongTimeout(..)
    , defaultPingPongOptions
    ) where 

import Control.Concurrent.Async as Async
import Control.Exception
import Control.Monad (void)
import Network.WebSockets.Connection (Connection, connectionHeartbeat, pingThread)
import Control.Concurrent.MVar (takeMVar)
import System.Timeout (timeout)


-- | Exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Int -> PongTimeout -> ShowS
[PongTimeout] -> ShowS
PongTimeout -> String
(Int -> PongTimeout -> ShowS)
-> (PongTimeout -> String)
-> ([PongTimeout] -> ShowS)
-> Show PongTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PongTimeout -> ShowS
showsPrec :: Int -> PongTimeout -> ShowS
$cshow :: PongTimeout -> String
show :: PongTimeout -> String
$cshowList :: [PongTimeout] -> ShowS
showList :: [PongTimeout] -> ShowS
Show

instance Exception PongTimeout


-- | Options for ping-pong
-- 
-- Make sure that the ping interval is less than the pong timeout,
-- for example N/2.
data PingPongOptions = PingPongOptions {
    PingPongOptions -> Int
pingInterval :: Int, -- ^ Interval in seconds
    PingPongOptions -> Int
pongTimeout :: Int, -- ^ Timeout in seconds
    PingPongOptions -> IO ()
pingAction :: IO () -- ^ Action to perform after sending a ping
}

-- | Default options for ping-pong
-- 
--   Ping every 15 seconds, timeout after 30 seconds
defaultPingPongOptions :: PingPongOptions
defaultPingPongOptions :: PingPongOptions
defaultPingPongOptions = PingPongOptions {
    pingInterval :: Int
pingInterval = Int
15,
    pongTimeout :: Int
pongTimeout = Int
30,
    pingAction :: IO ()
pingAction = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
}

-- | Run an application with ping-pong enabled. Raises PongTimeout if a pong is not received.
-- 
-- Can used with Client and Server connections.
withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO ()
withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO ()
withPingPong PingPongOptions
options Connection
connection Connection -> IO ()
app = IO (Async (), ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Async (), ()) -> IO ()) -> IO (Async (), ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ 
    IO () -> (Async () -> IO (Async (), ())) -> IO (Async (), ())
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Connection -> IO ()
app Connection
connection) ((Async () -> IO (Async (), ())) -> IO (Async (), ()))
-> (Async () -> IO (Async (), ())) -> IO (Async (), ())
forall a b. (a -> b) -> a -> b
$ \Async ()
appAsync -> do
        IO () -> (Async () -> IO (Async (), ())) -> IO (Async (), ())
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Connection -> Int -> IO () -> IO ()
pingThread Connection
connection (PingPongOptions -> Int
pingInterval PingPongOptions
options) (PingPongOptions -> IO ()
pingAction PingPongOptions
options)) ((Async () -> IO (Async (), ())) -> IO (Async (), ()))
-> (Async () -> IO (Async (), ())) -> IO (Async (), ())
forall a b. (a -> b) -> a -> b
$ \Async ()
pingAsync -> do
            IO () -> (Async () -> IO (Async (), ())) -> IO (Async (), ())
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (IO ()
heartbeat IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PongTimeout -> IO ()
forall e a. Exception e => e -> IO a
throwIO PongTimeout
PongTimeout) ((Async () -> IO (Async (), ())) -> IO (Async (), ()))
-> (Async () -> IO (Async (), ())) -> IO (Async (), ())
forall a b. (a -> b) -> a -> b
$ \Async ()
heartbeatAsync -> do
                [Async ()] -> IO (Async (), ())
forall a. [Async a] -> IO (Async a, a)
waitAnyCancel [Async ()
appAsync, Async ()
pingAsync, Async ()
heartbeatAsync]
    where
        heartbeat :: IO ()
heartbeat = IO (Maybe ()) -> IO ()
forall a. IO (Maybe a) -> IO ()
whileJust (IO (Maybe ()) -> IO ()) -> IO (Maybe ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout (PingPongOptions -> Int
pongTimeout PingPongOptions
options Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000) 
           (IO () -> IO (Maybe ())) -> IO () -> IO (Maybe ())
forall a b. (a -> b) -> a -> b
$ MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar (Connection -> MVar ()
connectionHeartbeat Connection
connection)

        -- Loop until action returns Nothing
        whileJust :: IO (Maybe a) -> IO ()
        whileJust :: forall a. IO (Maybe a) -> IO ()
whileJust IO (Maybe a)
action = do
            Maybe a
result <- IO (Maybe a)
action
            case Maybe a
result of
                Maybe a
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                Just a
_ -> IO (Maybe a) -> IO ()
forall a. IO (Maybe a) -> IO ()
whileJust IO (Maybe a)
action