-- This Source Code Form is subject to the terms of
-- the Mozilla Public License, v. 2.0. If a copy of
-- the MPL was not distributed with this file, You
-- can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE OverloadedStrings #-}

module Network.Wai.Middleware.Gunzip (gunzip) where

import Control.Applicative
import Control.Exception (throwIO)
import Data.IORef
import Network.HTTP.Types (Header, hContentEncoding)
import Network.Wai (Middleware, Request, RequestBodyLength (ChunkedBody))
import Prelude

import qualified Data.ByteString     as S
import qualified Data.Streaming.Zlib as Z
import qualified Network.Wai         as Wai

-- | This WAI middleware transparently unzips HTTP request bodies if
-- a request header @Content-Encoding: gzip@ is found.
--
-- Please note that the 'requestBodyLength' is set to 'ChunkedBody'
-- if the body is unzipped since we do not know the uncompressed
-- length yet.
gunzip :: Middleware
gunzip :: Middleware
gunzip Application
app Request
rq Response -> IO ResponseReceived
k
    | Request -> Bool
isGzip Request
rq = IO Request
prepare IO Request
-> (Request -> IO ResponseReceived) -> IO ResponseReceived
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Application
-> (Response -> IO ResponseReceived)
-> Request
-> IO ResponseReceived
forall a b c. (a -> b -> c) -> b -> a -> c
flip Application
app Response -> IO ResponseReceived
k
    | Bool
otherwise = Application
app Request
rq Response -> IO ResponseReceived
k
  where
    prepare :: IO Request
prepare = do
        IORef [ByteString]
r <- [ByteString] -> IO (IORef [ByteString])
forall a. a -> IO (IORef a)
newIORef []
        Inflate
i <- WindowBits -> IO Inflate
Z.initInflate (Int -> WindowBits
Z.WindowBits Int
31)
        Request -> IO Request
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ Request
rq { Wai.requestBody       = inflate r i
                    , Wai.requestBodyLength = ChunkedBody -- FIXME
                    , Wai.requestHeaders    = noGzip (Wai.requestHeaders rq)
                    }

    inflate :: IORef [ByteString] -> Inflate -> IO ByteString
inflate IORef [ByteString]
r Inflate
i = do
        [ByteString]
buffered <- IORef [ByteString] -> IO [ByteString]
forall a. IORef a -> IO a
readIORef IORef [ByteString]
r
        case [ByteString]
buffered of
            []     -> Request -> IO ByteString
Wai.requestBody Request
rq IO ByteString -> (ByteString -> IO ByteString) -> IO ByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef [ByteString] -> Inflate -> ByteString -> IO ByteString
continue IORef [ByteString]
r Inflate
i
            (ByteString
x:[ByteString]
xs) -> IORef [ByteString] -> [ByteString] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [ByteString]
r [ByteString]
xs IO () -> IO ByteString -> IO ByteString
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
x

    continue :: IORef [ByteString] -> Inflate -> ByteString -> IO ByteString
continue IORef [ByteString]
r Inflate
i ByteString
b =
        if ByteString -> Bool
S.null ByteString
b then
            ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
        else do
            [ByteString] -> [ByteString]
f <- ([ByteString] -> [ByteString])
-> IO PopperRes -> IO ([ByteString] -> [ByteString])
forall {c}.
([ByteString] -> c) -> IO PopperRes -> IO ([ByteString] -> c)
toBytes [ByteString] -> [ByteString]
forall a. a -> a
id (IO PopperRes -> IO ([ByteString] -> [ByteString]))
-> IO (IO PopperRes) -> IO ([ByteString] -> [ByteString])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Inflate -> ByteString -> IO (IO PopperRes)
Z.feedInflate Inflate
i ByteString
b
            [ByteString]
x <- [ByteString] -> [ByteString]
f ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[]) (ByteString -> [ByteString]) -> IO ByteString -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Inflate -> IO ByteString
Z.finishInflate Inflate
i
            case [ByteString]
x of
                []     -> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
                (ByteString
y:[ByteString]
ys) -> IORef [ByteString] -> [ByteString] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [ByteString]
r [ByteString]
ys IO () -> IO ByteString -> IO ByteString
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
y

    toBytes :: ([ByteString] -> c) -> IO PopperRes -> IO ([ByteString] -> c)
toBytes [ByteString] -> c
front IO PopperRes
p = IO PopperRes
p IO PopperRes
-> (PopperRes -> IO ([ByteString] -> c)) -> IO ([ByteString] -> c)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \PopperRes
r -> case PopperRes
r of
        PopperRes
Z.PRDone    -> ([ByteString] -> c) -> IO ([ByteString] -> c)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString] -> c
front
        Z.PRNext  ByteString
b -> ([ByteString] -> c) -> IO PopperRes -> IO ([ByteString] -> c)
toBytes ([ByteString] -> c
front ([ByteString] -> c)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:) ByteString
b) IO PopperRes
p
        Z.PRError ZlibException
e -> ZlibException -> IO ([ByteString] -> c)
forall e a. Exception e => e -> IO a
throwIO ZlibException
e

isGzip :: Request -> Bool
isGzip :: Request -> Bool
isGzip = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString
"gzip" ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
==) (Maybe ByteString -> Bool)
-> (Request -> Maybe ByteString) -> Request -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [Header] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hContentEncoding ([Header] -> Maybe ByteString)
-> (Request -> [Header]) -> Request -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> [Header]
Wai.requestHeaders

noGzip :: [Header] -> [Header]
noGzip :: [Header] -> [Header]
noGzip = (Header -> Bool) -> [Header] -> [Header]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(HeaderName
k, ByteString
v) -> HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
/= HeaderName
hContentEncoding Bool -> Bool -> Bool
|| ByteString
v ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
"gzip")