{-# LANGUAGE TemplateHaskell #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2022 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Wire.Network.DNS.Effect where

import Data.IP qualified as IP
import Imports
import Network.DNS (Domain, Resolver)
import Network.DNS qualified as DNS
import Polysemy
import Wire.Network.DNS.SRV qualified as SRV

data DNSLookup m a where
  LookupSRV :: Domain -> DNSLookup m SRV.SrvResponse
  LookupA :: Domain -> DNSLookup m (Either DNS.DNSError [IP.IPv4])

makeSem ''DNSLookup

runDNSLookupDefault :: (Member (Embed IO) r) => Sem (DNSLookup ': r) a -> Sem r a
runDNSLookupDefault :: forall (r :: EffectRow) a.
Member (Embed IO) r =>
Sem (DNSLookup : r) a -> Sem r a
runDNSLookupDefault =
  (forall (rInitial :: EffectRow) x.
 DNSLookup (Sem rInitial) x -> Sem r x)
-> Sem (DNSLookup : r) a -> Sem r a
forall (e :: Effect) (r :: EffectRow) a.
FirstOrder e "interpret" =>
(forall (rInitial :: EffectRow) x. e (Sem rInitial) x -> Sem r x)
-> Sem (e : r) a -> Sem r a
interpret ((forall (rInitial :: EffectRow) x.
  DNSLookup (Sem rInitial) x -> Sem r x)
 -> Sem (DNSLookup : r) a -> Sem r a)
-> (forall (rInitial :: EffectRow) x.
    DNSLookup (Sem rInitial) x -> Sem r x)
-> Sem (DNSLookup : r) a
-> Sem r a
forall a b. (a -> b) -> a -> b
$ \DNSLookup (Sem rInitial) x
action -> IO x -> Sem r x
forall (m :: * -> *) (r :: EffectRow) a.
Member (Embed m) r =>
m a -> Sem r a
embed (IO x -> Sem r x) -> IO x -> Sem r x
forall a b. (a -> b) -> a -> b
$ do
    ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
DNS.makeResolvSeed ResolvConf
DNS.defaultResolvConf
    ResolvSeed -> (Resolver -> IO x) -> IO x
forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
DNS.withResolver ResolvSeed
rs ((Resolver -> IO x) -> IO x) -> (Resolver -> IO x) -> IO x
forall a b. (a -> b) -> a -> b
$ (Resolver -> DNSLookup (Sem rInitial) x -> IO x)
-> DNSLookup (Sem rInitial) x -> Resolver -> IO x
forall a b c. (a -> b -> c) -> b -> a -> c
flip Resolver -> DNSLookup (Sem rInitial) x -> IO x
forall {k} (m :: k) a. Resolver -> DNSLookup m a -> IO a
runLookupIO DNSLookup (Sem rInitial) x
action

runDNSLookupWithResolver :: (Member (Embed IO) r) => Resolver -> Sem (DNSLookup ': r) a -> Sem r a
runDNSLookupWithResolver :: forall (r :: EffectRow) a.
Member (Embed IO) r =>
Resolver -> Sem (DNSLookup : r) a -> Sem r a
runDNSLookupWithResolver Resolver
resolver = (forall (rInitial :: EffectRow) x.
 DNSLookup (Sem rInitial) x -> Sem r x)
-> Sem (DNSLookup : r) a -> Sem r a
forall (e :: Effect) (r :: EffectRow) a.
FirstOrder e "interpret" =>
(forall (rInitial :: EffectRow) x. e (Sem rInitial) x -> Sem r x)
-> Sem (e : r) a -> Sem r a
interpret ((forall (rInitial :: EffectRow) x.
  DNSLookup (Sem rInitial) x -> Sem r x)
 -> Sem (DNSLookup : r) a -> Sem r a)
-> (forall (rInitial :: EffectRow) x.
    DNSLookup (Sem rInitial) x -> Sem r x)
-> Sem (DNSLookup : r) a
-> Sem r a
forall a b. (a -> b) -> a -> b
$ IO x -> Sem r x
forall (m :: * -> *) (r :: EffectRow) a.
Member (Embed m) r =>
m a -> Sem r a
embed (IO x -> Sem r x)
-> (DNSLookup (Sem rInitial) x -> IO x)
-> DNSLookup (Sem rInitial) x
-> Sem r x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Resolver -> DNSLookup (Sem rInitial) x -> IO x
forall {k} (m :: k) a. Resolver -> DNSLookup m a -> IO a
runLookupIO Resolver
resolver

runLookupIO :: Resolver -> DNSLookup m a -> IO a
runLookupIO :: forall {k} (m :: k) a. Resolver -> DNSLookup m a -> IO a
runLookupIO Resolver
resolver DNSLookup m a
action =
  case DNSLookup m a
action of
    LookupSRV Domain
domain -> do
      Either DNSError [(Word16, Word16, Word16, Domain)] -> a
Either DNSError [(Word16, Word16, Word16, Domain)] -> SrvResponse
SRV.interpretResponse (Either DNSError [(Word16, Word16, Word16, Domain)] -> a)
-> IO (Either DNSError [(Word16, Word16, Word16, Domain)]) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Resolver
-> Domain
-> IO (Either DNSError [(Word16, Word16, Word16, Domain)])
DNS.lookupSRV Resolver
resolver Domain
domain
    LookupA Domain
domain ->
      Resolver -> Domain -> IO (Either DNSError [IPv4])
DNS.lookupA Resolver
resolver Domain
domain