{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE BangPatterns #-} module Hans.Tcp.State ( -- * Tcp State HasTcpState(..), TcpState(), newTcpState, -- ** Responder Interaction tcpQueue, TcpResponderRequest(..), -- ** Listen Sockets incrSynBacklog, decrSynBacklog, registerListening, lookupListening, deleteListening, -- ** Active Sockets Key(), tcbKey, tcpActive, lookupActive, registerActive, closeActive, deleteActive, -- ** TimeWait Sockets registerTimeWait, lookupTimeWait, resetTimeWait, deleteTimeWait, -- ** Port Management nextTcpPort, -- ** Sequence Numbers nextIss, ) where import Hans.Addr (Addr,wildcardAddr,putAddr) import Hans.Config (HasConfig(..),Config(..)) import qualified Hans.HashTable as HT import Hans.Lens import Hans.Network.Types (RouteInfo(..)) import Hans.Tcp.Packet import Hans.Tcp.Tcb import Hans.Threads (forkNamed) import Hans.Time import Control.Concurrent (threadDelay,MVar,newMVar,modifyMVar) import qualified Control.Concurrent.BoundedChan as BC import Control.Monad (guard) import qualified Data.ByteString as S import qualified Data.ByteString.Lazy as L import Data.Digest.Pure.SHA(sha1,integerDigest) import qualified Data.Foldable as F import Data.Hashable (Hashable) import qualified Data.Heap as H import Data.IORef (IORef,newIORef,atomicModifyIORef',readIORef) import Data.Serialize (runPutLazy,putByteString) import Data.Time.Clock (UTCTime,getCurrentTime,addUTCTime,diffUTCTime) import Data.Word (Word32) import GHC.Generics (Generic) import System.Random (newStdGen,random,randoms) -- General State --------------------------------------------------------------- data ListenKey = ListenKey !Addr !TcpPort deriving (Show,Eq,Ord,Generic) listenKey :: Getting r ListenTcb ListenKey listenKey = to (\ ListenTcb { .. } -> ListenKey lSrc lPort) {-# INLINE listenKey #-} data Key = Key !Addr -- Remote address !TcpPort -- Remote port !Addr -- Local address !TcpPort -- Local port deriving (Show,Eq,Ord,Generic) tcbKey :: Getting r Tcb Key tcbKey = to $ \Tcb { tcbRouteInfo = RouteInfo { .. }, .. } -> Key tcbRemote tcbRemotePort riSource tcbLocalPort {-# INLINE tcbKey #-} instance Hashable ListenKey instance Hashable Key type TimeWaitHeap = ExpireHeap TimeWaitTcb data TcpState = TcpState { tcpListen_ :: {-# UNPACK #-} !(HT.HashTable ListenKey ListenTcb) , tcpActive_ :: {-# UNPACK #-} !(HT.HashTable Key Tcb) , tcpTimeWait_ :: {-# UNPACK #-} !(IORef TimeWaitHeap) , tcpSynBacklog_ :: {-# UNPACK #-} !(IORef Int) -- ^ Decrements when a connection enters SynReceived or SynSent, -- and increments back up once it's closed, or enters Established. , tcpPorts :: {-# UNPACK #-} !(MVar TcpPort) , tcpISSTimer :: {-# UNPACK #-} !(IORef Tcp4USTimer) , tcpQueue_ :: {-# UNPACK #-} !(BC.BoundedChan TcpResponderRequest) } -- | Requests that can be made to the responder thread. data TcpResponderRequest = SendSegment !(RouteInfo Addr) !Addr !TcpHeader !L.ByteString | SendWithTcb !Tcb !TcpHeader !L.ByteString tcpQueue :: HasTcpState state => Getting r state (BC.BoundedChan TcpResponderRequest) tcpQueue = tcpState . to tcpQueue_ {-# INLINE tcpQueue #-} tcpListen :: HasTcpState state => Getting r state (HT.HashTable ListenKey ListenTcb) tcpListen = tcpState . to tcpListen_ {-# INLINE tcpListen #-} tcpActive :: HasTcpState state => Getting r state (HT.HashTable Key Tcb) tcpActive = tcpState . to tcpActive_ {-# INLINE tcpActive #-} tcpTimeWait :: HasTcpState state => Getting r state (IORef TimeWaitHeap) tcpTimeWait = tcpState . to tcpTimeWait_ {-# INLINE tcpTimeWait #-} tcpSynBacklog :: HasTcpState state => Getting r state (IORef Int) tcpSynBacklog = tcpState . to tcpSynBacklog_ {-# INLINE tcpSynBacklog #-} class HasTcpState state where tcpState :: Getting r state TcpState instance HasTcpState TcpState where tcpState = id {-# INLINE tcpState #-} data Tcp4USTimer = Tcp4USTimer { tcpTimer :: {-# UNPACK #-} !Word32 , tcpSecret :: {-# UNPACK #-} !S.ByteString , tcpLastUpdate :: !UTCTime } newTcp4USTimer :: IO Tcp4USTimer newTcp4USTimer = do tcpLastUpdate <- getCurrentTime gen <- newStdGen let (tcpTimer,gen') = random gen tcpSecret = S.pack (take 256 (randoms gen')) return Tcp4USTimer { .. } newTcpState :: Config -> IO TcpState newTcpState Config { .. } = do tcpListen_ <- HT.newHashTable cfgTcpListenTableSize tcpActive_ <- HT.newHashTable cfgTcpActiveTableSize tcpTimeWait_ <- newIORef emptyHeap tcpSynBacklog_ <- newIORef cfgTcpMaxSynBacklog tcpPorts <- newMVar 32767 tcpISSTimer <- newIORef =<< newTcp4USTimer tcpQueue_ <- BC.newBoundedChan 128 return TcpState { .. } -- | Returns 'True' when there is space in the Syn backlog, and False if the -- connection should be rejected. decrSynBacklog :: HasTcpState state => state -> IO Bool decrSynBacklog state = atomicModifyIORef' (view tcpSynBacklog state) $ \ backlog -> if backlog > 0 then (backlog - 1, True) else (backlog, False) -- | Yield back an entry in the Syn backlog. incrSynBacklog :: HasTcpState state => state -> IO () incrSynBacklog state = atomicModifyIORef' (view tcpSynBacklog state) (\ backlog -> (backlog + 1, ())) -- Listening Sockets ----------------------------------------------------------- -- | Register a new listening socket. registerListening :: HasTcpState state => state -> ListenTcb -> IO Bool registerListening state tcb = HT.alter update (view listenKey tcb) (view tcpListen state) where update Nothing = (Just tcb, True) update mb@Just{} = (mb, False) {-# INLINE registerListening #-} -- | Remove a listening socket. deleteListening :: HasTcpState state => state -> ListenTcb -> IO () deleteListening state tcb = HT.delete (view listenKey tcb) (view tcpListen state) {-# INLINE deleteListening #-} -- | Lookup a socket in the Listen state. lookupListening :: HasTcpState state => state -> Addr -> TcpPort -> IO (Maybe ListenTcb) lookupListening state src port = do mb <- HT.lookup (ListenKey src port) (view tcpListen state) case mb of Just {} -> return mb Nothing -> HT.lookup (ListenKey (wildcardAddr src) port) (view tcpListen state) {-# INLINE lookupListening #-} -- TimeWait Sockets ------------------------------------------------------------ -- | Register a socket in the TimeWait state. If the heap was empty, fork off a -- thread to reap its contents after the timeWaitTimeout. -- -- NOTE: this doesn't remove the original socket from the Active set. registerTimeWait :: (HasConfig state, HasTcpState state) => state -> TimeWaitTcb -> IO () registerTimeWait state tcb = let Config { .. } = view config state in updateTimeWait state $ \ now heap -> let heap' = if H.size heap >= cfgTcpTimeWaitSocketLimit then H.deleteMin heap else heap in fst (expireAt (addUTCTime cfgTcpTimeoutTimeWait now) tcb heap') -- | Reset the timer associated with a TimeWaitTcb. resetTimeWait :: (HasConfig state, HasTcpState state) => state -> TimeWaitTcb -> IO () resetTimeWait state tcb = let Config { .. } = view config state in updateTimeWait state $ \ now heap -> fst $ expireAt (addUTCTime cfgTcpTimeoutTimeWait now) tcb $ filterHeap (/= tcb) heap -- | Modify the TimeWait heap, and spawn a reaping thread when necessary. updateTimeWait :: (HasConfig state, HasTcpState state) => state -> (UTCTime -> TimeWaitHeap -> TimeWaitHeap) -> IO () updateTimeWait state update = do now <- getCurrentTime mbReap <- atomicModifyIORef' (view tcpTimeWait state) $ \ heap -> let heap' = update now heap -- Return a reaping action if: -- -- 1. The original heap was empty, signifying that there was no -- existing reaper running -- -- 2. The user action added something to the heap reaper = do guard (nullHeap heap) future <- nextEvent heap' return $ do delayDiff now future reapLoop in (heap', reaper) case mbReap of Just reaper -> do _ <- forkNamed "TimeWait Reaper" reaper return () Nothing -> return () where -- delay by at least half a second, until some point in the future. delayDiff now future = threadDelay (max 500000 (toUSeconds (diffUTCTime future now))) -- The reap thread will reap TimeWait sockets according to their expiration -- time, and then exit. reapLoop = do now <- getCurrentTime mbExpire <- atomicModifyIORef' (view tcpTimeWait state) $ \ heap -> let heap' = dropExpired now heap in (heap', nextEvent heap') case mbExpire of Just future -> do delayDiff now future reapLoop Nothing -> return () -- | Lookup a socket in the TimeWait state. lookupTimeWait :: HasTcpState state => state -> Addr -> TcpPort -> Addr -> TcpPort -> IO (Maybe TimeWaitTcb) lookupTimeWait state dst dstPort src srcPort = do heap <- readIORef (view tcpTimeWait state) return (payload `fmap` F.find isConn heap) where isConn Entry { payload = TimeWaitTcb { .. } } = and [ twRemote == dst , twRemotePort == dstPort , riSource twRouteInfo == src , twLocalPort == srcPort ] {-# INLINE lookupTimeWait #-} -- | Delete an entry from the TimeWait heap. deleteTimeWait :: HasTcpState state => state -> TimeWaitTcb -> IO () deleteTimeWait state tw = atomicModifyIORef' (view tcpTimeWait state) $ \ heap -> (filterHeap (/= tw) heap, ()) {-# INLINE deleteTimeWait #-} -- Active Sockets -------------------------------------------------------------- -- | Register a new active socket. registerActive :: HasTcpState state => state -> Tcb -> IO Bool registerActive state tcb = HT.alter update (view tcbKey tcb) (view tcpActive state) where update Nothing = (Just tcb, True) update mb = (mb, False) {-# INLINE registerActive #-} -- | Lookup an active socket. lookupActive :: HasTcpState state => state -> Addr -> TcpPort -> Addr -> TcpPort -> IO (Maybe Tcb) lookupActive state dst dstPort src srcPort = HT.lookup (Key dst dstPort src srcPort) (view tcpActive state) {-# INLINE lookupActive #-} -- | Delete the 'Tcb', and notify any waiting processes. closeActive :: HasTcpState state => state -> Tcb -> IO () closeActive state tcb = do finalizeTcb tcb deleteActive state tcb {-# INLINE closeActive #-} -- | Delete an active connection from the tcp state. deleteActive :: HasTcpState state => state -> Tcb -> IO () deleteActive state tcb = HT.delete (view tcbKey tcb) (view tcpActive state) {-# INLINE deleteActive #-} -- Port Management ------------------------------------------------------------- -- | Pick a fresh port for a connection. nextTcpPort :: HasTcpState state => state -> Addr -> Addr -> TcpPort -> IO (Maybe TcpPort) nextTcpPort state src dst dstPort = modifyMVar tcpPorts (pickFreshPort tcpActive_ (Key dst dstPort src)) where TcpState { .. } = view tcpState state pickFreshPort :: HT.HashTable Key Tcb -> (TcpPort -> Key) -> TcpPort -> IO (TcpPort, Maybe TcpPort) pickFreshPort ht mkKey p0 = go 0 p0 where go :: Int -> TcpPort -> IO (TcpPort,Maybe TcpPort) go i _ | i > 65535 = return (p0, Nothing) go i 0 = go (i+1) 1025 go i port = do used <- HT.hasKey (mkKey port) ht if not used then return (port, Just port) else go (i + 1) (port + 1) -- Sequence Numbers ------------------------------------------------------------ nextIss :: HasTcpState state => state -> Addr -> TcpPort -> Addr -> TcpPort -> IO TcpSeqNum nextIss state src srcPort dst dstPort = do let TcpState { .. } = view tcpState state now <- getCurrentTime (m,f_digest) <- atomicModifyIORef' tcpISSTimer $ \ Tcp4USTimer { .. } -> let diff = diffUTCTime now tcpLastUpdate ticks = tcpTimer + truncate (diff * 250000) -- 4us chunks timers' = Tcp4USTimer { tcpTimer = ticks , tcpLastUpdate = now , .. } digest = integerDigest $ sha1 $ runPutLazy $ do putAddr src putTcpPort srcPort putAddr dst putTcpPort dstPort putByteString tcpSecret in (timers', (ticks, digest)) return (fromIntegral (m + fromIntegral f_digest))