{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language DuplicateRecordFields #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
{-# language StandaloneDeriving #-}
{-# language UnboxedTuples #-}

-- | Internet datagram sockets without a fixed destination.
-- The user may spoof the source address and may specify the
-- packet ID. An application must have @CAP_NET_RAW@ or be
-- running as root to use the functions in this module.
module Socket.Datagram.IPv4.Spoof
  ( -- * Types
    Socket(..)
  , Endpoint(..)
  , Message(..)
    -- * Establish
  , withSocket
    -- * Communicate
  , sendMutableByteArray
    -- * Exceptions
  , SocketException(..)
  , SendException(..)
  ) where

import Control.Concurrent (threadWaitWrite)
import Control.Exception (Exception,throwIO,mask,onException)
import Data.Bits (unsafeShiftR,complement,(.&.))
import Data.Kind (Type)
import Data.Primitive (MutableByteArray(..))
import Data.Word (Word16,Word8,Word64,Word32)
import Foreign.C.Error (Errno(..),eWOULDBLOCK,eAGAIN,eMFILE,eNFILE,eACCES,ePERM)
import Foreign.C.Types (CInt,CSize)
import GHC.Exts (RealWorld,touch#)
import GHC.IO (IO(..))
import Net.Types (IPv4(..))
import Socket (SocketUnrecoverableException(..),Interruptibility(..))
import Socket.Datagram (SendException(..))
import Socket.Datagram.IPv4.Undestined.Internal (Message(..))
import Socket.Debug (debug,whenDebugging)
import Socket.IPv4 (Endpoint(..))
import System.Posix.Types (Fd)
import Text.Printf (printf)

import qualified Data.Primitive as PM
import qualified Linux.Socket as L
import qualified Posix.Socket as S
import qualified GHC.Exts as E
import qualified Socket as SCK

-- TODO: Something I am not sure about is whether or not it is necessary
-- to bind to a port right after creating the socket. If we defering
-- binding until the call to the time of sending, what port do we bind
-- to? Does the kernel use the one in the source port or does it choose
-- an ephemeral port? We need it to choose an ephemeral port. Otherwise,
-- we can get spurious failures.

-- | A socket that send datagrams with spoofed source IP addresses.
-- It cannot receive datagrams.
newtype Socket = Socket Fd
  deriving stock (Eq,Ord,Show)

data SocketException :: Type where
  -- | Permission to create a raw socket was denied. The process needs
  --   the capability @CAP_NET_RAW@, or it must be run as root.
  SocketPermissionDenied :: SocketException
  -- | A limit on the number of open file descriptors has been reached.
  --   This could be the per-process limit or the system limit.
  --   (@EMFILE@ and @ENFILE@)
  SocketFileDescriptorLimit :: SocketException

deriving stock instance Show SocketException
deriving anyclass instance Exception SocketException

-- | Open a socket and run the supplied callback on it. This closes the socket
-- when the callback finishes or when an exception is thrown. Do not return 
-- the socket from the callback. This leads to undefined behavior. The user
-- cannot specify an endpoint since the socket cannot receive traffic.
withSocket ::
     (Socket -> IO a) -- ^ Callback providing the socket
  -> IO (Either SocketException a)
withSocket f = mask $ \restore -> do
  debug "withSocket: opening raw socket"
  e1 <- S.uninterruptibleSocket S.internet
    (L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.raw)
    S.rawProtocol
  debug "withSocket: opened raw socket"
  case e1 of
    Left err -> handleSocketException SCK.functionWithSocket err
    Right fd -> do
      a <- onException (restore (f (Socket fd))) (S.uninterruptibleErrorlessClose fd)
      S.uninterruptibleClose fd >>= \case
        Left err -> throwIO $ SocketUnrecoverableException
          moduleSocketDatagramIPv4Spoof
          SCK.functionWithSocket
          ["close",describeErrorCode err]
        Right _ -> pure (Right a)

-- | Send a slice of a bytearray to the specified endpoint.
sendMutableByteArray ::
     Socket -- ^ Socket
  -> Endpoint -- ^ Spoofed source address and port
  -> Endpoint -- ^ Remote IPv4 address and port
  -> MutableByteArray RealWorld -- ^ Buffer (will be sliced)
  -> Int -- ^ Offset into payload
  -> Int -- ^ Lenth of slice into buffer
  -> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArray (Socket !s) !theSource !theRemote !thePayload !off !len = do
  let ipHeaderSz = cintToInt L.sizeofIpHeader
  let totalHeaderSz = cintToInt (L.sizeofIpHeader + L.sizeofUdpHeader)
  let totalPacketSz = len + totalHeaderSz
  -- Why do we add one to the size? This extra byte at the end will always
  -- be zeroed out. It makes UDP checksum calculation a little easier,
  -- since we can now pull out Word16s until we reach the end. If the
  -- original length was even, this extra byte ends up unused by the
  -- checksum. But, if the original length was odd, it does get used.
  buf <- PM.newPinnedByteArray (totalPacketSz + 1)
  PM.setByteArray buf 0 (totalPacketSz + 1) (0 :: Word8)
  let addr = PM.mutableByteArrayContents buf
  L.pokeIpHeaderVersionIhl addr (4 * 16 + 5)
  L.pokeIpHeaderTypeOfService addr 0
  -- TODO: Actually check the length.
  -- NB: The packet length must be in network byte order.
  -- Expermentally, it seems that it does not.
  L.pokeIpHeaderTotalLength addr (S.hostToNetworkShort (intToWord16 totalPacketSz))
  L.pokeIpHeaderIdentifier addr 0
  L.pokeIpHeaderFragmentOffset addr 0
  L.pokeIpHeaderTimeToLive addr 64
  L.pokeIpHeaderProtocol addr (cintToWord8 (S.getProtocol S.udp))
  -- The linux kernel fills in the ip header checksum for us.
  L.pokeIpHeaderChecksum addr 0
  let src = S.hostToNetworkLong (getIPv4 (address theSource))
  L.pokeIpHeaderSourceAddress addr src
  let dst = S.hostToNetworkLong (getIPv4 (address theRemote))
  L.pokeIpHeaderDestinationAddress addr dst
  let udpAddr = PM.plusAddr addr ipHeaderSz
  L.pokeUdpHeaderSourcePort udpAddr (S.hostToNetworkShort (port theSource))
  L.pokeUdpHeaderDestinationPort udpAddr (S.hostToNetworkShort (port theRemote))
  let udpLen = cintToInt L.sizeofUdpHeader + len
  L.pokeUdpHeaderLength udpAddr (S.hostToNetworkShort (intToWord16 udpLen))
  PM.copyMutableByteArray buf totalHeaderSz thePayload off len
  L.pokeUdpHeaderChecksum udpAddr . S.hostToNetworkShort =<< udpChecksum src dst buf ipHeaderSz udpLen
  touchMutableByteArray buf
  debug ("spoof send mutable: about to send to " ++ show theRemote)
  whenDebugging $ do
    d <- PM.newByteArray totalPacketSz
    PM.copyMutableByteArray d 0 buf 0 totalPacketSz
    x <- PM.unsafeFreezeByteArray d
    debug ("raw packet: " ++ (foldMap (printf "%.2x ") (E.toList x)))
  e1 <- S.uninterruptibleSendToMutableByteArray s buf 0 (intToCSize totalPacketSz)
    mempty
    (S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
  debug ("spoof send mutable: just sent to " ++ show theRemote)
  case e1 of
    Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
      then do
        debug ("send mutable: waiting to for write ready to send to " ++ show theRemote)
        threadWaitWrite s
        e2 <- S.uninterruptibleSendToMutableByteArray s buf
          (intToCInt off)
          (intToCSize len)
          mempty
          (S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
        case e2 of
          Left err2 -> do
            debug ("send mutable: encountered error after sending")
            handleSendException "sendMutableByteArray" err2
          Right sz -> if csizeToInt sz == totalPacketSz
            then pure (Right ())
            else pure (Left (SendTruncated (csizeToInt sz)))
      else do
        debug "spoof send mutable: sent on first try but got error code"
        handleSendException "sendMutableByteArray" err1
    Right sz -> if csizeToInt sz == totalPacketSz
      then do
        debug ("send mutable: success")
        pure (Right ())
      else pure (Left (SendTruncated (csizeToInt sz)))

-- Precondition: the mutable byte array must have an extra zeroed out
-- byte at the end. That is, at arr[offset+length], there exists a
-- zero byte. The offset must divide two evenly.
udpChecksum ::
     Word32 -- source (network byte order)
  -> Word32 -- dest (network byte order)
  -> MutableByteArray RealWorld -- payload
  -> Int -- offset (start of the udp header)
  -> Int -- length (udp header size plus payload size) 
  -> IO Word16
udpChecksum src dst payload off len = do
  let sum0 = word16ToWord64 (S.hostToNetworkShort (word32ToWord16 src))
  debug ("udp checksum source lower: " ++ printf "%.8X" sum0)
  let sum1 = sum0 + word16ToWord64 (S.hostToNetworkShort (word32ToWord16 (unsafeShiftR src 16)))
  debug ("udp checksum source lower+upper: " ++ printf "%.8X" sum1)
  let sum2 = sum1 + word16ToWord64 (S.hostToNetworkShort (word32ToWord16 dst))
      sum3 = sum2 + word16ToWord64 (S.hostToNetworkShort (word32ToWord16 (unsafeShiftR dst 16)))
  debug ("udp checksum source+dest lower+upper: " ++ printf "%.8X" sum3)
  let sum4 = sum3 + word16ToWord64 (cintToWord16 (S.getProtocol S.udp))
      sum5 = sum4 + word16ToWord64 (intToWord16 len)
  debug ("udp checksum pseudoheader without carries: " ++ printf "%.8X" sum5)
  let halfLen = unsafeShiftR (len + off) 1
  debug ("udp checksum start offset: " ++ show (unsafeShiftR off 1))
  debug ("udp checksum last offset: " ++ show halfLen)
  let go :: Int -> Word64 -> IO Word64
      go !ix !acc = if ix < halfLen
        then do
          w16 <- PM.readByteArray payload ix :: IO Word16
          debug ("udp checksum payload iteration " ++ show ix ++ ": " ++ printf "%.8X" acc)
          go (ix + 1) (word16ToWord64 (S.hostToNetworkShort w16) + acc)
        else pure acc
  r <- go (unsafeShiftR off 1) sum5
  -- We assume that the upper 16 bits in this 64-bit word are zeroes.
  -- There is no way for a datagram to be long enough to start to
  -- fill the bits beyond 48.
  pure (word64ToWord16 (complement ((r .&. 0xFFFF) + (unsafeShiftR r 16 .&. 0xFFFF) + (unsafeShiftR r 32))))

endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
  { port = S.hostToNetworkShort port
  , address = S.hostToNetworkLong (getIPv4 address)
  }

intToCInt :: Int -> CInt
intToCInt = fromIntegral

intToCSize :: Int -> CSize
intToCSize = fromIntegral

csizeToInt :: CSize -> Int
csizeToInt = fromIntegral

cintToInt :: CInt -> Int
cintToInt = fromIntegral

cintToWord8 :: CInt -> Word8
cintToWord8 = fromIntegral

intToWord16 :: Int -> Word16
intToWord16 = fromIntegral

cintToWord16 :: CInt -> Word16
cintToWord16 = fromIntegral

word16ToWord64 :: Word16 -> Word64
word16ToWord64 = fromIntegral

word64ToWord16 :: Word64 -> Word16
word64ToWord16 = fromIntegral

word32ToWord16 :: Word32 -> Word16
word32ToWord16 = fromIntegral

touchMutableByteArray :: MutableByteArray RealWorld -> IO ()
touchMutableByteArray (MutableByteArray x) = touchMutableByteArray# x

touchMutableByteArray# :: E.MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)

moduleSocketDatagramIPv4Spoof :: String
moduleSocketDatagramIPv4Spoof = "Socket.Datagram.IPv4.Spoof"

handleSocketException :: String -> Errno -> IO (Either SocketException a)
{-# INLINE handleSocketException #-}
handleSocketException func e
  | e == ePERM = pure (Left SocketPermissionDenied)
  | e == eMFILE = pure (Left SocketFileDescriptorLimit)
  | e == eNFILE = pure (Left SocketFileDescriptorLimit)
  | otherwise = throwIO $ SocketUnrecoverableException
      moduleSocketDatagramIPv4Spoof
      func
      [describeErrorCode e]

describeErrorCode :: Errno -> String
describeErrorCode (Errno e) = "error code " ++ show e

handleSendException :: String -> Errno -> IO (Either (SendException i) a)
{-# INLINE handleSendException #-}
handleSendException func e
  | e == eACCES = pure (Left SendBroadcasted)
  | otherwise = throwIO $ SocketUnrecoverableException
      moduleSocketDatagramIPv4Spoof
      func
      [describeErrorCode e]