--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
{-# LANGUAGE CPP #-}
module Network.WebSockets.Stream
    ( Stream
    , makeStream
    , makeSocketStream
    , makeEchoStream
    , parse
    , parseBin
    , write
    , close
    ) where

import           Control.Concurrent.MVar        (MVar, newEmptyMVar, newMVar,
                                                 putMVar, takeMVar, withMVar)
import           Control.Exception              (SomeException, SomeAsyncException, throwIO, catch, fromException)
import           Control.Monad                  (forM_)
import qualified Data.Attoparsec.ByteString     as Atto
import qualified Data.Binary.Get                as BIN
import qualified Data.ByteString                as B
import qualified Data.ByteString.Lazy           as BL
import           Data.IORef                     (IORef, atomicModifyIORef',
                                                 newIORef, readIORef,
                                                 writeIORef)
import qualified Network.Socket                 as S
import qualified Network.Socket.ByteString      as SB (recv)

#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString.Lazy as SBL (sendAll)
#else
import qualified Network.Socket.ByteString      as SB (sendAll)
#endif

import           Network.WebSockets.Types


--------------------------------------------------------------------------------
-- | State of the stream
data StreamState
    = Closed !B.ByteString  -- Remainder
    | Open   !B.ByteString  -- Buffer


--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
data Stream = Stream
    { Stream -> IO (Maybe ByteString)
streamIn    :: IO (Maybe B.ByteString)
    , Stream -> Maybe ByteString -> IO ()
streamOut   :: (Maybe BL.ByteString -> IO ())
    , Stream -> IORef StreamState
streamState :: !(IORef StreamState)
    }


--------------------------------------------------------------------------------
-- | Create a stream from a "receive" and "send" action. The following
-- properties apply:
--
-- - Regardless of the provided "receive" and "send" functions, reading and
--   writing from the stream will be thread-safe, i.e. this function will create
--   a receive and write lock to be used internally.
--
-- - Reading from or writing or to a closed 'Stream' will always throw an
--   exception, even if the underlying "receive" and "send" functions do not
--   (we do the bookkeeping).
--
-- - Streams should always be closed.
makeStream
    :: IO (Maybe B.ByteString)         -- ^ Reading
    -> (Maybe BL.ByteString -> IO ())  -- ^ Writing
    -> IO Stream                       -- ^ Resulting stream
makeStream :: IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream IO (Maybe ByteString)
receive Maybe ByteString -> IO ()
send = do
    IORef StreamState
ref         <- StreamState -> IO (IORef StreamState)
forall a. a -> IO (IORef a)
newIORef (ByteString -> StreamState
Open ByteString
B.empty)
    MVar ()
receiveLock <- () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()
    MVar ()
sendLock    <- () -> IO (MVar ())
forall a. a -> IO (MVar a)
newMVar ()
    Stream -> IO Stream
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Stream -> IO Stream) -> Stream -> IO Stream
forall a b. (a -> b) -> a -> b
$ IO (Maybe ByteString)
-> (Maybe ByteString -> IO ()) -> IORef StreamState -> Stream
Stream (IORef StreamState -> MVar () -> IO (Maybe ByteString)
receive' IORef StreamState
ref MVar ()
receiveLock) (IORef StreamState -> MVar () -> Maybe ByteString -> IO ()
send' IORef StreamState
ref MVar ()
sendLock) IORef StreamState
ref
  where
    closeRef :: IORef StreamState -> IO ()
    closeRef :: IORef StreamState -> IO ()
closeRef IORef StreamState
ref = IORef StreamState -> (StreamState -> (StreamState, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef StreamState
ref ((StreamState -> (StreamState, ())) -> IO ())
-> (StreamState -> (StreamState, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \StreamState
state -> case StreamState
state of
        Open   ByteString
buf -> (ByteString -> StreamState
Closed ByteString
buf, ())
        Closed ByteString
buf -> (ByteString -> StreamState
Closed ByteString
buf, ())

    -- Throw a 'ConnectionClosed' is the connection is not 'Open'.
    assertOpen :: IORef StreamState -> IO ()
    assertOpen :: IORef StreamState -> IO ()
assertOpen IORef StreamState
ref = do
        StreamState
state <- IORef StreamState -> IO StreamState
forall a. IORef a -> IO a
readIORef IORef StreamState
ref
        case StreamState
state of
            Closed ByteString
_ -> ConnectionException -> IO ()
forall e a. Exception e => e -> IO a
throwIO ConnectionException
ConnectionClosed
            Open   ByteString
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    receive' :: IORef StreamState -> MVar () -> IO (Maybe B.ByteString)
    receive' :: IORef StreamState -> MVar () -> IO (Maybe ByteString)
receive' IORef StreamState
ref MVar ()
lock = MVar () -> (() -> IO (Maybe ByteString)) -> IO (Maybe ByteString)
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar ()
lock ((() -> IO (Maybe ByteString)) -> IO (Maybe ByteString))
-> (() -> IO (Maybe ByteString)) -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ \() -> do
        IORef StreamState -> IO ()
assertOpen IORef StreamState
ref
        Maybe ByteString
mbBs <- IO (Maybe ByteString) -> IO () -> IO (Maybe ByteString)
forall a b. IO a -> IO b -> IO a
onSyncException IO (Maybe ByteString)
receive (IORef StreamState -> IO ()
closeRef IORef StreamState
ref)
        case Maybe ByteString
mbBs of
            Maybe ByteString
Nothing -> IORef StreamState -> IO ()
closeRef IORef StreamState
ref IO () -> IO (Maybe ByteString) -> IO (Maybe ByteString)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
            Just ByteString
bs -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs)

    send' :: IORef StreamState -> MVar () -> (Maybe BL.ByteString -> IO ())
    send' :: IORef StreamState -> MVar () -> Maybe ByteString -> IO ()
send' IORef StreamState
ref MVar ()
lock Maybe ByteString
mbBs = MVar () -> (() -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar ()
lock ((() -> IO ()) -> IO ()) -> (() -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \() -> do
        case Maybe ByteString
mbBs of
            Maybe ByteString
Nothing -> IORef StreamState -> IO ()
closeRef IORef StreamState
ref
            Just ByteString
_  -> IORef StreamState -> IO ()
assertOpen IORef StreamState
ref
        IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
onSyncException (Maybe ByteString -> IO ()
send Maybe ByteString
mbBs) (IORef StreamState -> IO ()
closeRef IORef StreamState
ref)

    onSyncException :: IO a -> IO b -> IO a
    onSyncException :: forall a b. IO a -> IO b -> IO a
onSyncException IO a
io IO b
what =
        IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO a
io ((SomeException -> IO a) -> IO a)
-> (SomeException -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \SomeException
e -> do
            case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException (SomeException
e :: SomeException) :: Maybe SomeAsyncException of
                Just SomeAsyncException
_  -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Maybe SomeAsyncException
Nothing -> IO b
what IO b -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
e


--------------------------------------------------------------------------------
makeSocketStream :: S.Socket -> IO Stream
makeSocketStream :: Socket -> IO Stream
makeSocketStream Socket
socket = IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream IO (Maybe ByteString)
receive Maybe ByteString -> IO ()
send
  where
    receive :: IO (Maybe ByteString)
receive = do
        ByteString
bs <- Socket -> Int -> IO ByteString
SB.recv Socket
socket Int
8192
        Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ if ByteString -> Bool
B.null ByteString
bs then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs

    send :: Maybe ByteString -> IO ()
send Maybe ByteString
Nothing   = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    send (Just ByteString
bs) = do
#if !defined(mingw32_HOST_OS)
        Socket -> ByteString -> IO ()
SBL.sendAll Socket
socket ByteString
bs
#else
        forM_ (BL.toChunks bs) (SB.sendAll socket)
#endif


--------------------------------------------------------------------------------
makeEchoStream :: IO Stream
makeEchoStream :: IO Stream
makeEchoStream = do
    MVar (Maybe ByteString)
mvar <- IO (MVar (Maybe ByteString))
forall a. IO (MVar a)
newEmptyMVar
    IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream (MVar (Maybe ByteString) -> IO (Maybe ByteString)
forall a. MVar a -> IO a
takeMVar MVar (Maybe ByteString)
mvar) ((Maybe ByteString -> IO ()) -> IO Stream)
-> (Maybe ByteString -> IO ()) -> IO Stream
forall a b. (a -> b) -> a -> b
$ \Maybe ByteString
mbBs -> case Maybe ByteString
mbBs of
        Maybe ByteString
Nothing -> MVar (Maybe ByteString) -> Maybe ByteString -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe ByteString)
mvar Maybe ByteString
forall a. Maybe a
Nothing
        Just ByteString
bs -> [ByteString] -> (ByteString -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ByteString -> [ByteString]
BL.toChunks ByteString
bs) ((ByteString -> IO ()) -> IO ()) -> (ByteString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ByteString
c -> MVar (Maybe ByteString) -> Maybe ByteString -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe ByteString)
mvar (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
c)


--------------------------------------------------------------------------------
parseBin :: Stream -> BIN.Get a -> IO (Maybe a)
parseBin :: forall a. Stream -> Get a -> IO (Maybe a)
parseBin Stream
stream Get a
parser = do
    StreamState
state <- IORef StreamState -> IO StreamState
forall a. IORef a -> IO a
readIORef (Stream -> IORef StreamState
streamState Stream
stream)
    case StreamState
state of
        Closed ByteString
remainder
            | ByteString -> Bool
B.null ByteString
remainder -> Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
            | Bool
otherwise        -> Decoder a -> Bool -> IO (Maybe a)
forall {a}. Decoder a -> Bool -> IO (Maybe a)
go (Get a -> Decoder a
forall a. Get a -> Decoder a
BIN.runGetIncremental Get a
parser Decoder a -> ByteString -> Decoder a
forall a. Decoder a -> ByteString -> Decoder a
`BIN.pushChunk` ByteString
remainder) Bool
True
        Open ByteString
buffer
            | ByteString -> Bool
B.null ByteString
buffer -> do
                Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
                case Maybe ByteString
mbBs of
                    Maybe ByteString
Nothing -> do
                        IORef StreamState -> StreamState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) (ByteString -> StreamState
Closed ByteString
B.empty)
                        Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
                    Just ByteString
bs -> Decoder a -> Bool -> IO (Maybe a)
forall {a}. Decoder a -> Bool -> IO (Maybe a)
go (Get a -> Decoder a
forall a. Get a -> Decoder a
BIN.runGetIncremental Get a
parser Decoder a -> ByteString -> Decoder a
forall a. Decoder a -> ByteString -> Decoder a
`BIN.pushChunk` ByteString
bs) Bool
False
            | Bool
otherwise     -> Decoder a -> Bool -> IO (Maybe a)
forall {a}. Decoder a -> Bool -> IO (Maybe a)
go (Get a -> Decoder a
forall a. Get a -> Decoder a
BIN.runGetIncremental Get a
parser Decoder a -> ByteString -> Decoder a
forall a. Decoder a -> ByteString -> Decoder a
`BIN.pushChunk` ByteString
buffer) Bool
False
  where
    -- Buffer is empty when entering this function.
    go :: Decoder a -> Bool -> IO (Maybe a)
go (BIN.Done ByteString
remainder ByteOffset
_ a
x) Bool
closed = do
        IORef StreamState -> StreamState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) (StreamState -> IO ()) -> StreamState -> IO ()
forall a b. (a -> b) -> a -> b
$
            if Bool
closed then ByteString -> StreamState
Closed ByteString
remainder else ByteString -> StreamState
Open ByteString
remainder
        Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
x)
    go (BIN.Partial Maybe ByteString -> Decoder a
f) Bool
closed
        | Bool
closed    = Decoder a -> Bool -> IO (Maybe a)
go (Maybe ByteString -> Decoder a
f Maybe ByteString
forall a. Maybe a
Nothing) Bool
True
        | Bool
otherwise = do
            Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
            case Maybe ByteString
mbBs of
                Maybe ByteString
Nothing -> Decoder a -> Bool -> IO (Maybe a)
go (Maybe ByteString -> Decoder a
f Maybe ByteString
forall a. Maybe a
Nothing) Bool
True
                Just ByteString
bs -> Decoder a -> Bool -> IO (Maybe a)
go (Maybe ByteString -> Decoder a
f (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs)) Bool
False
    go (BIN.Fail ByteString
_ ByteOffset
_ String
err) Bool
_ = ConnectionException -> IO (Maybe a)
forall e a. Exception e => e -> IO a
throwIO (String -> ConnectionException
ParseException String
err)


parse :: Stream -> Atto.Parser a -> IO (Maybe a)
parse :: forall a. Stream -> Parser a -> IO (Maybe a)
parse Stream
stream Parser a
parser = do
    StreamState
state <- IORef StreamState -> IO StreamState
forall a. IORef a -> IO a
readIORef (Stream -> IORef StreamState
streamState Stream
stream)
    case StreamState
state of
        Closed ByteString
remainder
            | ByteString -> Bool
B.null ByteString
remainder -> Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
            | Bool
otherwise        -> IResult ByteString a -> Bool -> IO (Maybe a)
forall {a}. IResult ByteString a -> Bool -> IO (Maybe a)
go (Parser a -> ByteString -> IResult ByteString a
forall a. Parser a -> ByteString -> Result a
Atto.parse Parser a
parser ByteString
remainder) Bool
True
        Open ByteString
buffer
            | ByteString -> Bool
B.null ByteString
buffer -> do
                Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
                case Maybe ByteString
mbBs of
                    Maybe ByteString
Nothing -> do
                        IORef StreamState -> StreamState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) (ByteString -> StreamState
Closed ByteString
B.empty)
                        Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
                    Just ByteString
bs -> IResult ByteString a -> Bool -> IO (Maybe a)
forall {a}. IResult ByteString a -> Bool -> IO (Maybe a)
go (Parser a -> ByteString -> IResult ByteString a
forall a. Parser a -> ByteString -> Result a
Atto.parse Parser a
parser ByteString
bs) Bool
False
            | Bool
otherwise     -> IResult ByteString a -> Bool -> IO (Maybe a)
forall {a}. IResult ByteString a -> Bool -> IO (Maybe a)
go (Parser a -> ByteString -> IResult ByteString a
forall a. Parser a -> ByteString -> Result a
Atto.parse Parser a
parser ByteString
buffer) Bool
False
  where
    -- Buffer is empty when entering this function.
    go :: IResult ByteString a -> Bool -> IO (Maybe a)
go (Atto.Done ByteString
remainder a
x) Bool
closed = do
        IORef StreamState -> StreamState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) (StreamState -> IO ()) -> StreamState -> IO ()
forall a b. (a -> b) -> a -> b
$
            if Bool
closed then ByteString -> StreamState
Closed ByteString
remainder else ByteString -> StreamState
Open ByteString
remainder
        Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
x)
    go (Atto.Partial ByteString -> IResult ByteString a
f) Bool
closed
        | Bool
closed    = IResult ByteString a -> Bool -> IO (Maybe a)
go (ByteString -> IResult ByteString a
f ByteString
B.empty) Bool
True
        | Bool
otherwise = do
            Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
            case Maybe ByteString
mbBs of
                Maybe ByteString
Nothing -> IResult ByteString a -> Bool -> IO (Maybe a)
go (ByteString -> IResult ByteString a
f ByteString
B.empty) Bool
True
                Just ByteString
bs -> IResult ByteString a -> Bool -> IO (Maybe a)
go (ByteString -> IResult ByteString a
f ByteString
bs) Bool
False
    go (Atto.Fail ByteString
_ [String]
_ String
err) Bool
_ = ConnectionException -> IO (Maybe a)
forall e a. Exception e => e -> IO a
throwIO (String -> ConnectionException
ParseException String
err)


--------------------------------------------------------------------------------
write :: Stream -> BL.ByteString -> IO ()
write :: Stream -> ByteString -> IO ()
write Stream
stream = Stream -> Maybe ByteString -> IO ()
streamOut Stream
stream (Maybe ByteString -> IO ())
-> (ByteString -> Maybe ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just


--------------------------------------------------------------------------------
close :: Stream -> IO ()
close :: Stream -> IO ()
close Stream
stream = Stream -> Maybe ByteString -> IO ()
streamOut Stream
stream Maybe ByteString
forall a. Maybe a
Nothing