{-# LANGUAGE TupleSections #-}
module Foundation.Monad.State
    ( -- * MonadState
      MonadState(..)
    , get
    , put

    , -- * StateT
      StateT
    , runStateT
    ) where

import Basement.Compat.Bifunctor (first)
import Basement.Compat.Base (($), (.), const)
import Foundation.Monad.Base
import Control.Monad ((>=>))

class Monad m => MonadState m where
    type State m
    withState :: (State m -> (a, State m)) -> m a

get :: MonadState m => m (State m)
get :: forall (m :: * -> *). MonadState m => m (State m)
get = (State m -> (State m, State m)) -> m (State m)
forall a. (State m -> (a, State m)) -> m a
forall (m :: * -> *) a.
MonadState m =>
(State m -> (a, State m)) -> m a
withState ((State m -> (State m, State m)) -> m (State m))
-> (State m -> (State m, State m)) -> m (State m)
forall a b. (a -> b) -> a -> b
$ \State m
s -> (State m
s, State m
s)

put :: MonadState m => State m -> m ()
put :: forall (m :: * -> *). MonadState m => State m -> m ()
put State m
s = (State m -> ((), State m)) -> m ()
forall a. (State m -> (a, State m)) -> m a
forall (m :: * -> *) a.
MonadState m =>
(State m -> (a, State m)) -> m a
withState ((State m -> ((), State m)) -> m ())
-> (State m -> ((), State m)) -> m ()
forall a b. (a -> b) -> a -> b
$ ((), State m) -> State m -> ((), State m)
forall a b. a -> b -> a
const ((), State m
s)

-- | State Transformer
newtype StateT s m a = StateT { forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT :: s -> m (a, s) }

instance Functor m => Functor (StateT s m) where
    fmap :: forall a b. (a -> b) -> StateT s m a -> StateT s m b
fmap a -> b
f StateT s m a
m = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s1 -> ((a -> b) -> (a, s) -> (b, s)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first a -> b
f) ((a, s) -> (b, s)) -> m (a, s) -> m (b, s)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s1
    {-# INLINE fmap #-}

instance (Applicative m, Monad m) => Applicative (StateT s m) where
    pure :: forall a. a -> StateT s m a
pure a
a     = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> (,s
s) (a -> (a, s)) -> m a -> m (a, s)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
    {-# INLINE pure #-}
    StateT s m (a -> b)
fab <*> :: forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
<*> StateT s m a
fa = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s1 -> do
        (a -> b
ab,s
s2) <- StateT s m (a -> b) -> s -> m (a -> b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m (a -> b)
fab s
s1
        (a
a, s
s3) <- StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
fa s
s2
        (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
ab a
a, s
s3)
    {-# INLINE (<*>) #-}

instance (Functor m, Monad m) => Monad (StateT s m) where
    return :: forall a. a -> StateT s m a
return = a -> StateT s m a
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    {-# INLINE return #-}
    StateT s m a
ma >>= :: forall a b. StateT s m a -> (a -> StateT s m b) -> StateT s m b
>>= a -> StateT s m b
mab = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
ma (s -> m (a, s)) -> ((a, s) -> m (b, s)) -> s -> m (b, s)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (\(a
a, s
s2) -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> StateT s m b
mab a
a) s
s2)
    {-# INLINE (>>=) #-}

instance (Functor m, MonadFix m) => MonadFix (StateT s m) where
    mfix :: forall a. (a -> StateT s m a) -> StateT s m a
mfix a -> StateT s m a
f = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> ((a, s) -> m (a, s)) -> m (a, s)
forall a. (a -> m a) -> m a
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (((a, s) -> m (a, s)) -> m (a, s))
-> ((a, s) -> m (a, s)) -> m (a, s)
forall a b. (a -> b) -> a -> b
$ \ ~(a
a, s
_) -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> StateT s m a
f a
a) s
s
    {-# INLINE mfix #-}

instance MonadTrans (StateT s) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> StateT s m a
lift m a
f = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> m a
f m a -> (a -> m (a, s)) -> m (a, s)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a, s) -> m (a, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, s) -> m (a, s)) -> (a -> (a, s)) -> a -> m (a, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (,s
s)
    {-# INLINE lift #-}

instance (Functor m, MonadIO m) => MonadIO (StateT s m) where
    liftIO :: forall a. IO a -> StateT s m a
liftIO IO a
f = m a -> StateT s m a
forall (m :: * -> *) a. Monad m => m a -> StateT s m a
forall (trans :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans trans, Monad m) =>
m a -> trans m a
lift (IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
f)
    {-# INLINE liftIO #-}

instance (Functor m, MonadFailure m) => MonadFailure (StateT s m) where
    type Failure (StateT s m) = Failure m
    mFail :: Failure (StateT s m) -> StateT s m ()
mFail Failure (StateT s m)
e = (s -> m ((), s)) -> StateT s m ()
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m ((), s)) -> StateT s m ())
-> (s -> m ((), s)) -> StateT s m ()
forall a b. (a -> b) -> a -> b
$ \s
s -> ((,s
s) (() -> ((), s)) -> m () -> m ((), s)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Failure m -> m ()
forall (m :: * -> *). MonadFailure m => Failure m -> m ()
mFail Failure m
Failure (StateT s m)
e)

instance (Functor m, MonadThrow m) => MonadThrow (StateT s m) where
    throw :: forall e a. Exception e => e -> StateT s m a
throw e
e = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
_ -> e -> m (a, s)
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw e
e

instance (Functor m, MonadCatch m) => MonadCatch (StateT s m) where
    catch :: forall e a.
Exception e =>
StateT s m a -> (e -> StateT s m a) -> StateT s m a
catch (StateT s -> m (a, s)
m) e -> StateT s m a
c = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s1 -> s -> m (a, s)
m s
s1 m (a, s) -> (e -> m (a, s)) -> m (a, s)
forall e a. Exception e => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\e
e -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (e -> StateT s m a
c e
e) s
s1)

instance (Functor m, Monad m) => MonadState (StateT s m) where
    type State (StateT s m) = s
    withState :: forall a.
(State (StateT s m) -> (a, State (StateT s m))) -> StateT s m a
withState State (StateT s m) -> (a, State (StateT s m))
f = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ (a, s) -> m (a, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, s) -> m (a, s)) -> (s -> (a, s)) -> s -> m (a, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. s -> (a, s)
State (StateT s m) -> (a, State (StateT s m))
f