{-# LANGUAGE CPP #-}

-----------------------------------------------------------------

-----------------------------------------------------------------

-- | Module : Network.Wai.Middleware.MethodOverridePost
--
-- Changes the request-method via first post-parameter _method.
module Network.Wai.Middleware.MethodOverridePost (
    methodOverridePost,
) where

import Data.ByteString.Lazy (toChunks)
import Data.IORef (atomicModifyIORef, newIORef)
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (mconcat, mempty)
#endif
import Network.HTTP.Types (hContentType, parseQuery)
import Network.Wai

-- | Allows overriding of the HTTP request method via the _method post string parameter.
--
-- * Looks for the Content-Type requestHeader.
--
-- * If the header is set to application/x-www-form-urlencoded
-- and the first POST parameter is _method
-- then it changes the request-method to the value of that
-- parameter.
--
-- * This middleware only applies when the initial request method is POST.
methodOverridePost :: Middleware
methodOverridePost :: Middleware
methodOverridePost Application
app Request
req Response -> IO ResponseReceived
send =
    case (Request -> StrictByteString
requestMethod Request
req, HeaderName
-> [(HeaderName, StrictByteString)] -> Maybe StrictByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hContentType (Request -> [(HeaderName, StrictByteString)]
requestHeaders Request
req)) of
        (StrictByteString
"POST", Just StrictByteString
"application/x-www-form-urlencoded") -> Request -> IO Request
setPost Request
req IO Request
-> (Request -> IO ResponseReceived) -> IO ResponseReceived
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Application
-> (Response -> IO ResponseReceived)
-> Request
-> IO ResponseReceived
forall a b c. (a -> b -> c) -> b -> a -> c
flip Application
app Response -> IO ResponseReceived
send
        (StrictByteString, Maybe StrictByteString)
_ -> Application
app Request
req Response -> IO ResponseReceived
send

setPost :: Request -> IO Request
setPost :: Request -> IO Request
setPost Request
req = do
    StrictByteString
body <- ([StrictByteString] -> StrictByteString
forall a. Monoid a => [a] -> a
mconcat ([StrictByteString] -> StrictByteString)
-> (LazyByteString -> [StrictByteString])
-> LazyByteString
-> StrictByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LazyByteString -> [StrictByteString]
toChunks) (LazyByteString -> StrictByteString)
-> IO LazyByteString -> IO StrictByteString
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Request -> IO LazyByteString
lazyRequestBody Request
req
    IORef StrictByteString
ref <- StrictByteString -> IO (IORef StrictByteString)
forall a. a -> IO (IORef a)
newIORef StrictByteString
body
    let rb :: IO StrictByteString
rb = IORef StrictByteString
-> (StrictByteString -> (StrictByteString, StrictByteString))
-> IO StrictByteString
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef StrictByteString
ref ((StrictByteString -> (StrictByteString, StrictByteString))
 -> IO StrictByteString)
-> (StrictByteString -> (StrictByteString, StrictByteString))
-> IO StrictByteString
forall a b. (a -> b) -> a -> b
$ \StrictByteString
bs -> (StrictByteString
forall a. Monoid a => a
mempty, StrictByteString
bs)
        req' :: Request
req' = IO StrictByteString -> Request -> Request
setRequestBodyChunks IO StrictByteString
rb Request
req
    case StrictByteString -> Query
parseQuery StrictByteString
body of
        ((StrictByteString
"_method", Just StrictByteString
newmethod) : Query
_) -> Request -> IO Request
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Request
req'{requestMethod = newmethod}
        Query
_ -> Request -> IO Request
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Request
req'

{- HLint ignore setPost "Use tuple-section" -}