-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

-- Parts of this code, namely functions interpretResponse and orderSrvResult,
-- which were taken from http://hackage.haskell.org/package/pontarius-xmpp
-- are also licensed under the three-clause BSD license:
--
-- Copyright © 2005-2011 Dmitry Astapov
-- Copyright © 2005-2011 Pierre Kovalev
-- Copyright © 2010-2011 Mahdi Abdinejadi
-- Copyright © 2010-2013 Jon Kristensen
-- Copyright © 2011      IETF Trust
-- Copyright © 2012-2013 Philipp Balzarek
--
-- All rights reserved.
--
-- Pontarius XMPP is licensed under the three-clause BSD license.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
-- - Redistributions of source code must retain the above copyright notice, this
--   list of conditions and the following disclaimer.
--
-- - Redistributions in binary form must reproduce the above copyright notice,
--   this list of conditions and the following disclaimer in the documentation
--   and/or other materials provided with the distribution.
--
-- - Neither the name of the Pontarius project nor the names of its contributors
--   may be used to endorse or promote products derived from this software without
--   specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR THE PONTARIUS PROJECT BE
-- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-- CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
-- GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
-- HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
-- LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
-- OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

module Wire.Network.DNS.SRV where

import Control.Category ((>>>))
import Data.List.NonEmpty (NonEmpty (..))
import Imports
import Network.DNS (DNSError, Domain)
import System.Random (randomRIO)

data SrvEntry = SrvEntry
  { SrvEntry -> Word16
srvPriority :: !Word16,
    SrvEntry -> Word16
srvWeight :: !Word16,
    SrvEntry -> SrvTarget
srvTarget :: !SrvTarget
  }
  deriving (SrvEntry -> SrvEntry -> Bool
(SrvEntry -> SrvEntry -> Bool)
-> (SrvEntry -> SrvEntry -> Bool) -> Eq SrvEntry
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SrvEntry -> SrvEntry -> Bool
== :: SrvEntry -> SrvEntry -> Bool
$c/= :: SrvEntry -> SrvEntry -> Bool
/= :: SrvEntry -> SrvEntry -> Bool
Eq, Int -> SrvEntry -> ShowS
[SrvEntry] -> ShowS
SrvEntry -> String
(Int -> SrvEntry -> ShowS)
-> (SrvEntry -> String) -> ([SrvEntry] -> ShowS) -> Show SrvEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SrvEntry -> ShowS
showsPrec :: Int -> SrvEntry -> ShowS
$cshow :: SrvEntry -> String
show :: SrvEntry -> String
$cshowList :: [SrvEntry] -> ShowS
showList :: [SrvEntry] -> ShowS
Show, Eq SrvEntry
Eq SrvEntry =>
(SrvEntry -> SrvEntry -> Ordering)
-> (SrvEntry -> SrvEntry -> Bool)
-> (SrvEntry -> SrvEntry -> Bool)
-> (SrvEntry -> SrvEntry -> Bool)
-> (SrvEntry -> SrvEntry -> Bool)
-> (SrvEntry -> SrvEntry -> SrvEntry)
-> (SrvEntry -> SrvEntry -> SrvEntry)
-> Ord SrvEntry
SrvEntry -> SrvEntry -> Bool
SrvEntry -> SrvEntry -> Ordering
SrvEntry -> SrvEntry -> SrvEntry
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SrvEntry -> SrvEntry -> Ordering
compare :: SrvEntry -> SrvEntry -> Ordering
$c< :: SrvEntry -> SrvEntry -> Bool
< :: SrvEntry -> SrvEntry -> Bool
$c<= :: SrvEntry -> SrvEntry -> Bool
<= :: SrvEntry -> SrvEntry -> Bool
$c> :: SrvEntry -> SrvEntry -> Bool
> :: SrvEntry -> SrvEntry -> Bool
$c>= :: SrvEntry -> SrvEntry -> Bool
>= :: SrvEntry -> SrvEntry -> Bool
$cmax :: SrvEntry -> SrvEntry -> SrvEntry
max :: SrvEntry -> SrvEntry -> SrvEntry
$cmin :: SrvEntry -> SrvEntry -> SrvEntry
min :: SrvEntry -> SrvEntry -> SrvEntry
Ord)

data SrvTarget = SrvTarget
  { -- | the hostname on which the service is offered
    SrvTarget -> Domain
srvTargetDomain :: !Domain,
    -- | the port on which the service is offered
    SrvTarget -> Word16
srvTargetPort :: !Word16
  }
  deriving (SrvTarget -> SrvTarget -> Bool
(SrvTarget -> SrvTarget -> Bool)
-> (SrvTarget -> SrvTarget -> Bool) -> Eq SrvTarget
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SrvTarget -> SrvTarget -> Bool
== :: SrvTarget -> SrvTarget -> Bool
$c/= :: SrvTarget -> SrvTarget -> Bool
/= :: SrvTarget -> SrvTarget -> Bool
Eq, Int -> SrvTarget -> ShowS
[SrvTarget] -> ShowS
SrvTarget -> String
(Int -> SrvTarget -> ShowS)
-> (SrvTarget -> String)
-> ([SrvTarget] -> ShowS)
-> Show SrvTarget
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SrvTarget -> ShowS
showsPrec :: Int -> SrvTarget -> ShowS
$cshow :: SrvTarget -> String
show :: SrvTarget -> String
$cshowList :: [SrvTarget] -> ShowS
showList :: [SrvTarget] -> ShowS
Show, Eq SrvTarget
Eq SrvTarget =>
(SrvTarget -> SrvTarget -> Ordering)
-> (SrvTarget -> SrvTarget -> Bool)
-> (SrvTarget -> SrvTarget -> Bool)
-> (SrvTarget -> SrvTarget -> Bool)
-> (SrvTarget -> SrvTarget -> Bool)
-> (SrvTarget -> SrvTarget -> SrvTarget)
-> (SrvTarget -> SrvTarget -> SrvTarget)
-> Ord SrvTarget
SrvTarget -> SrvTarget -> Bool
SrvTarget -> SrvTarget -> Ordering
SrvTarget -> SrvTarget -> SrvTarget
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SrvTarget -> SrvTarget -> Ordering
compare :: SrvTarget -> SrvTarget -> Ordering
$c< :: SrvTarget -> SrvTarget -> Bool
< :: SrvTarget -> SrvTarget -> Bool
$c<= :: SrvTarget -> SrvTarget -> Bool
<= :: SrvTarget -> SrvTarget -> Bool
$c> :: SrvTarget -> SrvTarget -> Bool
> :: SrvTarget -> SrvTarget -> Bool
$c>= :: SrvTarget -> SrvTarget -> Bool
>= :: SrvTarget -> SrvTarget -> Bool
$cmax :: SrvTarget -> SrvTarget -> SrvTarget
max :: SrvTarget -> SrvTarget -> SrvTarget
$cmin :: SrvTarget -> SrvTarget -> SrvTarget
min :: SrvTarget -> SrvTarget -> SrvTarget
Ord)

data SrvResponse
  = SrvNotAvailable
  | SrvAvailable (NonEmpty SrvEntry)
  | SrvResponseError DNSError
  deriving (SrvResponse -> SrvResponse -> Bool
(SrvResponse -> SrvResponse -> Bool)
-> (SrvResponse -> SrvResponse -> Bool) -> Eq SrvResponse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SrvResponse -> SrvResponse -> Bool
== :: SrvResponse -> SrvResponse -> Bool
$c/= :: SrvResponse -> SrvResponse -> Bool
/= :: SrvResponse -> SrvResponse -> Bool
Eq, Int -> SrvResponse -> ShowS
[SrvResponse] -> ShowS
SrvResponse -> String
(Int -> SrvResponse -> ShowS)
-> (SrvResponse -> String)
-> ([SrvResponse] -> ShowS)
-> Show SrvResponse
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SrvResponse -> ShowS
showsPrec :: Int -> SrvResponse -> ShowS
$cshow :: SrvResponse -> String
show :: SrvResponse -> String
$cshowList :: [SrvResponse] -> ShowS
showList :: [SrvResponse] -> ShowS
Show)

interpretResponse :: Either DNSError [(Word16, Word16, Word16, Domain)] -> SrvResponse
interpretResponse :: Either DNSError [(Word16, Word16, Word16, Domain)] -> SrvResponse
interpretResponse = \case
  Left DNSError
err -> DNSError -> SrvResponse
SrvResponseError DNSError
err
  Right [] -> SrvResponse
SrvNotAvailable
  Right [(Word16
_, Word16
_, Word16
_, Domain
".")] -> SrvResponse
SrvNotAvailable -- According to RFC2782
  Right ((Word16, Word16, Word16, Domain)
r : [(Word16, Word16, Word16, Domain)]
rs) -> NonEmpty SrvEntry -> SrvResponse
SrvAvailable (NonEmpty SrvEntry -> SrvResponse)
-> NonEmpty SrvEntry -> SrvResponse
forall a b. (a -> b) -> a -> b
$ ((Word16, Word16, Word16, Domain) -> SrvEntry)
-> NonEmpty (Word16, Word16, Word16, Domain) -> NonEmpty SrvEntry
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Word16, Word16, Word16, Domain) -> SrvEntry
toSrvEntry ((Word16, Word16, Word16, Domain)
r (Word16, Word16, Word16, Domain)
-> [(Word16, Word16, Word16, Domain)]
-> NonEmpty (Word16, Word16, Word16, Domain)
forall a. a -> [a] -> NonEmpty a
:| [(Word16, Word16, Word16, Domain)]
rs)

toSrvEntry :: (Word16, Word16, Word16, Domain) -> SrvEntry
toSrvEntry :: (Word16, Word16, Word16, Domain) -> SrvEntry
toSrvEntry (Word16
prio, Word16
weight, Word16
port, Domain
domain) = Word16 -> Word16 -> SrvTarget -> SrvEntry
SrvEntry Word16
prio Word16
weight (Domain -> Word16 -> SrvTarget
SrvTarget Domain
domain Word16
port)

-- FUTUREWORK: maybe improve sorting algorithm here? (with respect to performance and code style)
--
-- This function orders the SRV result in accordance with RFC
-- 2782. It sorts the SRV results in order of priority, and then
-- uses a random process to order the records with the same
-- priority based on their weight.
--
-- Taken from http://hackage.haskell.org/package/pontarius-xmpp (BSD3 licence) and refactored.
orderSrvResult :: [SrvEntry] -> IO [SrvEntry]
orderSrvResult :: [SrvEntry] -> IO [SrvEntry]
orderSrvResult =
  -- Order the result set by priority.
  (SrvEntry -> Word16) -> [SrvEntry] -> [SrvEntry]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn SrvEntry -> Word16
srvPriority
    -- Group elements in sublists based on their priority.
    -- The result type is `[[(Word16, Word16, Word16, Domain)]]' (nested list).
    ([SrvEntry] -> [SrvEntry])
-> ([SrvEntry] -> IO [SrvEntry]) -> [SrvEntry] -> IO [SrvEntry]
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (SrvEntry -> SrvEntry -> Bool) -> [SrvEntry] -> [[SrvEntry]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Word16 -> Word16 -> Bool)
-> (SrvEntry -> Word16) -> SrvEntry -> SrvEntry -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` SrvEntry -> Word16
srvPriority)
    -- For each sublist, put records with a weight of zero first.
    ([SrvEntry] -> [[SrvEntry]])
-> ([[SrvEntry]] -> IO [SrvEntry]) -> [SrvEntry] -> IO [SrvEntry]
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> ([SrvEntry] -> [SrvEntry]) -> [[SrvEntry]] -> [[SrvEntry]]
forall a b. (a -> b) -> [a] -> [b]
map (([SrvEntry] -> [SrvEntry] -> [SrvEntry])
-> ([SrvEntry], [SrvEntry]) -> [SrvEntry]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [SrvEntry] -> [SrvEntry] -> [SrvEntry]
forall a. [a] -> [a] -> [a]
(++) (([SrvEntry], [SrvEntry]) -> [SrvEntry])
-> ([SrvEntry] -> ([SrvEntry], [SrvEntry]))
-> [SrvEntry]
-> [SrvEntry]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SrvEntry -> Bool) -> [SrvEntry] -> ([SrvEntry], [SrvEntry])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Word16
0) (Word16 -> Bool) -> (SrvEntry -> Word16) -> SrvEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrvEntry -> Word16
srvWeight))
    -- Order each sublist.
    ([[SrvEntry]] -> [[SrvEntry]])
-> ([[SrvEntry]] -> IO [SrvEntry]) -> [[SrvEntry]] -> IO [SrvEntry]
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> ([SrvEntry] -> IO [SrvEntry]) -> [[SrvEntry]] -> IO [[SrvEntry]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM [SrvEntry] -> IO [SrvEntry]
orderSublist
    -- Concatenate the results.
    ([[SrvEntry]] -> IO [[SrvEntry]])
-> (IO [[SrvEntry]] -> IO [SrvEntry])
-> [[SrvEntry]]
-> IO [SrvEntry]
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> ([[SrvEntry]] -> [SrvEntry]) -> IO [[SrvEntry]] -> IO [SrvEntry]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SrvEntry]] -> [SrvEntry]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
  where
    orderSublist :: [SrvEntry] -> IO [SrvEntry]
    orderSublist :: [SrvEntry] -> IO [SrvEntry]
orderSublist [] = [SrvEntry] -> IO [SrvEntry]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    orderSublist [SrvEntry]
sublist = do
      -- Compute the running sum, as well as the total sum of the sublist.
      -- Add the running sum to the SRV tuples.
      let (Word16
total, [(SrvEntry, Word16)]
sublistWithRunning) =
            (Word16 -> SrvEntry -> (Word16, (SrvEntry, Word16)))
-> Word16 -> [SrvEntry] -> (Word16, [(SrvEntry, Word16)])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\Word16
acc SrvEntry
srv -> let acc' :: Word16
acc' = Word16
acc Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ SrvEntry -> Word16
srvWeight SrvEntry
srv in (Word16
acc', (SrvEntry
srv, Word16
acc'))) Word16
0 [SrvEntry]
sublist
      -- Choose a random number between 0 and the total sum (inclusive).
      Word16
randomNumber <- (Word16, Word16) -> IO Word16
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Word16
0, Word16
total)
      -- Select the first record with its running sum greater
      -- than or equal to the random number.
      let ([(SrvEntry, Word16)]
beginning, (SrvEntry
firstSrv, Word16
_), [(SrvEntry, Word16)]
end) =
            case ((SrvEntry, Word16) -> Bool)
-> [(SrvEntry, Word16)]
-> ([(SrvEntry, Word16)], [(SrvEntry, Word16)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (\(SrvEntry
_, Word16
running) -> Word16
randomNumber Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word16
running) [(SrvEntry, Word16)]
sublistWithRunning of
              ([(SrvEntry, Word16)]
b, (SrvEntry, Word16)
c : [(SrvEntry, Word16)]
e) -> ([(SrvEntry, Word16)]
b, (SrvEntry, Word16)
c, [(SrvEntry, Word16)]
e)
              ([(SrvEntry, Word16)], [(SrvEntry, Word16)])
_ -> String
-> ([(SrvEntry, Word16)], (SrvEntry, Word16), [(SrvEntry, Word16)])
forall a. HasCallStack => String -> a
error String
"orderSrvResult: no record with running sum greater than random number"
      -- Remove the running total number from the remaining elements.
      let remainingSrvs :: [SrvEntry]
remainingSrvs = ((SrvEntry, Word16) -> SrvEntry)
-> [(SrvEntry, Word16)] -> [SrvEntry]
forall a b. (a -> b) -> [a] -> [b]
map (SrvEntry, Word16) -> SrvEntry
forall a b. (a, b) -> a
fst ([(SrvEntry, Word16)]
beginning [(SrvEntry, Word16)]
-> [(SrvEntry, Word16)] -> [(SrvEntry, Word16)]
forall a. [a] -> [a] -> [a]
++ [(SrvEntry, Word16)]
end)
      -- Repeat the ordering procedure on the remaining elements.
      [SrvEntry]
rest <- [SrvEntry] -> IO [SrvEntry]
orderSublist [SrvEntry]
remainingSrvs
      [SrvEntry] -> IO [SrvEntry]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SrvEntry] -> IO [SrvEntry]) -> [SrvEntry] -> IO [SrvEntry]
forall a b. (a -> b) -> a -> b
$ SrvEntry
firstSrv SrvEntry -> [SrvEntry] -> [SrvEntry]
forall a. a -> [a] -> [a]
: [SrvEntry]
rest