{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Handler.WebSockets
    ( websocketsOr
    , websocketsApp
    , isWebSocketsReq
    , getRequestHead
    , runWebSockets
    ) where

import              Control.Exception               (bracket, tryJust)
import              Data.ByteString                 (ByteString)
import qualified    Data.ByteString.Char8           as BC
import qualified    Data.ByteString.Lazy            as BL
import qualified    Data.CaseInsensitive            as CI
import              Network.HTTP.Types              (status500)
import qualified    Network.Wai                     as Wai
import qualified    Network.WebSockets              as WS
import qualified    Network.WebSockets.Connection   as WS
import qualified    Network.WebSockets.Stream       as WS

--------------------------------------------------------------------------------
-- | Returns whether or not the given 'Wai.Request' is a WebSocket request.
isWebSocketsReq :: Wai.Request -> Bool
isWebSocketsReq :: Request -> Bool
isWebSocketsReq Request
req =
    (ByteString -> CI ByteString)
-> Maybe ByteString -> Maybe (CI ByteString)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk (CI ByteString -> [(CI ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
"upgrade" ([(CI ByteString, ByteString)] -> Maybe ByteString)
-> [(CI ByteString, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(CI ByteString, ByteString)]
Wai.requestHeaders Request
req) Maybe (CI ByteString) -> Maybe (CI ByteString) -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString -> Maybe (CI ByteString)
forall a. a -> Maybe a
Just CI ByteString
"websocket"

--------------------------------------------------------------------------------
-- | Upgrade a @websockets@ 'WS.ServerApp' to a @wai@ 'Wai.Application'. Uses
-- the given backup 'Wai.Application' to handle 'Wai.Request's that are not
-- WebSocket requests.
--
-- @
-- websocketsOr opts ws_app backup_app = \\req respond ->
--     __case__ 'websocketsApp' opts ws_app req __of__
--         'Nothing'  -> backup_app req send_response
--         'Just' res -> respond res
-- @
--
-- For example, below is an 'Wai.Application' that sends @"Hello, client!"@ to
-- each connected client.
--
-- @
-- app :: 'Wai.Application'
-- app = 'websocketsOr' 'WS.defaultConnectionOptions' wsApp backupApp
--   __where__
--     wsApp :: 'WS.ServerApp'
--     wsApp pending_conn = do
--         conn <- 'WS.acceptRequest' pending_conn
--         'WS.sendTextData' conn ("Hello, client!" :: 'Data.Text.Text')
--
--     backupApp :: 'Wai.Application'
--     backupApp _ respond = respond $ 'Wai.responseLBS' 'Network.HTTP.Types.status400' [] "Not a WebSocket request"
-- @
websocketsOr :: WS.ConnectionOptions
             -> WS.ServerApp
             -> Wai.Application
             -> Wai.Application
websocketsOr :: ConnectionOptions -> ServerApp -> Application -> Application
websocketsOr ConnectionOptions
opts ServerApp
app Application
backup Request
req Response -> IO ResponseReceived
sendResponse =
    case ConnectionOptions -> ServerApp -> Request -> Maybe Response
websocketsApp ConnectionOptions
opts ServerApp
app Request
req of
        Maybe Response
Nothing -> Application
backup Request
req Response -> IO ResponseReceived
sendResponse
        Just Response
res -> Response -> IO ResponseReceived
sendResponse Response
res

--------------------------------------------------------------------------------
-- | Handle a single @wai@ 'Wai.Request' with the given @websockets@
-- 'WS.ServerApp'. Returns 'Nothing' if the 'Wai.Request' is not a WebSocket
-- request, 'Just' otherwise.
--
-- Usually, 'websocketsOr' is more convenient.
websocketsApp :: WS.ConnectionOptions
              -> WS.ServerApp
              -> Wai.Request
              -> Maybe Wai.Response
websocketsApp :: ConnectionOptions -> ServerApp -> Request -> Maybe Response
websocketsApp ConnectionOptions
opts ServerApp
app Request
req
    | Request -> Bool
isWebSocketsReq Request
req =
        Response -> Maybe Response
forall a. a -> Maybe a
Just (Response -> Maybe Response) -> Response -> Maybe Response
forall a b. (a -> b) -> a -> b
$ ((IO ByteString -> (ByteString -> IO ()) -> IO ())
 -> Response -> Response)
-> Response
-> (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response -> Response
Wai.responseRaw Response
backup ((IO ByteString -> (ByteString -> IO ()) -> IO ()) -> Response)
-> (IO ByteString -> (ByteString -> IO ()) -> IO ()) -> Response
forall a b. (a -> b) -> a -> b
$ \IO ByteString
src ByteString -> IO ()
sink ->
            ConnectionOptions
-> RequestHead
-> ServerApp
-> IO ByteString
-> (ByteString -> IO ())
-> IO ()
forall a.
ConnectionOptions
-> RequestHead
-> (PendingConnection -> IO a)
-> IO ByteString
-> (ByteString -> IO ())
-> IO a
runWebSockets ConnectionOptions
opts RequestHead
req' ServerApp
app IO ByteString
src ByteString -> IO ()
sink
    | Bool
otherwise = Maybe Response
forall a. Maybe a
Nothing
  where
    req' :: RequestHead
req' = Request -> RequestHead
getRequestHead Request
req
    backup :: Response
backup = Status -> [(CI ByteString, ByteString)] -> ByteString -> Response
Wai.responseLBS Status
status500 [(CI ByteString
"Content-Type", ByteString
"text/plain")]
                ByteString
"The web application attempted to send a WebSockets response, but WebSockets are not supported by your WAI handler."

--------------------------------------------------------------------------------
getRequestHead :: Wai.Request -> WS.RequestHead
getRequestHead :: Request -> RequestHead
getRequestHead Request
req = ByteString -> [(CI ByteString, ByteString)] -> Bool -> RequestHead
WS.RequestHead
    (Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
`BC.append` Request -> ByteString
Wai.rawQueryString Request
req)
    (Request -> [(CI ByteString, ByteString)]
Wai.requestHeaders Request
req)
    (Request -> Bool
Wai.isSecure Request
req)

--------------------------------------------------------------------------------
-- | Internal function to run the WebSocket io-streams using the conduit library.
runWebSockets :: WS.ConnectionOptions
              -> WS.RequestHead
              -> (WS.PendingConnection -> IO a)
              -> IO ByteString
              -> (ByteString -> IO ())
              -> IO a
runWebSockets :: forall a.
ConnectionOptions
-> RequestHead
-> (PendingConnection -> IO a)
-> IO ByteString
-> (ByteString -> IO ())
-> IO a
runWebSockets ConnectionOptions
opts RequestHead
req PendingConnection -> IO a
app IO ByteString
src ByteString -> IO ()
sink = IO Stream
-> (Stream -> IO (Either () ())) -> (Stream -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Stream
mkStream Stream -> IO (Either () ())
ensureClose (PendingConnection -> IO a
app (PendingConnection -> IO a)
-> (Stream -> PendingConnection) -> Stream -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> PendingConnection
pc)
  where
    ensureClose :: Stream -> IO (Either () ())
ensureClose = (ConnectionException -> Maybe ()) -> IO () -> IO (Either () ())
forall e b a.
Exception e =>
(e -> Maybe b) -> IO a -> IO (Either b a)
tryJust ConnectionException -> Maybe ()
onConnectionException (IO () -> IO (Either () ()))
-> (Stream -> IO ()) -> Stream -> IO (Either () ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> IO ()
WS.close
    onConnectionException :: WS.ConnectionException -> Maybe ()
    onConnectionException :: ConnectionException -> Maybe ()
onConnectionException ConnectionException
WS.ConnectionClosed = () -> Maybe ()
forall a. a -> Maybe a
Just ()
    onConnectionException ConnectionException
_                   = Maybe ()
forall a. Maybe a
Nothing
    mkStream :: IO Stream
mkStream =
        IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
WS.makeStream
            (do
                ByteString
bs <- IO ByteString
src
                Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ if ByteString -> Bool
BC.null ByteString
bs then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs)
            (\Maybe ByteString
mbBl -> case Maybe ByteString
mbBl of
                Maybe ByteString
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                Just ByteString
bl -> (ByteString -> IO ()) -> [ByteString] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ByteString -> IO ()
sink (ByteString -> [ByteString]
BL.toChunks ByteString
bl))

    pc :: Stream -> PendingConnection
pc Stream
stream = WS.PendingConnection
        { pendingOptions :: ConnectionOptions
WS.pendingOptions     = ConnectionOptions
opts
        , pendingRequest :: RequestHead
WS.pendingRequest     = RequestHead
req
        , pendingOnAccept :: Connection -> IO ()
WS.pendingOnAccept    = \Connection
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , pendingStream :: Stream
WS.pendingStream      = Stream
stream
        }