-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Wire.API.Routes.Version.Wai (versionMiddleware) where

import Control.Monad.Except (throwError)
import Data.ByteString.Conversion
import Data.EitherR (fmapL)
import Data.Text qualified as T
import Data.Text.Lazy (fromStrict)
import Imports
import Network.HTTP.Types qualified as HTTP
import Network.Wai
import Network.Wai.Middleware.Rewrite
import Network.Wai.Utilities.Error
import Network.Wai.Utilities.Response
import Web.HttpApiData (parseUrlPiece, toUrlPiece)
import Wire.API.Routes.Version

-- | Strip off version prefix. Return 404 if the version is not supported.
versionMiddleware :: Set Version -> Middleware
versionMiddleware :: Set Version -> Middleware
versionMiddleware Set Version
disabledAPIVersions Application
app Request
req Response -> IO ResponseReceived
k = case Request -> Either ParseVersionError (Request, Version)
parseVersion (Request -> Request
removeVersionHeader Request
req) of
  Right (Request
req', Version
v) -> do
    if Version
v Version -> Set Version -> Bool
forall a. Eq a => a -> Set a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Set Version
disabledAPIVersions Bool -> Bool -> Bool
&& Request -> Bool
requestIsDisableable Request
req'
      then Text -> IO ResponseReceived
err (Version -> Text
forall a. ToHttpApiData a => a -> Text
toUrlPiece Version
v)
      else Application
app (Version -> Request -> Request
addVersionHeader Version
v Request
req') Response -> IO ResponseReceived
k
  Left (BadVersion Text
v) -> Text -> IO ResponseReceived
err Text
v
  Left ParseVersionError
NoVersion -> Application
app Request
req Response -> IO ResponseReceived
k
  Left ParseVersionError
InternalApisAreUnversioned -> IO ResponseReceived
errint
  where
    err :: Text -> IO ResponseReceived
    err :: Text -> IO ResponseReceived
err Text
v =
      Response -> IO ResponseReceived
k (Response -> IO ResponseReceived)
-> (LText -> Response) -> LText -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error -> Response
errorRs (Error -> Response) -> (LText -> Error) -> LText -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Status -> LText -> LText -> Error
mkError Status
HTTP.status404 LText
"unsupported-version" (LText -> IO ResponseReceived) -> LText -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
        LText
"Version " LText -> LText -> LText
forall a. Semigroup a => a -> a -> a
<> Text -> LText
fromStrict Text
v LText -> LText -> LText
forall a. Semigroup a => a -> a -> a
<> LText
" is not supported"

    errint :: IO ResponseReceived
    errint :: IO ResponseReceived
errint =
      Response -> IO ResponseReceived
k (Response -> IO ResponseReceived)
-> (LText -> Response) -> LText -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error -> Response
errorRs (Error -> Response) -> (LText -> Error) -> LText -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Status -> LText -> LText -> Error
mkError Status
HTTP.status404 LText
"unsupported-version" (LText -> IO ResponseReceived) -> LText -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
        LText
"Internal APIs (`/i/...`) are not under version control"

data ParseVersionError = NoVersion | BadVersion Text | InternalApisAreUnversioned

parseVersion :: Request -> Either ParseVersionError (Request, Version)
parseVersion :: Request -> Either ParseVersionError (Request, Version)
parseVersion Request
req = do
  (Text
version, [Text]
pinfo) <- case Request -> [Text]
pathInfo Request
req of
    [] -> ParseVersionError -> Either ParseVersionError (Text, [Text])
forall a. ParseVersionError -> Either ParseVersionError a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ParseVersionError
NoVersion
    (Text
x : [Text]
xs) -> do
      Bool -> Either ParseVersionError () -> Either ParseVersionError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text -> Bool
looksLikeVersion Text
x) (Either ParseVersionError () -> Either ParseVersionError ())
-> Either ParseVersionError () -> Either ParseVersionError ()
forall a b. (a -> b) -> a -> b
$
        ParseVersionError -> Either ParseVersionError ()
forall a. ParseVersionError -> Either ParseVersionError a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ParseVersionError
NoVersion
      case [Text]
xs of
        (Text
"i" : [Text]
_) -> ParseVersionError -> Either ParseVersionError (Text, [Text])
forall a. ParseVersionError -> Either ParseVersionError a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ParseVersionError
InternalApisAreUnversioned
        (Text
"api-internal" : [Text]
_) -> ParseVersionError -> Either ParseVersionError (Text, [Text])
forall a. ParseVersionError -> Either ParseVersionError a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ParseVersionError
InternalApisAreUnversioned
        [Text]
_ -> (Text, [Text]) -> Either ParseVersionError (Text, [Text])
forall a. a -> Either ParseVersionError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text
x, [Text]
xs)
  Version
n <- (Text -> ParseVersionError)
-> Either Text Version -> Either ParseVersionError Version
forall a b r. (a -> b) -> Either a r -> Either b r
fmapL (ParseVersionError -> Text -> ParseVersionError
forall a b. a -> b -> a
const (ParseVersionError -> Text -> ParseVersionError)
-> ParseVersionError -> Text -> ParseVersionError
forall a b. (a -> b) -> a -> b
$ Text -> ParseVersionError
BadVersion Text
version) (Either Text Version -> Either ParseVersionError Version)
-> Either Text Version -> Either ParseVersionError Version
forall a b. (a -> b) -> a -> b
$ Text -> Either Text Version
forall a. FromHttpApiData a => Text -> Either Text a
parseUrlPiece Text
version
  (Request, Version) -> Either ParseVersionError (Request, Version)
forall a. a -> Either ParseVersionError a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PathsAndQueries -> RequestHeaders -> PathsAndQueries)
-> Request -> Request
rewriteRequestPure (\([Text]
_, Query
q) RequestHeaders
_ -> ([Text]
pinfo, Query
q)) Request
req, Version
n)

looksLikeVersion :: Text -> Bool
looksLikeVersion :: Text -> Bool
looksLikeVersion Text
version = case Int -> Text -> (Text, Text)
T.splitAt Int
1 Text
version of (Text
h, Text
t) -> Text
h Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"v" Bool -> Bool -> Bool
&& (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
isDigit Text
t

-- | swagger-delivering end-points are not disableable: they should work for all versions.
requestIsDisableable :: Request -> Bool
requestIsDisableable :: Request -> Bool
requestIsDisableable (Request -> [Text]
pathInfo -> [Text]
path) = case [Text]
path of
  [Text
"api", Text
"swagger-ui"] -> Bool
False
  [Text
"api", Text
"swagger.json"] -> Bool
False
  [Text]
_ -> Bool
True

removeVersionHeader :: Request -> Request
removeVersionHeader :: Request -> Request
removeVersionHeader Request
req =
  Request
req
    { requestHeaders = filter ((/= versionHeader) . fst) (requestHeaders req)
    }

addVersionHeader :: Version -> Request -> Request
addVersionHeader :: Version -> Request -> Request
addVersionHeader Version
v Request
req =
  Request
req
    { requestHeaders = (versionHeader, toByteString' (versionInt v :: Int)) : requestHeaders req
    }