{-# LANGUAGE OverloadedStrings #-}

module System.Linux.Proc.Tcp
  ( TcpSocket (..)
  , TcpState (..)
  , readProcTcpSockets
  )
  where

import           Control.Error (runExceptT, throwE)
import           Control.Monad (replicateM, void)

import           Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as Atto
import           Data.Bits ((.|.), shiftL)
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import qualified Data.List as List
import qualified Data.Text as Text

import           System.Linux.Proc.Errors (ProcError (..))
import           System.Linux.Proc.Process (ProcessId (..))
import           System.Linux.Proc.IO (readProcFile)



data TcpState
  = TcpEstablished
  | TcpSynSent
  | TcpSynReceive
  | TcpFinWait1
  | TcpFinWait2
  | TcpTimeWait
  | TcpClose
  | TcpCloseWait
  | TcpLastAck
  | TcpListen
  | TcpClosing
  | TcpNewSynReceive
  deriving (Int -> TcpState -> ShowS
[TcpState] -> ShowS
TcpState -> String
(Int -> TcpState -> ShowS)
-> (TcpState -> String) -> ([TcpState] -> ShowS) -> Show TcpState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TcpState] -> ShowS
$cshowList :: [TcpState] -> ShowS
show :: TcpState -> String
$cshow :: TcpState -> String
showsPrec :: Int -> TcpState -> ShowS
$cshowsPrec :: Int -> TcpState -> ShowS
Show, TcpState -> TcpState -> Bool
(TcpState -> TcpState -> Bool)
-> (TcpState -> TcpState -> Bool) -> Eq TcpState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TcpState -> TcpState -> Bool
$c/= :: TcpState -> TcpState -> Bool
== :: TcpState -> TcpState -> Bool
$c== :: TcpState -> TcpState -> Bool
Eq)

-- | TCP socket used by a process according to the `/proc/<pid>/net/tcp`
-- file of the process. Only non-debug fields are parsed and described the socket
-- data structure.
data TcpSocket = TcpSocket
  { TcpSocket -> (ByteString, Int)
tcpLocalAddr :: !(ByteString, Int)
  , TcpSocket -> (ByteString, Int)
tcpRemoteAddr :: !(ByteString, Int)
  , TcpSocket -> TcpState
tcpTcpState :: !TcpState
  , TcpSocket -> Int
tcpUid :: !Int
  , TcpSocket -> Int
tcpInode :: !Int
  } deriving (Int -> TcpSocket -> ShowS
[TcpSocket] -> ShowS
TcpSocket -> String
(Int -> TcpSocket -> ShowS)
-> (TcpSocket -> String)
-> ([TcpSocket] -> ShowS)
-> Show TcpSocket
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TcpSocket] -> ShowS
$cshowList :: [TcpSocket] -> ShowS
show :: TcpSocket -> String
$cshow :: TcpSocket -> String
showsPrec :: Int -> TcpSocket -> ShowS
$cshowsPrec :: Int -> TcpSocket -> ShowS
Show)


-- | Read and parse the `/proc/<pid>/net/tcp` file. Read and parse errors are caught
-- and returned.
readProcTcpSockets :: ProcessId -> IO (Either ProcError [TcpSocket])
readProcTcpSockets :: ProcessId -> IO (Either ProcError [TcpSocket])
readProcTcpSockets ProcessId
pid =
  ExceptT ProcError IO [TcpSocket]
-> IO (Either ProcError [TcpSocket])
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT ProcError IO [TcpSocket]
 -> IO (Either ProcError [TcpSocket]))
-> ExceptT ProcError IO [TcpSocket]
-> IO (Either ProcError [TcpSocket])
forall a b. (a -> b) -> a -> b
$ do
    let fpath :: String
fpath = ProcessId -> String
mkNetTcpPath ProcessId
pid
    ByteString
bs <- String -> ExceptT ProcError IO ByteString
readProcFile String
fpath
    case Parser [TcpSocket] -> ByteString -> Either String [TcpSocket]
forall a. Parser a -> ByteString -> Either String a
Atto.parseOnly (Parser [TcpSocket]
pTcpSocketList Parser [TcpSocket] -> Parser ByteString () -> Parser [TcpSocket]
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
forall t. Chunk t => Parser t ()
Atto.endOfInput) ByteString
bs of
      Left  String
e  -> ProcError -> ExceptT ProcError IO [TcpSocket]
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (ProcError -> ExceptT ProcError IO [TcpSocket])
-> ProcError -> ExceptT ProcError IO [TcpSocket]
forall a b. (a -> b) -> a -> b
$ String -> Text -> ProcError
ProcParseError String
fpath (String -> Text
Text.pack String
e)
      Right [TcpSocket]
ss -> [TcpSocket] -> ExceptT ProcError IO [TcpSocket]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TcpSocket]
ss


-- -----------------------------------------------------------------------------
-- Internals.

mkNetTcpPath :: ProcessId -> FilePath
mkNetTcpPath :: ProcessId -> String
mkNetTcpPath (ProcessId Int
pid) = String
"/proc/" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
pid String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/net/tcp"

-- -----------------------------------------------------------------------------
-- Parsers.

pTcpSocketList :: Parser [TcpSocket]
pTcpSocketList :: Parser [TcpSocket]
pTcpSocketList = Parser ByteString ()
pHeaders Parser ByteString () -> Parser [TcpSocket] -> Parser [TcpSocket]
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser ByteString TcpSocket -> Parser [TcpSocket]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
Atto.many' Parser ByteString TcpSocket
pTcpSocket

-- Parse a single pSpace. The net/tcp file does not use tabs. Attoparsec's pSpace
-- includes tab, newline and return feed which captures too much in our case.
pSpace :: Parser Char
pSpace :: Parser Char
pSpace = Char -> Parser Char
Atto.char Char
' '

pMany1Space :: Parser ()
pMany1Space :: Parser ByteString ()
pMany1Space = Parser ByteString String -> Parser ByteString ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parser ByteString String -> Parser ByteString ())
-> Parser ByteString String -> Parser ByteString ()
forall a b. (a -> b) -> a -> b
$ Parser Char -> Parser ByteString String
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 Parser Char
pSpace

pStringSpace :: ByteString -> Parser ()
pStringSpace :: ByteString -> Parser ByteString ()
pStringSpace ByteString
s =
  ByteString -> Parser ByteString
Atto.string ByteString
s Parser ByteString -> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser ByteString ()
pMany1Space

pHeaders :: Parser ()
pHeaders :: Parser ByteString ()
pHeaders =
  Parser ByteString ()
pMany1Space
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"sl"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"local_address"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"rem_address"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"st"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"tx_queue"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"rx_queue"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"tr"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"tm->when"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"retrnsmt"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"uid"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ByteString -> Parser ByteString ()
pStringSpace ByteString
"timeout inode"
    Parser ByteString ()
-> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
Atto.endOfLine

pTcpSocket :: Parser TcpSocket
pTcpSocket :: Parser ByteString TcpSocket
pTcpSocket = do
  ()
_          <- Parser ByteString ()
pMany1Space
  Char
_          <- (Parser Char -> Parser ByteString String
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 Parser Char
Atto.digit Parser ByteString String -> Parser Char -> Parser Char
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Char -> Parser Char
Atto.char Char
':') Parser Char -> Parser ByteString () -> Parser Char
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space -- Parse kernel slot
  (ByteString, Int)
localAddr  <- Parser (ByteString, Int)
pAddressPort Parser (ByteString, Int)
-> Parser ByteString () -> Parser (ByteString, Int)
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space
  (ByteString, Int)
remoteAddr <- Parser (ByteString, Int)
pAddressPort Parser (ByteString, Int)
-> Parser ByteString () -> Parser (ByteString, Int)
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space
  TcpState
tcpState   <- Parser TcpState
pTcpState Parser TcpState -> Parser ByteString () -> Parser TcpState
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space
  ()
_          <- Parser ByteString ()
pInternalData
  Int
uid        <- Parser Int
forall a. Integral a => Parser a
Atto.decimal Parser Int -> Parser ByteString () -> Parser Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space
  Int
_          <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal Parser Int -> Parser ByteString () -> Parser Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space :: Parser Int -- internal kernel state
  Int
inode      <- Parser Int
forall a. Integral a => Parser a
Atto.decimal Parser Int -> Parser ByteString () -> Parser Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
pMany1Space :: Parser Int
  String
_          <- Parser Char -> Parser ByteString String
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 ((Char -> Bool) -> Parser Char
Atto.satisfy (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'\n')) -- remaining internal state
  ()
_          <- Parser ByteString ()
Atto.endOfLine
  TcpSocket -> Parser ByteString TcpSocket
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcpSocket -> Parser ByteString TcpSocket)
-> TcpSocket -> Parser ByteString TcpSocket
forall a b. (a -> b) -> a -> b
$ (ByteString, Int)
-> (ByteString, Int) -> TcpState -> Int -> Int -> TcpSocket
TcpSocket (ByteString, Int)
localAddr (ByteString, Int)
remoteAddr TcpState
tcpState Int
uid Int
inode

pInternalData :: Parser ()
pInternalData :: Parser ByteString ()
pInternalData = do
  Int
_ <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal :: Parser Int -- outgoing data queue
  Char
_ <- Char -> Parser Char
Atto.char Char
':'
  Int
_ <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal :: Parser Int -- incoming data queue
  String
_ <- Parser Char -> Parser ByteString String
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 Parser Char
pSpace
  Int
_ <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal :: Parser Int -- internal kernel state
  Char
_ <- Char -> Parser Char
Atto.char Char
':'
  Int
_ <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal :: Parser Int -- internal kernel state
  String
_ <- Parser Char -> Parser ByteString String
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 Parser Char
pSpace
  Int
_ <- Parser Int
forall a. (Integral a, Bits a) => Parser a
Atto.hexadecimal :: Parser Int -- internal kernel state
  String
_ <- Parser Char -> Parser ByteString String
forall (f :: * -> *) a. Alternative f => f a -> f [a]
Atto.many1 Parser Char
pSpace
  () -> Parser ByteString ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- The address parts of the `net/tcp` file is a hexadecimal representation of the IP
-- address and the port. The octets of the IP address have been reversed: 127.0.0.1
-- has been reversed to 1.0.0.127 and then rendered as hex numbers. The port is only
-- rendered as a hex number; it's not been reversed.
pAddressPort :: Parser (ByteString, Int)
pAddressPort :: Parser (ByteString, Int)
pAddressPort = do
  [Int]
addrParts <- Int -> Parser Int -> Parser ByteString [Int]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
4 (Parser Int -> Parser ByteString [Int])
-> Parser Int -> Parser ByteString [Int]
forall a b. (a -> b) -> a -> b
$ Int -> Parser Int
pHexadecimalOfLength Int
2
  Char
_         <- Char -> Parser Char
Atto.char Char
':'
  Int
port      <- Int -> Parser Int
pHexadecimalOfLength Int
4
  let addr' :: ByteString
addr' =
        [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString)
-> ([Int] -> [ByteString]) -> [Int] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
List.intersperse ByteString
"." ([ByteString] -> [ByteString])
-> ([Int] -> [ByteString]) -> [Int] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> ByteString) -> [Int] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> ByteString
BS.pack (String -> ByteString) -> (Int -> String) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show) ([Int] -> ByteString) -> [Int] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
addrParts
  (ByteString, Int) -> Parser (ByteString, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
addr', Int
port)

-- See include/net/tcp_states.h of your kernel's source code for all possible states.
pTcpState :: Parser TcpState
pTcpState :: Parser TcpState
pTcpState =
    Char -> TcpState
lookupState (Char -> TcpState) -> Parser Char -> Parser TcpState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Char -> Parser Char
Atto.char Char
'0' Parser Char -> Parser Char -> Parser Char
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Char
Atto.anyChar)
  where
    lookupState :: Char -> TcpState
    lookupState :: Char -> TcpState
lookupState Char
'1' = TcpState
TcpEstablished
    lookupState Char
'2' = TcpState
TcpSynSent
    lookupState Char
'3' = TcpState
TcpSynReceive
    lookupState Char
'4' = TcpState
TcpFinWait1
    lookupState Char
'5' = TcpState
TcpFinWait2
    lookupState Char
'6' = TcpState
TcpTimeWait
    lookupState Char
'7' = TcpState
TcpClose
    lookupState Char
'8' = TcpState
TcpCloseWait
    lookupState Char
'9' = TcpState
TcpLastAck
    lookupState Char
'A' = TcpState
TcpListen
    lookupState Char
'B' = TcpState
TcpClosing
    lookupState Char
'C' = TcpState
TcpNewSynReceive
    lookupState Char
c = String -> TcpState
forall a. HasCallStack => String -> a
error (String -> TcpState) -> String -> TcpState
forall a b. (a -> b) -> a -> b
$ String
"System.Linux.Proc.Tcp.pTcpState: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Char -> String
forall a. Show a => a -> String
show Char
c

-- Helper parser for hexadecimal strings of a known length. Attoparsec's hexadecimal
-- will keep parsing digits to cover cases like '1', 'AB2', 'deadbeef', etc. In our
-- case we need to parse cases of exact length like port numbers.
pHexadecimalOfLength :: Int -> Parser Int
pHexadecimalOfLength :: Int -> Parser Int
pHexadecimalOfLength Int
n = do
  String
ds <- Int -> Parser Char -> Parser ByteString String
forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
Atto.count Int
n ((Char -> Bool) -> Parser Char
Atto.satisfy (Int -> Bool
isHexDigit (Int -> Bool) -> (Char -> Int) -> Char -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
forall a. Enum a => a -> Int
fromEnum))
  Int -> Parser Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Parser Int) -> Int -> Parser Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
step Int
0 ((Char -> Int) -> String -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Char -> Int
forall a. Enum a => a -> Int
fromEnum :: Char -> Int) String
ds)
 where
  isHexDigit :: Int -> Bool
  isHexDigit :: Int -> Bool
isHexDigit Int
w =
    (Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
48 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
57) Bool -> Bool -> Bool
|| (Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
97 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
102) Bool -> Bool -> Bool
|| (Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
65 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
70)
  step :: Int -> Int -> Int
  step :: Int -> Int -> Int
step Int
a Int
w | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
48 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
57 = (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
48)
           | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
97            = (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
87)
           | Bool
otherwise          = (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
55)