{-# LANGUAGE OverloadedStrings #-}

module System.Linux.Proc.Tcp
  ( TcpSocket (..)
  , TcpState (..)
  , readProcTcpSockets
  )
  where

import           Control.Error (runExceptT, throwE)
import           Control.Monad (replicateM, void)

import           Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as Atto
import           Data.Bits ((.|.), shiftL)
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import qualified Data.List as List
import qualified Data.Text as Text

import           System.Linux.Proc.Errors (ProcError (..))
import           System.Linux.Proc.Process (ProcessId (..))
import           System.Linux.Proc.IO (readProcFile)



data TcpState
  = TcpEstablished
  | TcpSynSent
  | TcpSynReceive
  | TcpFinWait1
  | TcpFinWait2
  | TcpTimeWait
  | TcpClose
  | TcpCloseWait
  | TcpLastAck
  | TcpListen
  | TcpClosing
  | TcpNewSynReceive
  deriving (Int -> TcpState -> ShowS
[TcpState] -> ShowS
TcpState -> [Char]
(Int -> TcpState -> ShowS)
-> (TcpState -> [Char]) -> ([TcpState] -> ShowS) -> Show TcpState
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TcpState -> ShowS
showsPrec :: Int -> TcpState -> ShowS
$cshow :: TcpState -> [Char]
show :: TcpState -> [Char]
$cshowList :: [TcpState] -> ShowS
showList :: [TcpState] -> ShowS
Show, TcpState -> TcpState -> Bool
(TcpState -> TcpState -> Bool)
-> (TcpState -> TcpState -> Bool) -> Eq TcpState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TcpState -> TcpState -> Bool
== :: TcpState -> TcpState -> Bool
$c/= :: TcpState -> TcpState -> Bool
/= :: TcpState -> TcpState -> Bool
Eq)

-- | TCP socket used by a process according to the `/proc/<pid>/net/tcp`
-- file of the process. Only non-debug fields are parsed and described the socket
-- data structure.
data TcpSocket = TcpSocket
  { TcpSocket -> (ByteString, Int)
tcpLocalAddr :: !(ByteString, Int)
  , TcpSocket -> (ByteString, Int)
tcpRemoteAddr :: !(ByteString, Int)
  , TcpSocket -> TcpState
tcpTcpState :: !TcpState
  , TcpSocket -> Int
tcpUid :: !Int
  , TcpSocket -> Int
tcpInode :: !Int
  } deriving (Int -> TcpSocket -> ShowS
[TcpSocket] -> ShowS
TcpSocket -> [Char]
(Int -> TcpSocket -> ShowS)
-> (TcpSocket -> [Char])
-> ([TcpSocket] -> ShowS)
-> Show TcpSocket
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TcpSocket -> ShowS
showsPrec :: Int -> TcpSocket -> ShowS
$cshow :: TcpSocket -> [Char]
show :: TcpSocket -> [Char]
$cshowList :: [TcpSocket] -> ShowS
showList :: [TcpSocket] -> ShowS
Show)


-- | Read and parse the `/proc/<pid>/net/tcp` file. Read and parse errors are caught
-- and returned.
readProcTcpSockets :: ProcessId -> IO (Either ProcError [TcpSocket])
readProcTcpSockets :: ProcessId -> IO (Either ProcError [TcpSocket])
readProcTcpSockets ProcessId
pid =
  ExceptT ProcError IO [TcpSocket]
-> IO (Either ProcError [TcpSocket])
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT ProcError IO [TcpSocket]
 -> IO (Either ProcError [TcpSocket]))
-> ExceptT ProcError IO [TcpSocket]
-> IO (Either ProcError [TcpSocket])
forall a b. (a -> b) -> a -> b
$ do
    let fpath :: [Char]
fpath = ProcessId -> [Char]
mkNetTcpPath ProcessId
pid
    bs <- [Char] -> ExceptT ProcError IO ByteString
readProcFile [Char]
fpath
    case Atto.parseOnly (pTcpSocketList <* Atto.endOfInput) bs of
      Left  [Char]
e  -> ProcError -> ExceptT ProcError IO [TcpSocket]
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (ProcError -> ExceptT ProcError IO [TcpSocket])
-> ProcError -> ExceptT ProcError IO [TcpSocket]
forall a b. (a -> b) -> a -> b
$ [Char] -> Text -> ProcError
ProcParseError [Char]
fpath ([Char] -> Text
Text.pack [Char]
e)
      Right [TcpSocket]
ss -> [TcpSocket] -> ExceptT ProcError IO [TcpSocket]
forall a. a -> ExceptT ProcError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TcpSocket]
ss


-- -----------------------------------------------------------------------------
-- Internals.

mkNetTcpPath :: ProcessId -> FilePath
mkNetTcpPath :: ProcessId -> [Char]
mkNetTcpPath (ProcessId Int
pid) = [Char]
"/proc/" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
pid [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"/net/tcp"

-- -----------------------------------------------------------------------------
-- Parsers.

pTcpSocketList :: Parser [TcpSocket]
pTcpSocketList :: Parser [TcpSocket]
pTcpSocketList = Parser ByteString ()
pHeaders Parser ByteString () -> Parser [TcpSocket] -> Parser [TcpSocket]
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser ByteString TcpSocket -> Parser [TcpSocket]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
Atto.many' Parser ByteString TcpSocket
pTcpSocket

-- Parse a single pSpace. The net/tcp file does not use tabs. Attoparsec's pSpace
-- includes tab, newline and return feed which captures too much in our case.
pSpace :: Parser Char
pSpace :: Parser Char
pSpace = Char -> Parser Char
Atto.char Char
' '

pMany1Space :: Parser ()
pMany1Space :: Parser ByteString ()
pMany1Space = Parser ByteString [Char] -> Parser ByteString ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parser ByteString [Char] -> Parser ByteString ())
-> Parser ByteString [Char] -> Parser ByteString ()
forall a b. (a -> b) -> a -> b
$ Parser Char -> Parser ByteString [Char]
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 Parser Char
pSpace

pStringSpace :: ByteString -> Parser ()
pStringSpace :: ByteString -> Parser ByteString ()
pStringSpace ByteString
s =
  ByteString -> Parser ByteString
Atto.string ByteString
s Parser ByteString -> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser ByteString ()
pMany1Space

pHeaders :: Parser ()
pHeaders :: Parser ByteString ()
pHeaders =
  Parser ByteString ()
pMany1Space
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"sl"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"local_address"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"rem_address"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"st"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"tx_queue"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"rx_queue"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"tr"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"tm->when"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"retrnsmt"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"uid"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"timeout inode"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
Atto.endOfLine

pTcpSocket :: Parser TcpSocket
pTcpSocket :: Parser ByteString TcpSocket
pTcpSocket = do
  _          <- Parser ByteString ()
pMany1Space
  _          <- (Atto.many1 Atto.digit *> Atto.char ':') <* pMany1Space -- Parse kernel slot
  localAddr  <- pAddressPort <* pMany1Space
  remoteAddr <- pAddressPort <* pMany1Space
  tcpState   <- pTcpState <* pMany1Space
  _          <- pInternalData
  uid        <- Atto.decimal <* pMany1Space
  _          <- Atto.hexadecimal <* pMany1Space :: Parser Int -- internal kernel state
  inode      <- Atto.decimal <* pMany1Space :: Parser Int
  _          <- Atto.many1 (Atto.satisfy (/= '\n')) -- remaining internal state
  _          <- Atto.endOfLine
  pure $ TcpSocket localAddr remoteAddr tcpState uid inode

pInternalData :: Parser ()
pInternalData :: Parser ByteString ()
pInternalData = do
  _ <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal :: Parser Int -- outgoing data queue
  _ <- Atto.char ':'
  _ <- Atto.hexadecimal :: Parser Int -- incoming data queue
  _ <- Atto.many1 pSpace
  _ <- Atto.hexadecimal :: Parser Int -- internal kernel state
  _ <- Atto.char ':'
  _ <- Atto.hexadecimal :: Parser Int -- internal kernel state
  _ <- Atto.many1 pSpace
  _ <- Atto.hexadecimal :: Parser Int -- internal kernel state
  _ <- Atto.many1 pSpace
  pure ()

-- The address parts of the `net/tcp` file is a hexadecimal representation of the IP
-- address and the port. The octets of the IP address have been reversed: 127.0.0.1
-- has been reversed to 1.0.0.127 and then rendered as hex numbers. The port is only
-- rendered as a hex number; it's not been reversed.
pAddressPort :: Parser (ByteString, Int)
pAddressPort :: Parser (ByteString, Int)
pAddressPort = do
  addrParts <- Int -> Parser Int -> Parser ByteString [Int]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
4 (Parser Int -> Parser ByteString [Int])
-> Parser Int -> Parser ByteString [Int]
forall a b. (a -> b) -> a -> b
$ Int -> Parser Int
pHexadecimalOfLength Int
2
  _         <- Atto.char ':'
  port      <- pHexadecimalOfLength 4
  let addr' =
        [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString)
-> ([Int] -> [ByteString]) -> [Int] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
List.intersperse ByteString
"." ([ByteString] -> [ByteString])
-> ([Int] -> [ByteString]) -> [Int] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> ByteString) -> [Int] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Char] -> ByteString
BS.pack ([Char] -> ByteString) -> (Int -> [Char]) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Char]
forall a. Show a => a -> [Char]
show) ([Int] -> ByteString) -> [Int] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
addrParts
  pure (addr', port)

-- See include/net/tcp_states.h of your kernel's source code for all possible states.
pTcpState :: Parser TcpState
pTcpState :: Parser TcpState
pTcpState =
    Char -> TcpState
lookupState (Char -> TcpState) -> Parser Char -> Parser TcpState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Char -> Parser Char
Atto.char Char
'0' Parser Char -> Parser Char -> Parser Char
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Char
Atto.anyChar)
  where
    lookupState :: Char -> TcpState
    lookupState :: Char -> TcpState
lookupState Char
'1' = TcpState
TcpEstablished
    lookupState Char
'2' = TcpState
TcpSynSent
    lookupState Char
'3' = TcpState
TcpSynReceive
    lookupState Char
'4' = TcpState
TcpFinWait1
    lookupState Char
'5' = TcpState
TcpFinWait2
    lookupState Char
'6' = TcpState
TcpTimeWait
    lookupState Char
'7' = TcpState
TcpClose
    lookupState Char
'8' = TcpState
TcpCloseWait
    lookupState Char
'9' = TcpState
TcpLastAck
    lookupState Char
'A' = TcpState
TcpListen
    lookupState Char
'B' = TcpState
TcpClosing
    lookupState Char
'C' = TcpState
TcpNewSynReceive
    lookupState Char
c = [Char] -> TcpState
forall a. HasCallStack => [Char] -> a
error ([Char] -> TcpState) -> [Char] -> TcpState
forall a b. (a -> b) -> a -> b
$ [Char]
"System.Linux.Proc.Tcp.pTcpState: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Char -> [Char]
forall a. Show a => a -> [Char]
show Char
c

-- Helper parser for hexadecimal strings of a known length. Attoparsec's hexadecimal
-- will keep parsing digits to cover cases like '1', 'AB2', 'deadbeef', etc. In our
-- case we need to parse cases of exact length like port numbers.
pHexadecimalOfLength :: Int -> Parser Int
pHexadecimalOfLength :: Int -> Parser Int
pHexadecimalOfLength Int
n = do
  ds <- Int -> Parser Char -> Parser ByteString [Char]
forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
Atto.count Int
n ((Char -> Bool) -> Parser Char
Atto.satisfy (Int -> Bool
isHexDigit (Int -> Bool) -> (Char -> Int) -> Char -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
forall a. Enum a => a -> Int
fromEnum))
  return $ foldl step 0 (fmap (fromEnum :: Char -> Int) ds)
 where
  isHexDigit :: Int -> Bool
  isHexDigit :: Int -> Bool
isHexDigit Int
w =
    (Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
48 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
57) Bool -> Bool -> Bool
|| (Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
97 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
102) Bool -> Bool -> Bool
|| (Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
65 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
70)
  step :: Int -> Int -> Int
  step :: Int -> Int -> Int
step Int
a Int
w | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
48 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
57 = (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
48)
           | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
97            = (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
87)
           | Bool
otherwise          = (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
55)