{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Hans.Layer.Tcp.Types where

import Hans.Address.IP4
import Hans.Layer.Tcp.Window
import Hans.Message.Tcp
import Hans.Ports

import Control.Exception (Exception,SomeException,toException)
import Control.Monad ( guard )
import Data.Time.Clock.POSIX (POSIXTime)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Word (Word16,Word32)
import qualified Data.Map.Strict as Map
import qualified Data.Sequence as Seq


-- Hosts Information -----------------------------------------------------------

data Host = Host
  { hostConnections   :: !Connections
  , hostTimeWaits     :: !TimeWaitConnections
  , hostInitialSeqNum :: {-# UNPACK #-} !TcpSeqNum
  , hostPorts         :: !(PortManager TcpPort)
  , hostLastUpdate    :: !POSIXTime
  }

emptyHost :: POSIXTime -> Host
emptyHost start = Host
  { hostConnections   = Map.empty
  , hostTimeWaits     = Map.empty
    -- XXX what should we seed this with?
  , hostInitialSeqNum = 0
  , hostPorts         = emptyPortManager [32768 .. 61000]
  , hostLastUpdate    = start
  }

takePort :: Host -> Maybe (TcpPort,Host)
takePort host = do
  (p,ps) <- nextPort (hostPorts host)
  return (p, host { hostPorts = ps })

releasePort :: TcpPort -> Host -> Host
releasePort p host = fromMaybe host $ do
  ps <- unreserve p (hostPorts host)
  return host { hostPorts = ps }


-- Connections -----------------------------------------------------------------

type Connections = Map.Map SocketId TcpSocket

removeClosed :: Connections -> Connections
removeClosed  =
  Map.filter (\tcp -> tcpState tcp /= Closed || not (tcpUserClosed tcp))

data SocketId = SocketId
  { sidLocalPort  :: {-# UNPACK #-} !TcpPort
  , sidRemotePort :: {-# UNPACK #-} !TcpPort
  , sidRemoteHost :: {-# UNPACK #-} !IP4
  } deriving (Eq,Show,Ord)

emptySocketId :: SocketId
emptySocketId  = SocketId
  { sidLocalPort  = TcpPort 0
  , sidRemotePort = TcpPort 0
  , sidRemoteHost = IP4 0 0 0 0
  }

listenSocketId :: TcpPort -> SocketId
listenSocketId port = emptySocketId { sidLocalPort = port }

incomingSocketId :: IP4 -> TcpHeader -> SocketId
incomingSocketId remote hdr = SocketId
  { sidLocalPort  = tcpDestPort hdr
  , sidRemotePort = tcpSourcePort hdr
  , sidRemoteHost = remote
  }

data SocketResult a
  = SocketResult a
  | SocketError SomeException
    deriving (Show)

socketError :: Exception e => e -> SocketResult a
socketError  = SocketError . toException


-- Connections in TimeWait -----------------------------------------------------

-- | The socket that's in TimeWait, plus its current 2MSL value.
type TimeWaitConnections = Map.Map SocketId TimeWaitSock

data TimeWaitSock = TimeWaitSock { tw2MSL      :: {-# UNPACK #-} !SlowTicks
                                   -- ^ The current 2MSL value
                                 , twInit2MSL  :: {-# UNPACK #-} !SlowTicks
                                   -- ^ The initial 2MSL value
                                 , twSeqNum    :: {-# UNPACK #-} !TcpSeqNum
                                   -- ^ The sequence number to use when
                                   -- responding to messages
                                 , twTimestamp :: !(Maybe Timestamp)
                                   -- ^ The last timestamp used
                                 } deriving (Show)

twReset2MSL :: TimeWaitSock -> TimeWaitSock
twReset2MSL tw = tw { tw2MSL = twInit2MSL tw }

mkTimeWait :: TcpSocket -> TimeWaitSock
mkTimeWait TcpSocket { .. } =
  TimeWaitSock { tw2MSL      = timeout
               , twInit2MSL  = timeout
               , twSeqNum    = tcpSndNxt
               , twTimestamp = tcpTimestamp
               }
  where
  timeout | tt2MSL tcpTimers <= 0 = 2 * mslTimeout
          | otherwise             = tt2MSL tcpTimers

-- | Add a socket to the TimeWait map.
addTimeWait :: TcpSocket -> TimeWaitConnections -> TimeWaitConnections
addTimeWait tcp socks = Map.insert (tcpSocketId tcp) (mkTimeWait tcp) socks

-- | Take one step, collecting any connections whose 2MSL timer goes to 0.
stepTimeWaitConnections :: TimeWaitConnections -> TimeWaitConnections
stepTimeWaitConnections  = Map.mapMaybe $ \ tw ->
  do guard (tw2MSL tw > 0)
     return tw { tw2MSL = tw2MSL tw - 1 }


-- Timers ----------------------------------------------------------------------

type SlowTicks = Int

-- | MSL is 60 seconds, which is slightly more aggressive than the 2 minutes
-- from the original RFC.
mslTimeout :: SlowTicks
mslTimeout = 2 * 60

data TcpTimers = TcpTimers
  { ttDelayedAck :: !Bool

    -- 2MSL
  , tt2MSL       :: {-# UNPACK #-} !SlowTicks

    -- retransmit timer
  , ttRTO        :: {-# UNPACK #-} !SlowTicks
  , ttSRTT       :: !POSIXTime
  , ttRTTVar     :: !POSIXTime

    -- idle timer
  , ttMaxIdle    :: {-# UNPACK #-} !SlowTicks
  , ttIdle       :: {-# UNPACK #-} !SlowTicks
  }

emptyTcpTimers :: TcpTimers
emptyTcpTimers  = TcpTimers
  { ttDelayedAck = False
  , tt2MSL       = 0
  , ttRTO        = 2 -- one second
  , ttSRTT       = 0
  , ttRTTVar     = 0
  , ttMaxIdle    = 10 * 60 * 2
  , ttIdle       = 0
  }


-- Timestamp Management --------------------------------------------------------

-- | Manage the timestamp values that are in flight between two hosts.
data Timestamp = Timestamp
  { tsTimestamp     :: {-# UNPACK #-} !Word32
  , tsLastTimestamp :: {-# UNPACK #-} !Word32
  , tsGranularity   :: !POSIXTime -- ^ Hz
  , tsLastUpdate    :: !POSIXTime
  } deriving (Show)

emptyTimestamp :: POSIXTime -> Timestamp
emptyTimestamp start = Timestamp
  { tsTimestamp     = 0
  , tsLastTimestamp = 0
  , tsGranularity   = 200
  , tsLastUpdate    = start
  }

-- | Update the timestamp value, advancing based on the timestamp granularity.
-- If the number of ticks to advance is 0, don't advance the timestamp.
stepTimestamp :: POSIXTime -> Timestamp -> Timestamp
stepTimestamp now ts
  | diff == 0 = ts
  | otherwise = ts
    { tsLastUpdate = now
    , tsTimestamp  = tsTimestamp ts + diff
    }
  where
  diff = ceiling (tsGranularity ts * (now - tsLastUpdate ts))

-- | Generate timestamp option for an outgoing packet.
mkTimestamp :: Timestamp -> TcpOption
mkTimestamp ts = OptTimestamp (tsTimestamp ts) (tsLastTimestamp ts)


-- Internal Sockets ------------------------------------------------------------

type Acceptor = SocketId -> IO ()

type Notify = Bool -> IO ()

type Close = IO ()

data TcpSocket = TcpSocket
  { tcpParent      :: !(Maybe SocketId)
  , tcpSocketId    :: {-# UNPACK #-} !SocketId
  , tcpState       :: !ConnState
  , tcpAcceptors   :: !(Seq.Seq Acceptor)
  , tcpNotify      :: !(Maybe Notify)
  , tcpIss         :: {-# UNPACK #-} !TcpSeqNum
  , tcpSndNxt      :: {-# UNPACK #-} !TcpSeqNum
  , tcpSndUna      :: {-# UNPACK #-} !TcpSeqNum

  , tcpUserClosed  :: !Bool
  , tcpOut         :: !RemoteWindow
  , tcpOutBuffer   :: !(Buffer Outgoing)
  , tcpOutMSS      :: {-# UNPACK #-} !Int64
  , tcpIn          :: !LocalWindow
  , tcpInBuffer    :: !(Buffer Incoming)
  , tcpInMSS       :: {-# UNPACK #-} !Int64

  , tcpTimers      :: {-# UNPACK #-} !TcpTimers
  , tcpTimestamp   :: !(Maybe Timestamp)

  , tcpSack        :: !Bool
  , tcpWindowScale :: !Bool
  }

emptyTcpSocket :: Word16 -> Int -> TcpSocket
emptyTcpSocket sendWindow sendScale = TcpSocket
  { tcpParent      = Nothing
  , tcpSocketId    = emptySocketId
  , tcpState       = Closed
  , tcpAcceptors   = Seq.empty
  , tcpNotify      = Nothing
  , tcpIss         = 0
  , tcpSndNxt      = 0
  , tcpSndUna      = 0

  , tcpUserClosed  = False
  , tcpOut         = emptyRemoteWindow sendWindow sendScale
  , tcpOutBuffer   = emptyBuffer 16384
  , tcpOutMSS      = defaultMSS
  , tcpIn          = emptyLocalWindow 0 14600 0
  , tcpInBuffer    = emptyBuffer 16384
  , tcpInMSS       = defaultMSS

  , tcpTimers      = emptyTcpTimers
  , tcpTimestamp   = Nothing

  , tcpSack        = True
  , tcpWindowScale = True
  }

defaultMSS :: Int64
defaultMSS  = 1460

nothingOutstanding :: TcpSocket -> Bool
nothingOutstanding TcpSocket { .. } = tcpSndUna == tcpSndNxt

tcpRcvNxt :: TcpSocket -> TcpSeqNum
tcpRcvNxt = lwRcvNxt . tcpIn

inRcvWnd :: TcpSeqNum -> TcpSocket -> Bool
inRcvWnd (TcpSeqNum sn) TcpSocket { .. } =
  rcvNxt <= sn && sn < rcvNxt + fromIntegral (lwRcvWind tcpIn)
  where
  TcpSeqNum rcvNxt = lwRcvNxt tcpIn

nextSegSize :: TcpSocket -> Int64
nextSegSize tcp =
  min (fromIntegral (rwAvailable (tcpOut tcp))) (tcpOutMSS tcp)

isAccepting :: TcpSocket -> Bool
isAccepting  = not . Seq.null . tcpAcceptors

needsDelayedAck :: TcpSocket -> Bool
needsDelayedAck  = ttDelayedAck . tcpTimers

mkMSS :: TcpSocket -> TcpOption
mkMSS tcp = OptMaxSegmentSize (fromIntegral (tcpInMSS tcp))

data ConnState
  = Closed
  | Listen
  | SynSent
  | SynReceived
  | Established
  | CloseWait
  | FinWait1
  | FinWait2
  | Closing
  | LastAck
  | TimeWait
    deriving (Show,Eq,Ord)