{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE CPP #-}
-- |
-- Module      : Network.Socks5.Command
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.Socks5.Command
    ( establish
    , Connect(..)
    , Command(..)
    , connectIPV4
    , connectIPV6
    , connectDomainName
    -- * lowlevel interface
    , rpc
    , rpc_
    , sendSerialized
    , waitSerialized
    ) where

import Basement.Compat.Base
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Prelude
import Data.Serialize

import Network.Socket (Socket, PortNumber, HostAddress, HostAddress6)
import Network.Socket.ByteString

import Network.Socks5.Types
import Network.Socks5.Wire

establish :: SocksVersion -> Socket -> [SocksMethod] -> IO SocksMethod
establish :: SocksVersion -> Socket -> [SocksMethod] -> IO SocksMethod
establish SocksVersion
SocksVer5 Socket
socket [SocksMethod]
methods = do
    Socket -> ByteString -> IO ()
sendAll Socket
socket (SocksHello -> ByteString
forall a. Serialize a => a -> ByteString
encode (SocksHello -> ByteString) -> SocksHello -> ByteString
forall a b. (a -> b) -> a -> b
$ [SocksMethod] -> SocksHello
SocksHello [SocksMethod]
methods)
    SocksHelloResponse -> SocksMethod
getSocksHelloResponseMethod (SocksHelloResponse -> SocksMethod)
-> IO SocksHelloResponse -> IO SocksMethod
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get SocksHelloResponse -> IO ByteString -> IO SocksHelloResponse
forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone Get SocksHelloResponse
forall t. Serialize t => Get t
get (Socket -> Int -> IO ByteString
recv Socket
socket Int
4096)

newtype Connect = Connect SocksAddress deriving (Int -> Connect -> ShowS
[Connect] -> ShowS
Connect -> String
(Int -> Connect -> ShowS)
-> (Connect -> String) -> ([Connect] -> ShowS) -> Show Connect
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Connect -> ShowS
showsPrec :: Int -> Connect -> ShowS
$cshow :: Connect -> String
show :: Connect -> String
$cshowList :: [Connect] -> ShowS
showList :: [Connect] -> ShowS
Show,Connect -> Connect -> Bool
(Connect -> Connect -> Bool)
-> (Connect -> Connect -> Bool) -> Eq Connect
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Connect -> Connect -> Bool
== :: Connect -> Connect -> Bool
$c/= :: Connect -> Connect -> Bool
/= :: Connect -> Connect -> Bool
Eq,Eq Connect
Eq Connect =>
(Connect -> Connect -> Ordering)
-> (Connect -> Connect -> Bool)
-> (Connect -> Connect -> Bool)
-> (Connect -> Connect -> Bool)
-> (Connect -> Connect -> Bool)
-> (Connect -> Connect -> Connect)
-> (Connect -> Connect -> Connect)
-> Ord Connect
Connect -> Connect -> Bool
Connect -> Connect -> Ordering
Connect -> Connect -> Connect
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Connect -> Connect -> Ordering
compare :: Connect -> Connect -> Ordering
$c< :: Connect -> Connect -> Bool
< :: Connect -> Connect -> Bool
$c<= :: Connect -> Connect -> Bool
<= :: Connect -> Connect -> Bool
$c> :: Connect -> Connect -> Bool
> :: Connect -> Connect -> Bool
$c>= :: Connect -> Connect -> Bool
>= :: Connect -> Connect -> Bool
$cmax :: Connect -> Connect -> Connect
max :: Connect -> Connect -> Connect
$cmin :: Connect -> Connect -> Connect
min :: Connect -> Connect -> Connect
Ord)

class Command a where
    toRequest   :: a -> SocksRequest
    fromRequest :: SocksRequest -> Maybe a

instance Command SocksRequest where
    toRequest :: SocksRequest -> SocksRequest
toRequest   = SocksRequest -> SocksRequest
forall a. a -> a
forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
    fromRequest :: SocksRequest -> Maybe SocksRequest
fromRequest = SocksRequest -> Maybe SocksRequest
forall a. a -> Maybe a
Just

instance Command Connect where
    toRequest :: Connect -> SocksRequest
toRequest (Connect (SocksAddress SocksHostAddress
ha PortNumber
port)) = SocksRequest
            { requestCommand :: SocksCommand
requestCommand  = SocksCommand
SocksCommandConnect
            , requestDstAddr :: SocksHostAddress
requestDstAddr  = SocksHostAddress
ha
            , requestDstPort :: PortNumber
requestDstPort  = PortNumber -> PortNumber
forall a b. (Integral a, Num b) => a -> b
Prelude.fromIntegral PortNumber
port
            }
    fromRequest :: SocksRequest -> Maybe Connect
fromRequest SocksRequest
req
        | SocksRequest -> SocksCommand
requestCommand SocksRequest
req SocksCommand -> SocksCommand -> Bool
forall a. Eq a => a -> a -> Bool
/= SocksCommand
SocksCommandConnect = Maybe Connect
forall a. Maybe a
Nothing
        | Bool
otherwise = Connect -> Maybe Connect
forall a. a -> Maybe a
Just (Connect -> Maybe Connect) -> Connect -> Maybe Connect
forall a b. (a -> b) -> a -> b
$ SocksAddress -> Connect
Connect (SocksAddress -> Connect) -> SocksAddress -> Connect
forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (SocksRequest -> SocksHostAddress
requestDstAddr SocksRequest
req) (SocksRequest -> PortNumber
requestDstPort SocksRequest
req)

connectIPV4 :: Socket -> HostAddress -> PortNumber -> IO (HostAddress, PortNumber)
connectIPV4 :: Socket -> HostAddress -> PortNumber -> IO (HostAddress, PortNumber)
connectIPV4 Socket
socket HostAddress
hostaddr PortNumber
port = (SocksHostAddress, PortNumber) -> (HostAddress, PortNumber)
forall {b}. (SocksHostAddress, b) -> (HostAddress, b)
onReply ((SocksHostAddress, PortNumber) -> (HostAddress, PortNumber))
-> IO (SocksHostAddress, PortNumber)
-> IO (HostAddress, PortNumber)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Connect -> IO (SocksHostAddress, PortNumber)
forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket (SocksAddress -> Connect
Connect (SocksAddress -> Connect) -> SocksAddress -> Connect
forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (HostAddress -> SocksHostAddress
SocksAddrIPV4 HostAddress
hostaddr) PortNumber
port)
    where onReply :: (SocksHostAddress, b) -> (HostAddress, b)
onReply (SocksAddrIPV4 HostAddress
h, b
p) = (HostAddress
h, b
p)
          onReply (SocksHostAddress, b)
_                    = String -> (HostAddress, b)
forall a. HasCallStack => String -> a
error String
"ipv4 requested, got something different"

connectIPV6 :: Socket -> HostAddress6 -> PortNumber -> IO (HostAddress6, PortNumber)
connectIPV6 :: Socket
-> HostAddress6 -> PortNumber -> IO (HostAddress6, PortNumber)
connectIPV6 Socket
socket HostAddress6
hostaddr6 PortNumber
port = (SocksHostAddress, PortNumber) -> (HostAddress6, PortNumber)
forall {b}. (SocksHostAddress, b) -> (HostAddress6, b)
onReply ((SocksHostAddress, PortNumber) -> (HostAddress6, PortNumber))
-> IO (SocksHostAddress, PortNumber)
-> IO (HostAddress6, PortNumber)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Connect -> IO (SocksHostAddress, PortNumber)
forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket (SocksAddress -> Connect
Connect (SocksAddress -> Connect) -> SocksAddress -> Connect
forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (HostAddress6 -> SocksHostAddress
SocksAddrIPV6 HostAddress6
hostaddr6) PortNumber
port)
    where onReply :: (SocksHostAddress, b) -> (HostAddress6, b)
onReply (SocksAddrIPV6 HostAddress6
h, b
p) = (HostAddress6
h, b
p)
          onReply (SocksHostAddress, b)
_                    = String -> (HostAddress6, b)
forall a. HasCallStack => String -> a
error String
"ipv6 requested, got something different"

-- TODO: FQDN should only be ascii, maybe putting a "fqdn" data type
-- in front to make sure and make the BC.pack safe.
connectDomainName :: Socket -> [Char] -> PortNumber -> IO (SocksHostAddress, PortNumber)
connectDomainName :: Socket -> String -> PortNumber -> IO (SocksHostAddress, PortNumber)
connectDomainName Socket
socket String
fqdn PortNumber
port = Socket -> Connect -> IO (SocksHostAddress, PortNumber)
forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket (Connect -> IO (SocksHostAddress, PortNumber))
-> Connect -> IO (SocksHostAddress, PortNumber)
forall a b. (a -> b) -> a -> b
$ SocksAddress -> Connect
Connect (SocksAddress -> Connect) -> SocksAddress -> Connect
forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (ByteString -> SocksHostAddress
SocksAddrDomainName (ByteString -> SocksHostAddress) -> ByteString -> SocksHostAddress
forall a b. (a -> b) -> a -> b
$ String -> ByteString
BC.pack String
fqdn) PortNumber
port

sendSerialized :: Serialize a => Socket -> a -> IO ()
sendSerialized :: forall a. Serialize a => Socket -> a -> IO ()
sendSerialized Socket
sock a
a = Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
a

waitSerialized :: Serialize a => Socket -> IO a
waitSerialized :: forall a. Serialize a => Socket -> IO a
waitSerialized Socket
sock = Get a -> IO ByteString -> IO a
forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone Get a
forall t. Serialize t => Get t
get (Socket -> IO ByteString
getMore Socket
sock)

rpc :: Command a => Socket -> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc :: forall a.
Command a =>
Socket
-> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc Socket
socket a
req = do
    Socket -> SocksRequest -> IO ()
forall a. Serialize a => Socket -> a -> IO ()
sendSerialized Socket
socket (a -> SocksRequest
forall a. Command a => a -> SocksRequest
toRequest a
req)
    SocksResponse -> Either SocksError (SocksHostAddress, PortNumber)
forall {b}.
Num b =>
SocksResponse -> Either SocksError (SocksHostAddress, b)
onReply (SocksResponse -> Either SocksError (SocksHostAddress, PortNumber))
-> IO SocksResponse
-> IO (Either SocksError (SocksHostAddress, PortNumber))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get SocksResponse -> IO ByteString -> IO SocksResponse
forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone Get SocksResponse
forall t. Serialize t => Get t
get (Socket -> IO ByteString
getMore Socket
socket)
    where onReply :: SocksResponse -> Either SocksError (SocksHostAddress, b)
onReply res :: SocksResponse
res@(SocksResponse -> SocksReply
responseReply -> SocksReply
reply) =
                case SocksReply
reply of
                    SocksReply
SocksReplySuccess -> (SocksHostAddress, b) -> Either SocksError (SocksHostAddress, b)
forall a b. b -> Either a b
Right (SocksResponse -> SocksHostAddress
responseBindAddr SocksResponse
res, PortNumber -> b
forall a b. (Integral a, Num b) => a -> b
Prelude.fromIntegral (PortNumber -> b) -> PortNumber -> b
forall a b. (a -> b) -> a -> b
$ SocksResponse -> PortNumber
responseBindPort SocksResponse
res)
                    SocksReplyError SocksError
e -> SocksError -> Either SocksError (SocksHostAddress, b)
forall a b. a -> Either a b
Left SocksError
e

rpc_ :: Command a => Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ :: forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket a
req = Socket
-> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
forall a.
Command a =>
Socket
-> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc Socket
socket a
req IO (Either SocksError (SocksHostAddress, PortNumber))
-> (Either SocksError (SocksHostAddress, PortNumber)
    -> IO (SocksHostAddress, PortNumber))
-> IO (SocksHostAddress, PortNumber)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (SocksError -> IO (SocksHostAddress, PortNumber))
-> ((SocksHostAddress, PortNumber)
    -> IO (SocksHostAddress, PortNumber))
-> Either SocksError (SocksHostAddress, PortNumber)
-> IO (SocksHostAddress, PortNumber)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SocksError -> IO (SocksHostAddress, PortNumber)
forall e a. Exception e => e -> IO a
throwIO (SocksHostAddress, PortNumber) -> IO (SocksHostAddress, PortNumber)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

-- this function expect all the data to be consumed. this is fine for intertwined message,
-- but might not be a good idea for multi messages from one party.
runGetDone :: Serialize a => Get a -> IO ByteString -> IO a
runGetDone :: forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone Get a
getter IO ByteString
ioget = IO ByteString
ioget IO ByteString -> (ByteString -> IO (Result a)) -> IO (Result a)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result a -> IO (Result a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result a -> IO (Result a))
-> (ByteString -> Result a) -> ByteString -> IO (Result a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Get a -> ByteString -> Result a
forall a. Get a -> ByteString -> Result a
runGetPartial Get a
getter IO (Result a) -> (Result a -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result a -> IO a
forall {b}. Result b -> IO b
r where
#if MIN_VERSION_cereal(0,4,0)
    r :: Result b -> IO b
r (Fail String
s ByteString
_)     = String -> IO b
forall a. HasCallStack => String -> a
error String
s
#else
    r (Fail s)       = error s
#endif
    r (Partial ByteString -> Result b
cont) = IO ByteString
ioget IO ByteString -> (ByteString -> IO b) -> IO b
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result b -> IO b
r (Result b -> IO b)
-> (ByteString -> Result b) -> ByteString -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> Result b
cont
    r (Done b
a ByteString
b)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ByteString -> Bool
B.null ByteString
b = String -> IO b
forall a. HasCallStack => String -> a
error String
"got too many bytes while receiving data"
        | Bool
otherwise      = b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
a

getMore :: Socket -> IO ByteString
getMore :: Socket -> IO ByteString
getMore Socket
socket = Socket -> Int -> IO ByteString
recv Socket
socket Int
4096