{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Network.Socket.BufferPool.Recv (
    receive
  , makeRecvN
  ) where

import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString(..), unsafeCreate)
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.Ptr (Ptr, castPtr)
import GHC.Conc (threadWaitRead)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (Fd(..))

#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Socket.BufferPool.Windows
#endif

import Network.Socket.BufferPool.Types
import Network.Socket.BufferPool.Buffer

----------------------------------------------------------------

-- | The receiving function with a buffer pool.
--   The buffer pool is automatically managed.
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> BufferPool -> Recv
receive Socket
sock BufferPool
pool = BufferPool -> (Buffer -> Int -> IO Int) -> Recv
withBufferPool BufferPool
pool ((Buffer -> Int -> IO Int) -> Recv)
-> (Buffer -> Int -> IO Int) -> Recv
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr Int
size -> do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO Int) -> IO Int
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO Int) -> IO Int) -> (CInt -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    let size' :: CSize
size' = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size
    CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryReceive CInt
fd Buffer
ptr CSize
size'

----------------------------------------------------------------

tryReceive :: CInt -> Buffer -> CSize -> IO CInt
tryReceive :: CInt -> Buffer -> CSize -> IO CInt
tryReceive CInt
sock Buffer
ptr CSize
size = IO CInt
go
  where
    go :: IO CInt
go = do
#ifdef mingw32_HOST_OS
      bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "tryReceive" (FD sock 1) (castPtr ptr) 0 size
#else
      CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Buffer -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Buffer
ptr) CSize
size CInt
0
#endif
      if CInt
bytes CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
          Errno
errno <- IO Errno
getErrno
          if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
              Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
              IO CInt
go
            else
              String -> IO CInt
forall a. String -> IO a
throwErrno String
"tryReceive"
         else
          CInt -> IO CInt
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes

----------------------------------------------------------------

-- | This function returns a receiving function
--   based on two receiving functions.
--   The returned function receives exactly N bytes.
--   The first argument is an initial received data.
--   After consuming the initial data, the two functions is used.
--   When N is less than equal to 4096, the buffer pool is used.
--   Otherwise, a new buffer is allocated.
--   In this case, the global lock is taken.
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN ByteString
bs0 Recv
recv = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    RecvN -> IO RecvN
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvN -> IO RecvN) -> RecvN -> IO RecvN
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvN
recvN IORef ByteString
ref Recv
recv

-- | The receiving function which receives exactly N bytes
--   (the fourth argument).
recvN :: IORef ByteString -> Recv -> RecvN
recvN :: IORef ByteString -> Recv -> RecvN
recvN IORef ByteString
ref Recv
recv Int
size = do
    ByteString
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    (ByteString
bs, ByteString
leftover) <- ByteString -> Int -> Recv -> IO (ByteString, ByteString)
tryRecvN ByteString
cached Int
size Recv
recv
    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
    ByteString -> Recv
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

tryRecvN :: ByteString -> Int -> IO ByteString -> IO (ByteString, ByteString)
tryRecvN :: ByteString -> Int -> Recv -> IO (ByteString, ByteString)
tryRecvN ByteString
init0 Int
siz0 Recv
recv
  | Int
siz0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
siz0 ByteString
init0
  | Bool
otherwise    = ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go (ByteString
init0ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) (Int
siz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len0)
  where
    len0 :: Int
len0 = ByteString -> Int
BS.length ByteString
init0
    go :: ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go [ByteString] -> [ByteString]
build Int
left = do
        ByteString
bs <- Recv
recv
        let len :: Int
len = ByteString -> Int
BS.length ByteString
bs
        if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
          else if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
left then do
            let (ByteString
consume, ByteString
leftover) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
left ByteString
bs
                ret :: ByteString
ret = Int -> [ByteString] -> ByteString
concatN Int
siz0 ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
build [ByteString
consume]
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
          else do
            let build' :: [ByteString] -> [ByteString]
build' = [ByteString] -> [ByteString]
build ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:)
                left' :: Int
left' = Int
left Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
            ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go [ByteString] -> [ByteString]
build' Int
left'

concatN :: Int -> [ByteString] -> ByteString
concatN :: Int -> [ByteString] -> ByteString
concatN Int
total [ByteString]
bss0 = Int -> (Buffer -> IO ()) -> ByteString
unsafeCreate Int
total ((Buffer -> IO ()) -> ByteString)
-> (Buffer -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr -> [ByteString] -> Buffer -> IO ()
goCopy [ByteString]
bss0 Buffer
ptr
  where
    goCopy :: [ByteString] -> Buffer -> IO ()
goCopy []       Buffer
_   = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    goCopy (ByteString
bs:[ByteString]
bss) Buffer
ptr = do
        Buffer
ptr' <- Buffer -> ByteString -> IO Buffer
copy Buffer
ptr ByteString
bs
        [ByteString] -> Buffer -> IO ()
goCopy [ByteString]
bss Buffer
ptr'

#ifndef mingw32_HOST_OS
-- fixme: the type of the return value
foreign import ccall unsafe "recv"
    c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif