{- This file is part of time-out.
 -
 - Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
 -
 - ♡ Copying is an act of love. Please copy, reuse and share.
 -
 - The author(s) have dedicated all copyright and related and neighboring
 - rights to this software to the public domain worldwide. This software is
 - distributed without any warranty.
 -
 - You should have received a copy of the CC0 Public Domain Dedication along
 - with this software. If not, see
 - <http://creativecommons.org/publicdomain/zero/1.0/>.
 -}

{-# LANGUAGE MultiParamTypeClasses #-}

module Control.Timeout
    ( timeout
    , delay
    )
where

import Control.Concurrent
import Control.Monad (when)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.Timeout.Class
import Data.List (genericReplicate)
import Data.Maybe (isJust)
import Data.Time.Units

data Timeout' = Timeout' deriving Int -> Timeout' -> ShowS
[Timeout'] -> ShowS
Timeout' -> String
(Int -> Timeout' -> ShowS)
-> (Timeout' -> String) -> ([Timeout'] -> ShowS) -> Show Timeout'
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Timeout' -> ShowS
showsPrec :: Int -> Timeout' -> ShowS
$cshow :: Timeout' -> String
show :: Timeout' -> String
$cshowList :: [Timeout'] -> ShowS
showList :: [Timeout'] -> ShowS
Show

instance Exception Timeout'

instance MonadTimeout IO IO where
    timeoutThrow :: forall t a. TimeUnit t => t -> IO a -> IO a
timeoutThrow t
t IO a
act = do
        Maybe a
result <- t -> IO a -> IO (Maybe a)
forall t a. TimeUnit t => t -> IO a -> IO (Maybe a)
forall (m :: * -> *) (p :: * -> *) t a.
(MonadTimeout m p, TimeUnit t) =>
t -> p a -> m (Maybe a)
timeoutCatch t
t IO a
act
        case Maybe a
result of
            Maybe a
Nothing -> Timeout' -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM Timeout'
Timeout'
            Just a
a  -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

    timeoutCatch :: forall t a. TimeUnit t => t -> IO a -> IO (Maybe a)
timeoutCatch = t -> IO a -> IO (Maybe a)
forall t (m :: * -> *) a.
(TimeUnit t, MonadIO m, MonadCatch m) =>
t -> m a -> m (Maybe a)
timeout

-- | If the action succeeds, return 'Just' the result. If a timeout exception
-- is thrown during the action, catch it and return 'Nothing'. Other exceptions
-- aren't caught.
catchTimeout :: (MonadIO m, MonadCatch m) => m a -> m (Maybe a)
catchTimeout :: forall (m :: * -> *) a.
(MonadIO m, MonadCatch m) =>
m a -> m (Maybe a)
catchTimeout m a
action = m (Maybe a) -> (Timeout' -> m (Maybe a)) -> m (Maybe a)
forall e a. (HasCallStack, Exception e) => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
catch (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
action) ((Timeout' -> m (Maybe a)) -> m (Maybe a))
-> (Timeout' -> m (Maybe a)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \ Timeout'
Timeout' -> Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing

-- | Run a monadic action with a time limit. If it finishes before that time
-- passes and returns value @x@, then @Just x@ is returned. If the timeout
-- passes, the action is aborted and @Nothing@ is returned. If the action
-- throws an exception, it is aborted and the exception is rethrown.
--
-- >>> timeout (3 :: Second) $ delay (1 :: Second) >> return "hello"
-- Just "hello"
--
-- >>> timeout (3 :: Second) $ delay (5 :: Second) >> return "hello"
-- Nothing
--
-- >>> timeout (1 :: Second) $ error "hello"
-- *** Exception: hello
timeout :: (TimeUnit t, MonadIO m, MonadCatch m) => t -> m a -> m (Maybe a)
timeout :: forall t (m :: * -> *) a.
(TimeUnit t, MonadIO m, MonadCatch m) =>
t -> m a -> m (Maybe a)
timeout t
time m a
action = do
    ThreadId
tidMain <- IO ThreadId -> m ThreadId
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ThreadId
myThreadId
    ThreadId
tidTemp <- IO ThreadId -> m ThreadId
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ThreadId -> m ThreadId) -> IO ThreadId -> m ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ t -> IO ()
forall t (m :: * -> *). (TimeUnit t, MonadIO m) => t -> m ()
delay t
time IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ThreadId -> Timeout' -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
tidMain Timeout'
Timeout'
    Maybe a
result <- m a -> m (Maybe a)
forall (m :: * -> *) a.
(MonadIO m, MonadCatch m) =>
m a -> m (Maybe a)
catchTimeout m a
action m (Maybe a) -> m () -> m (Maybe a)
forall (m :: * -> *) a b.
(HasCallStack, MonadCatch m) =>
m a -> m b -> m a
`onException` IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (ThreadId -> IO ()
killThread ThreadId
tidTemp)
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe a -> Bool
forall a. Maybe a -> Bool
isJust Maybe a
result) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO ()
killThread ThreadId
tidTemp
    Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
result

delayInt :: MonadIO m => Int -> m ()
delayInt :: forall (m :: * -> *). MonadIO m => Int -> m ()
delayInt Int
usec = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
usec

delayInteger :: MonadIO m => Integer -> m ()
delayInteger :: forall (m :: * -> *). MonadIO m => Integer -> m ()
delayInteger Integer
usec =
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
usec Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        let maxInt :: Int
maxInt = Int
forall a. Bounded a => a
maxBound :: Int
            (Integer
times, Integer
rest) = Integer
usec Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
maxInt
        [m ()] -> m ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([m ()] -> m ()) -> [m ()] -> m ()
forall a b. (a -> b) -> a -> b
$ Integer -> m () -> [m ()]
forall i a. Integral i => i -> a -> [a]
genericReplicate Integer
times (m () -> [m ()]) -> m () -> [m ()]
forall a b. (a -> b) -> a -> b
$ Int -> m ()
forall (m :: * -> *). MonadIO m => Int -> m ()
delayInt Int
maxInt
        Int -> m ()
forall (m :: * -> *). MonadIO m => Int -> m ()
delayInt (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
rest

-- | Suspend the current thread for the given amount of time.
--
-- Example:
--
-- > delay (5 :: Second)
delay :: (TimeUnit t, MonadIO m) => t -> m ()
delay :: forall t (m :: * -> *). (TimeUnit t, MonadIO m) => t -> m ()
delay = Integer -> m ()
forall (m :: * -> *). MonadIO m => Integer -> m ()
delayInteger (Integer -> m ()) -> (t -> Integer) -> t -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Integer
forall a. TimeUnit a => a -> Integer
toMicroseconds