{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Middleware.Gunzip (gunzip) where
import Control.Applicative
import Control.Exception (throwIO)
import Data.IORef
import Network.HTTP.Types (Header, hContentEncoding)
import Network.Wai (Middleware, Request, RequestBodyLength (ChunkedBody))
import Prelude
import qualified Data.ByteString as S
import qualified Data.Streaming.Zlib as Z
import qualified Network.Wai as Wai
gunzip :: Middleware
gunzip :: Middleware
gunzip Application
app Request
rq Response -> IO ResponseReceived
k
| Request -> Bool
isGzip Request
rq = IO Request
prepare IO Request
-> (Request -> IO ResponseReceived) -> IO ResponseReceived
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Application
-> (Response -> IO ResponseReceived)
-> Request
-> IO ResponseReceived
forall a b c. (a -> b -> c) -> b -> a -> c
flip Application
app Response -> IO ResponseReceived
k
| Bool
otherwise = Application
app Request
rq Response -> IO ResponseReceived
k
where
prepare :: IO Request
prepare = do
IORef [ByteString]
r <- [ByteString] -> IO (IORef [ByteString])
forall a. a -> IO (IORef a)
newIORef []
Inflate
i <- WindowBits -> IO Inflate
Z.initInflate (Int -> WindowBits
Z.WindowBits Int
31)
Request -> IO Request
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ Request
rq { Wai.requestBody = inflate r i
, Wai.requestBodyLength = ChunkedBody
, Wai.requestHeaders = noGzip (Wai.requestHeaders rq)
}
inflate :: IORef [ByteString] -> Inflate -> IO ByteString
inflate IORef [ByteString]
r Inflate
i = do
[ByteString]
buffered <- IORef [ByteString] -> IO [ByteString]
forall a. IORef a -> IO a
readIORef IORef [ByteString]
r
case [ByteString]
buffered of
[] -> Request -> IO ByteString
Wai.requestBody Request
rq IO ByteString -> (ByteString -> IO ByteString) -> IO ByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef [ByteString] -> Inflate -> ByteString -> IO ByteString
continue IORef [ByteString]
r Inflate
i
(ByteString
x:[ByteString]
xs) -> IORef [ByteString] -> [ByteString] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [ByteString]
r [ByteString]
xs IO () -> IO ByteString -> IO ByteString
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
x
continue :: IORef [ByteString] -> Inflate -> ByteString -> IO ByteString
continue IORef [ByteString]
r Inflate
i ByteString
b =
if ByteString -> Bool
S.null ByteString
b then
ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
else do
[ByteString] -> [ByteString]
f <- ([ByteString] -> [ByteString])
-> IO PopperRes -> IO ([ByteString] -> [ByteString])
forall {c}.
([ByteString] -> c) -> IO PopperRes -> IO ([ByteString] -> c)
toBytes [ByteString] -> [ByteString]
forall a. a -> a
id (IO PopperRes -> IO ([ByteString] -> [ByteString]))
-> IO (IO PopperRes) -> IO ([ByteString] -> [ByteString])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Inflate -> ByteString -> IO (IO PopperRes)
Z.feedInflate Inflate
i ByteString
b
[ByteString]
x <- [ByteString] -> [ByteString]
f ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[]) (ByteString -> [ByteString]) -> IO ByteString -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Inflate -> IO ByteString
Z.finishInflate Inflate
i
case [ByteString]
x of
[] -> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
(ByteString
y:[ByteString]
ys) -> IORef [ByteString] -> [ByteString] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [ByteString]
r [ByteString]
ys IO () -> IO ByteString -> IO ByteString
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
y
toBytes :: ([ByteString] -> c) -> IO PopperRes -> IO ([ByteString] -> c)
toBytes [ByteString] -> c
front IO PopperRes
p = IO PopperRes
p IO PopperRes
-> (PopperRes -> IO ([ByteString] -> c)) -> IO ([ByteString] -> c)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \PopperRes
r -> case PopperRes
r of
PopperRes
Z.PRDone -> ([ByteString] -> c) -> IO ([ByteString] -> c)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString] -> c
front
Z.PRNext ByteString
b -> ([ByteString] -> c) -> IO PopperRes -> IO ([ByteString] -> c)
toBytes ([ByteString] -> c
front ([ByteString] -> c)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:) ByteString
b) IO PopperRes
p
Z.PRError ZlibException
e -> ZlibException -> IO ([ByteString] -> c)
forall e a. Exception e => e -> IO a
throwIO ZlibException
e
isGzip :: Request -> Bool
isGzip :: Request -> Bool
isGzip = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString
"gzip" ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
==) (Maybe ByteString -> Bool)
-> (Request -> Maybe ByteString) -> Request -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [Header] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hContentEncoding ([Header] -> Maybe ByteString)
-> (Request -> [Header]) -> Request -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> [Header]
Wai.requestHeaders
noGzip :: [Header] -> [Header]
noGzip :: [Header] -> [Header]
noGzip = (Header -> Bool) -> [Header] -> [Header]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(HeaderName
k, ByteString
v) -> HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
/= HeaderName
hContentEncoding Bool -> Bool -> Bool
|| ByteString
v ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
"gzip")