{-# LANGUAGE FlexibleContexts #-} module Net.TCP_Client( initialize,Active(..),tx,rx,Passive(..),Interface(..),Peer,Port(..) ) where -- Transmission Control Protocol -- See http://www.networksorcery.com/enp/protocol/tcp.htm -- http://www.networksorcery.com/enp/rfc/rfc793.txt import Net.Concurrent import Control.Monad.State import Control.Monad.Trans(lift) import Data.Map (Map) import qualified Data.Map as Map import Data.List((\\)) import Data.Word(Word8,Word16,Word32) import Net.TCP as TCP import Net.PortNumber import qualified Net.IPv4 as IPv4 import qualified Net.Interface as Net import Net.Utils as Util(doReq,contents,checksum,bytes_to_words_big) import Net.Packet(InPacket,len,dropInPack, OutPacket,outLen,outBytes,emptyInPack, emptyOutPack,appendOutPack,splitOutPack) import Net.PacketParsing(doUnparse) import Monad.Util data Active m = Active { close:: m (), io::Net.Interface m InPacket OutPacket } data Passive m = Passive { accept::m (Peer,Active m), unlisten::m () } type Peer = (IPv4.Addr,Port) tx = Net.tx . io rx = Net.rx . io data Interface m = Interface { listen :: Port -> m (Passive m), connect :: Peer -> m (Maybe (Active m)) } -------------------------------------------------------------------------------- data Req m = Listen Port (Passive m->m ()) | Unlisten Port | Connect Peer (Maybe (Active m) ->m ()) | Disconnect Port Peer -- from connection handling thread | FromNetwork TCPPacketIn data State m = T { listeners::Listeners m, connections::Connections m } type Connections m = Map (Port,Peer) (TCPPacketIn->m ()) type Listeners m = Map Port (Listening m) type Listening m = (Peer,Active m)->m () type TCPPacketIn = TCPPacket InPacket type TCPPacketOut = TCPPacket OutPacket type TCPPacket contents = IPv4.Packet (Packet contents) type TCPIPLink m = Net.Interface m TCPPacketIn TCPPacketOut {-# NOINLINE initialize #-} initialize putStrLn myIP iface = do reqChan <- newChan fork $ loop $ writeChan reqChan . FromNetwork =<< Net.rx iface fork $ server debug myIP iface reqChan return $ Interface { listen = doReq reqChan . Listen, connect = doReq reqChan . Connect } where debug = putStrLn . ("TCP: "++) server debugIO myIP iface reqChan = flip evalStateT init $ loop (handle=< handlePacket ipPacket Listen port reply -> addListener port reply Connect peer reply -> activate peer (sendSyn reply) Unlisten port -> modify $ unlisten port Disconnect port peer -> modify $ disconnect (port,peer) -- State updates: listen port accept s@T{listeners=l} = s{listeners=Map.insert port accept l} unlisten port s@T{listeners=l} = s{listeners=Map.delete port l} connect c fwd s@T{connections=cs} = s{connections=Map.insert c fwd cs} disconnect c s@T{connections=cs} = s{connections=Map.delete c cs} addListener port reply = do -- check that port is not already listening acceptCh <- lift newChan lift $ reply Passive { accept=readChan acceptCh, unlisten=writeChan reqChan (Unlisten port) } let accept = writeChan acceptCh modify $ listen port accept handlePacket ipPacket = if okTCPchksum ipPacket then handleOkPacket ipPacket else debug "Dropping packet with bad checksum" handleOkPacket ipPacket = do let packet = IPv4.content ipPacket peer = (IPv4.source ipPacket,sourcePort packet) me = (IPv4.dest ipPacket,port) c = (me,peer) port = destPort packet CB{ack=a,syn=s} = controlBits packet acknr = ackNr packet dropit = debug $ "Dropped packet from "++show peer ++" to "++show me ++ "\n"++show ipPacket optcon <- gets (Map.lookup (port, peer).connections) case optcon of Just toConnection -> do --debug $ "Forwarding "++show c lift $ toConnection ipPacket _ -> do optlistener <- gets (Map.lookup port . listeners) case optlistener of Just listener | s && not a -> activate' c (synReceived ipPacket listener) _ | a -> reset c acknr -- Half-open connection detected _ -> dropit reset c acknr = do let rst = minBound{rst=True} debug $ "RST "++show c lift $ Net.tx iface (setTCPchksum (tcpPacket () rst c acknr 0)) pickPort = do T{listeners=l,connections=c} <- get let inuse = Map.keys l++map fst (Map.keys c) -- duplicates, slow? return $ head (map Port [32768..65535]\\inuse) activate peer handler = do port <- pickPort -- find an unused port let me = (myIP,port) activate' (me,peer) handler activate' c@(me@(_,port),peer) handler = do outCh <- lift newChan -- packets from client to connection inCh <- lift newChan -- packets from connection to client flowctl <- lift $ newMVar () -- for client output flow control let cdebug msg = debugIO $ show me++"<->"++show peer++"\n "++msg forward = writeChan outCh . ConFromNetwork modify $ connect (port,peer) forward let io = Net.Interface { Net.rx=readChan outCh, Net.tx=Net.tx iface . setTCPchksum } active = Active { close=writeChan outCh Close, io=Net.Interface { Net.rx=readChan inCh, Net.tx=tx}} where tx p = do --debugIO $ "takeMVar flowctl "++show (outLen p) takeMVar flowctl writeChan outCh (ConTx p) --debugIO $ "tookMvar flowctl" lift $ fork $ do t <- fork $ timer (writeChan outCh . Tick) handler c cdebug (writeChan inCh) io flowctl active kill t writeChan reqChan (Disconnect port peer) return () -------------------------------------------------------------------------------- synReceived ipPacket reply c@(_,peer) debug deliver io flowctl active = do --debug $ "SYN received " ++show rxSeqNr let synackP = synackPacket c txSeqNr (rxSeqNr+1) --debug $ "Sending SYN ACK " ++show txSeqNr++" "++show (rxSeqNr+1) maybe done gotAck =<< waitForAck synackP where gotAck ip = do --debug $ "Got ACK, connection is established" reply (peer,active) -- not until connection is established let tcp = contents ip dat = contents tcp l = fromIntegral (len dat) when (l>0) $ do debug $ "ACK and delever initial bytes "++show l Net.tx io (ackPacket c (txSeqNr+1) (rxSeqNr+1+l)) deliver dat established c debug deliver io flowctl (rxSeqNr+1+l,txSeqNr+1,txWindow) tcp = contents ipPacket rxSeqNr = seqNr tcp txWindow = window tcp txSeqNr = 10000000 -- should be chosen randomly!!! waitForAck synackP = solicitPacket debug io synackP expected where expected p = if cb==minBound{ack=True} && ackNr tcp==txSeqNr+1 && seqNr tcp==rxSeqNr+1 then Just p else Nothing where tcp = contents p cb = controlBits tcp -------------------------------------------------------------------------------- solicitPacket debug io request expected = loop 3 0 where loop 0 _ = return Nothing loop retries 0 = do debug "Retrying" Net.tx io request loop (retries-1) (3*ticksPerSecond) loop retries t = do r <- Net.rx io case r of Tick _ -> loop retries (t-1) ConFromNetwork p -> case expected p of Just r -> return (Just r) _ -> loop retries t _ -> loop retries t -------------------------------------------------------------------------------- sendSyn reply c debug deliver io flowctl active = do let synP = synPacket c iss 0 --debug "Sent SYN, waiting for SYN ACK" maybe noreply gotAck =<< waitForAck synP where iss = 20000000 -- Initial Send Sequence number, should be chosen randomly!!! noreply = reply Nothing gotAck (irs,txWindow) = do --debug $ "Got SYN ACK, sending ACK, IRS="++show irs Net.tx io (ackPacket c (iss+1) (irs+1)) reply (Just active) established c debug deliver io flowctl (irs+1,iss+1,txWindow) waitForAck synP = solicitPacket debug io synP expected where expected p = let tcp = contents p cb = controlBits tcp in if cb==minBound{ack=True,syn=True} && ackNr tcp==iss+1 then Just (seqNr tcp,window tcp) else Nothing -------------------------------------------------------------------------------- dataPacket dat = tcpPacket dat minBound{ack=True} ackPacket = dataPacket () finPacket = emptyPacket minBound{ack=True,fin=True} synPacket = emptyPacket minBound{syn=True} synackPacket = emptyPacket minBound{syn=True,ack=True} emptyPacket = tcpPacket () tcpPacket dat cb ((myIP,myPort),(peerIP,peerPort)) seqnr acknr = iptemplate tcp{content=doUnparse dat} where tcp = template{sourcePort=myPort,destPort=peerPort, ackNr=acknr,seqNr=seqnr,controlBits=cb} iptemplate = IPv4.template IPv4.TCP myIP peerIP -------------------------------------------------------------------------------- -- Requests to connection handling thread: data ConReq = Close | ConTx OutPacket | ConFromNetwork TCPPacketIn | Tick Int data ConState = S { phase::Phase, now,roundTripTime::Int, unackedData::[(Word32,Int,OutPacket)], unsentData::OutPacket, txUnacked,txSeq,txWindow,rxSeq,rxWindow::Word32 } data Phase = Established | CloseWait | Closing | FinWait1 | FinWait2 | LastAck | TimeWait | Closed deriving (Eq,Ord,Show) conReq disc tx rx tick req = case req of Close -> disc ConTx p -> tx p ConFromNetwork p -> rx p Tick t -> tick t ticksPerSecond=10 timer m = loop 0 where loop t = do delay us m t loop (t+1) us = 1000000 `div` ticksPerSecond established c debugIO deliver io flowctl (rxseq,txseq,txwin) = flip evalStateT state0 $ do debug $ "Transmit window = "++show txwin whileM ((0) $ do sendData' dat seq put s{txSeq=seq+l,unackedData=ps++[(seq,t,dat)]} --debug $ "Sent "++show l++" bytes upto "++show (seq+l) sendData' dat seq = tx . dataPacket dat c seq =<< gets rxSeq trySendData = do S{txSeq=seq,txUnacked=unacked,txWindow=win} <- get when (seq-unacked=win && q0 || q>0) $ debug $ "Sending "++show l++" bytes, " ++show q++ " bytes left in transmit queue" --} return p1 sendFin = do s@S{txSeq=seq,rxSeq=ack} <- get tx (finPacket c seq ack) put s{txSeq=seq+1} goto p = do modify $ \ s -> s{phase=p} --debug $ "Go to state "++show p -- Some time has passed: tick now = do rtt <- gets roundTripTime (ps',timeout) <- flip runStateT False . mapM (retransmit rtt) =<< gets unackedData when timeout $ modify $ \ s -> s{roundTripTime=backoff rtt} modify $ \ s -> s{now=now,unackedData=ps'} where backoff rtt = min (5*ticksPerSecond) (2*(max 1 rtt)) retransmit rtt p@(seq,t,buf) = if now>1+t+2*rtt then do lift $ debug $ "Retransmitting seqNr "++show seq ++ " len "++show (outLen buf) ++" after "++show(now-t)++" ticks" lift $ sendData' buf seq put True return (seq,now,buf) else return p -- Local request to close the connection: close = do p <- gets phase case p of Established -> do sendFin ; goto FinWait1 CloseWait -> do sendFin ; goto LastAck _ -> debug "Buggy local client closing more than once" -- Local request to send some data: conTx dat = do p <- gets phase --let l=fromIntegral (outLen dat) --when (l>0) $ if p>CloseWait then debug "Buggy local client sending after closing" else do queueData dat trySendData -- Receiver a packet from the network: conRx ip | rst (controlBits (contents ip)) = -- also check seqNr -- Should probably notify client that the connection was reset -- and not just closed the normal way... do p <- gets phase when (p==Established) $ lift (deliver emptyInPack) -- EOS goto TimeWait -- or Closed? conRx ip = do let tcp = contents ip got = seqNr tcp dat = contents tcp l=fromIntegral (len dat) cb=controlBits tcp --debug $ "Got packet with "++show cb++" and "++show l++" bytes of data" expecting <- gets rxSeq -- also check RST flag! when (l>0) $ do let new=got+l-expecting rxwin <- gets rxWindow if new>0 && new<=rxwin then do --debug $ "ACK upto "++show ack -- ++" and deliver "++show new++" bytes" let ack=got+l dup=fromIntegral (l-new) modify $ \ s->s{rxSeq=ack} lift $ deliver (skipIn dup dat) else do acknowledge expecting debug $ "got duplicate input " ++show (got,l,expecting) S{phase=p,rxSeq=expecting} <- get if fin cb then let finseq=got+l ack=finseq+1 ackgoto p = do acknowledge ack;goto p in if finseq/=expecting then debug "FIN with unexpected sequence number" else case p of Established -> do lift (deliver emptyInPack) -- EOS ackgoto CloseWait FinWait1 -> do ackgoto Closing FinWait2 -> do ackgoto TimeWait _ -> debug "Unexpected FIN" else acknowledge expecting S{txUnacked=unacked,txSeq=seq} <- get let acknr = ackNr tcp when (ack cb && acknr/=unacked && acknr-unacked<=seq-unacked) $ -- !! modulo arithmetic do s@S{now=now,roundTripTime=oldrtt,unackedData=ps} <- get let (ps1,ps2) =span (isAcked acknr) ps newrtt=if null ps1 then oldrtt else maximum [ now-t | (_,t,_)<-ps1] rtt=(oldrtt+newrtt) `div` 2 put s{txUnacked=acknr,unackedData=ps2,roundTripTime=rtt} {- debug $ "Update ACKed output to "++show acknr ++", "++show (length ps2)++" unacked packets" ++", new roundtrip time="++show rtt++" ticks" -} trySendData seq <- gets txSeq when (acknr==seq) $ -- when everything sent has been acked case phase s of FinWait1 -> goto FinWait2 Closing -> goto TimeWait LastAck -> goto Closed _ -> return () isAcked acknr (seq,t,buf) = seq+fromIntegral (outLen buf)<=acknr -------------------------------------------------------------------------------- okTCPchksum ip = tcp_chksum ip == 0 setTCPchksum ip = ip{IPv4.content=tcp'} where tcp' = tcp{TCP.checksum=tcp_chksum ip} tcp = contents ip tcp_chksum ip = outPacketChecksum pseudoTCP where tcp = contents ip pseudoHeader = (IPv4.source ip,IPv4.dest ip,0::Word8,IPv4.TCP,tcpLength) tcpLength = fromIntegral (outLen utcp)::Word16 pseudoTCP = doUnparse (pseudoHeader,utcp) utcp = doUnparse tcp -- TCP packet will be serialized twice!! outPacketChecksum = Util.checksum . bytes_to_words_big . outBytes --pre: n<=len p skipIn n p = dropInPack n p