{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-|
    This module provides the low-level interface for communicating with a
    metaverse server.  It handles the details of packet encoding, accounting,
    handshaking, and so on.

    In general, you should try to use the higher-level functions in the
    "Network.Metaverse" module as often as possible, and fall down to this
    level only when there is no other option.
-}
module Network.Metaverse.Circuit (
    Circuit,
    circuitConnect,
    circuitAgentID,
    circuitSessionID,
    circuitCode,
    circuitSend,
    circuitSendSync,
    circuitIncoming,
    circuitClose,
    circuitIsClosed
    )
    where

import Prelude hiding (catch)

import Control.Arrow (first)
import Control.Concurrent
import Control.Event.Relative
import Control.Exception
import Control.Monad
import Control.Monad.Trans
import Data.Binary
import Data.Binary.Put
import Data.Binary.Get
import Data.Binary.IEEE754
import Data.Bits
import Data.Time.Clock
import Data.UUID hiding (null)

import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString

import Network.Metaverse.Login
import Network.Metaverse.PacketTypes

import qualified Data.Map as M
import Data.Map (Map)

import Control.Monad.State hiding (get, put)
import qualified Control.Monad.State as S

import qualified Data.ByteString      as B
import qualified Data.ByteString.Lazy as L

------------------------------------------------------------------------

{-
    A wrapper for Control.Event.Relative, except that instead of returning
    an EventId when scheduling an event, the user supplies the event
    identifier as an ordered type of their own choosing.
-}

data TaskQueue k = TaskQueue {
    taskVar :: MVar (M.Map k EventId)
    }

newTaskQueue :: Ord k => IO (TaskQueue k)
newTaskQueue = fmap TaskQueue (newMVar M.empty)

schedule :: Ord k => TaskQueue k -> k -> Int -> IO () -> IO ()
schedule (TaskQueue v) k t a = do
    id <- addEvent t $ modifyMVar_ v (return . M.delete k) >> a
    modifyMVar_ v (return . M.insert k id)

cancel   :: Ord k => TaskQueue k -> k -> IO Bool
cancel (TaskQueue v) k = do
    m <- takeMVar v
    case M.lookup k m of Nothing -> putMVar v m >> return False
                         Just id -> putMVar v (M.delete k m) >> delEvent id >> return True

closeQueue :: Ord k => TaskQueue k -> IO ()
closeQueue (TaskQueue v) = do
    tasks <- takeMVar v
    mapM_ delEvent $ M.elems tasks
    putMVar v M.empty

------------------------------------------------------------------------

{-
    Packet handling, serialization, and deserialization.
-}

type SequenceNum = Word32

data Packet = Packet {
    packetZerocoded  :: Bool,
    packetReliable   :: Bool,
    packetRetransmit :: Bool,
    packetSequence   :: SequenceNum,
    packetExtra      :: B.ByteString,
    packetBody       :: PacketBody,
    packetAcks       :: [SequenceNum]
    }
    deriving Show

serialize :: Packet -> B.ByteString
serialize (Packet zcode reliable retrans seq extra body acks) =
    let putter  = do
            let mask i b = if b then bit i else 0
            let nacks = length acks
            let flags = mask 4 (nacks > 0)
                    .|. mask 5 retrans
                    .|. mask 6 reliable
                    .|. mask 7 zcode
            putWord8 flags
            putWord32be seq
            putWord8 (fromIntegral (B.length extra))
            putByteString extra

            if zcode
               then putLazyByteString (zeroencode (encode body))
               else put body

            mapM_ putWord32be acks
            when (nacks > 0) (putWord8 (fromIntegral nacks))

    in  B.concat $ L.toChunks $ runPut putter

deserialize :: B.ByteString -> Packet
deserialize fullMsg = 
    -- Unfortunately, the encoding of the packets makes it impossible to
    -- just use Data.Binary.  We need to preprocess the header using
    -- something that lets us read from both sides.

    let -- First, read the flags from the first byte of the header.
        flags    = B.head fullMsg
        hasAcks  = testBit flags 4
        retrans  = testBit flags 5
        reliable = testBit flags 6
        zcode    = testBit flags 7

        -- Next, if there are appended acks, peel them off and read them.
        (withoutAcks, acks) = if hasAcks
            then let msg1                        = B.init fullMsg
                     nacks                       = B.last fullMsg
                     (result, appended)          = B.splitAt (B.length msg1 - 4 * fromIntegral nacks) msg1
                     ackGetter                   = replicateM (fromIntegral nacks) getWord32be
                     acks                        = runGet ackGetter (L.fromChunks [ appended ])
                 in  (result, acks)
            else (fullMsg, [])

        -- Now take off the header, so we're left with only the body (which
        -- may or may not be zerocoded.
        headerGetter = do _        <- getWord8    -- Header flags (already seen)
                          seq      <- getWord32be
                          extralen <- getWord8
                          extra    <- getBytes (fromIntegral extralen)
                          body     <- getRemainingLazyByteString
                          return (seq, extra, body)
        (seq, extra, encodedBody) = runGet headerGetter (L.fromChunks [ withoutAcks ])

        -- Un-zerocode the body if needed.  Zerocoding is indicated by a flag
        -- in the message headers.
        decodedBody = if zcode then zerodecode encodedBody else encodedBody

        -- Parse the body into a PacketBody.  The "decode" here means something
        -- different, hence the funny-sounding code.
        body = decode (decodedBody)

    in  Packet zcode reliable retrans seq extra body acks

zerodecode :: L.ByteString -> L.ByteString
zerodecode r | L.length r <= 1  = r
             | x == 0           = let Just (n, r') = L.uncons xs
                                  in  L.append (L.replicate (fromIntegral n) 0) (zerodecode r')
             | otherwise        = L.cons x (zerodecode xs)
    where Just (x, xs) = L.uncons r

zeroencode :: L.ByteString -> L.ByteString
zeroencode r | L.null r   = r
             | L.null pfx = L.cons x (zeroencode xs)
             | otherwise  = L.append (zeros (L.length pfx)) (zeroencode rest)
    where (pfx, rest)         = L.span (== 0) r
          Just (x, xs)        = L.uncons r
          zeros n | n > 255   = L.append (L.pack [0, 255]) (zeros (n - 255))
                  | n > 0     = L.pack [0, fromIntegral n]
                  | otherwise = L.empty

------------------------------------------------------------------------

{-|
    A @Circuit@ is a connection to a metaverse server.  One connects to
    the server using the information given from the login server, via
    'circuitConnect'.  Messages are then sent and received by operating
    on the circuit, until it is closed with 'circuitClose' or by a
    network timeout.
-}
data Circuit = Circuit {
    -- Some information that's nice to have access to when connected to a
    -- circuit.  This needs to be accessible to the next layer up, because
    -- this information is embedded into various message fields.

    {-|
        Gives the agent UUID associated with this circuit.
    -}
    circuitAgentID    :: !UUID,

    {-|
        Gives the session UUID associated with this circuit.
    -}
    circuitSessionID  :: !UUID,

    {-|
        Gives the circuit code, a 32-bit integer, associated with this circuit.
        This is only rarely used, but it occasionally needed.
    -}
    circuitCode       :: !Word32,

    {-
        Each circuit is associated with a socket and a remote address to send
        and receive communication.  This is used internally within the current
        module only.
    -}
    circuitSocket     :: Socket,
    circuitAddr       :: SockAddr,

    {-
        This is a list of auxiliary threads that are used by the circuit.
        This is helpful so that we can be sure that all threads are killed
        when the circuit is closed.
    -}
    circuitThreads    :: MVar [ThreadId],

    {-|
        Gives the channel used to provide incoming packets from the server.
        In general it is not used directly, but rather in conjunction with
        'dupChan' so that each piece of the client can operate independently
        with respect to all of the others.

        When the circuit is closed, 'Nothing' is written to this channel.
    -}
    circuitIncoming   :: Chan (Maybe PacketBody),

    {-
        Packet accounting state, used to handle retransmissions and sequencing.
        This is used only within the current module.
    -}
    circuitAccounting :: MVar Accounting
    }

{-
    Packet accounting.  This layer of the communication system handles
    sequencing packets, tracking and resending dropped packets, acknowledging
    packets from the server, and pruning duplicate packets due to lost acks.
-}

data Accounting = Accounting {
    acctClosed        :: Bool,

    acctSequence      :: SequenceNum,
    acctRecentPackets :: [SequenceNum],
    acctPendingAcks   :: [(UTCTime, SequenceNum)],
    acctReliableQueue :: TaskQueue SequenceNum,
    acctConfirmations :: Map SequenceNum (MVar Bool),
    acctLastTime      :: UTCTime
    }

{-
    Convenient for composing actions that need to atomically modify the
    connection accounting information.
-}
runWithMVar :: MVar a -> StateT a IO b -> IO b
runWithMVar v m = modifyMVar v (fmap (fmap swap) (runStateT m))
    where swap (a,b) = (b,a)

{-
    Generates the next sequence number in line.
-}
nextSequence :: StateT Accounting IO SequenceNum
nextSequence = do
    seq <- fmap acctSequence S.get
    modify $ \s -> s { acctSequence = seq + 1 }
    return seq

{-
    Sends an entire packet, without doing any circuit accounting
-}
sendRaw :: Socket -> SockAddr -> Packet -> IO ()
sendRaw sock addr packet = sendAllTo sock (serialize packet) addr

{-
    Retrieves acks to append to the current packets, given the number
    of bytes available to send them.  Updates the accounting information
    to indicate these acks have been sent.
-}
getAcks :: Int -> StateT Accounting IO [SequenceNum]
getAcks size = do
    let nacks = size `div` 4
    pending <- fmap acctPendingAcks S.get
    let (sending, leftovers) = splitAt (min 255 nacks) pending
    modify (\s -> s { acctPendingAcks = leftovers })
    return (map snd sending)

{-
    Sends a packet, including appending any acks for otherwise
    unacknowledged packets.
-}
sendWithAcks :: Socket -> SockAddr -> Packet -> StateT Accounting IO ()
sendWithAcks sock addr packet = do
    acks <- getAcks $ 10000 - 7 - packetLength (packetBody packet)
    liftIO $ sendRaw sock addr packet { packetAcks = acks }
        `catch` \(e :: SomeException) -> return ()

{-
    Indicates whether a packet should be sent reliably or not.  There are
    three possibilities:

    1. The packet should be sent unreliably.
    2. The packet should be sent reliably (retried if it fails), but the
       client is not interested in knowing whether it succeeded.
    3. The packet should be sent reliably, and the client wants to know
       whether it succeeded via the attached MVar.
-}
data Reliability = Unreliable
                 | Reliable (Maybe (MVar Bool))

{-
    Determines whether a certain 'Reliability' should cause a packet to be
    sent reliably or acknowledged.
-}
isReliable :: Reliability -> Bool
isReliable Unreliable   = False
isReliable (Reliable _) = True

{-
    Only does something if the circuit is not closed.
-}
ifNotClosed :: Monad m => StateT Accounting m () -> StateT Accounting m ()
ifNotClosed action = do
    acct <- S.get
    let r = acctClosed acct
    when (not r) action

{-
    Sends a packet on the circuit.  This is the function that is narrowly
    wrapped by 'circuitSend' and 'circuitSendSync'.
-}
circuitSendImpl :: Circuit
                -> Reliability -- ^ MVar to notify when packet is acked
                -> PacketBody  -- ^ The payload to send along
                -> StateT Accounting IO ()
circuitSendImpl circ rel body = ifNotClosed $ do
    let sock = circuitSocket circ
    let addr = circuitAddr   circ
    seq <- nextSequence
    let packet = Packet
            (shouldZerocode body) (isReliable rel) False seq B.empty body []
    sendWithAcks sock addr packet
    reliableAccounting rel circ packet

{-
    Handles the accounting necessary for sending a packet reliably.
-}
reliableAccounting :: Reliability
                   -> Circuit
                   -> Packet
                   -> StateT Accounting IO ()
reliableAccounting Unreliable _ _ = return ()
reliableAccounting (Reliable mv) circ packet = do
    flip (maybe $ return ()) mv $ \ v -> do
        con   <- fmap acctConfirmations S.get
        modify $ \s -> s { acctConfirmations = M.insert seq v con }
    queue <- fmap acctReliableQueue S.get
    liftIO $ schedule queue seq retryTime (retry retryCount)
  where
    retryTime  = 1500000 -- TODO: Find the right value
    retryCount = 3
    sock    = circuitSocket circ
    addr    = circuitAddr   circ
    seq     = packetSequence packet
    retried = packet { packetRetransmit = True }
    retry 0 = flip (maybe $ return ()) mv $ \ v -> do
        putMVar v False
        runWithMVar (circuitAccounting circ) $ do
            con   <- fmap acctConfirmations S.get
            modify $ \s -> s { acctConfirmations = M.delete seq con }
    retry n = runWithMVar (circuitAccounting circ) $ do
        sendWithAcks sock addr retried
        queue <- fmap acctReliableQueue S.get
        liftIO $ schedule queue seq retryTime (retry (n-1))

{-|
    Sends a packet to the server, but does not wait for a response.
-}
circuitSend :: Circuit    -- ^ The circuit to send on
            -> Bool       -- ^ Whether to send reliably.  While this function
                          --   never waits for a response, a value of 'True'
                          --   here will cause the system to at least look for
                          --   acknowledgement and retry a few times, greatly
                          --   increasing the chance of the message getting
                          --   through.  On the other hand, 'False' here gives
                          --   a much cheaper packet.
            -> PacketBody -- ^ The packet contents to send
            -> IO ()
circuitSend circ reliable msg = runWithMVar (circuitAccounting circ) $ do
    circuitSendImpl circ
        (if reliable then Reliable Nothing else Unreliable) msg

{-|
    Sends a packet to the server, and waits for acknowledgement.
-}
circuitSendSync :: Circuit    -- The circuit to send on
                -> PacketBody -- The packet contents to send
                -> IO Bool
circuitSendSync circ msg = do
    v <- newEmptyMVar
    runWithMVar (circuitAccounting circ) $
        circuitSendImpl circ (Reliable (Just v)) msg
    takeMVar v

{-
    A process that occasionally sends PacketAck messages for any
    outstanding acks.  This ensures that if there's no communication
    for any reason, there's still acks going out.
-}
ackSender :: Circuit -> IO ()
ackSender circ = do
    cont <- runWithMVar (circuitAccounting circ) $ do
        acks <- fmap acctPendingAcks S.get
        t    <- liftIO $ getCurrentTime
        let ackThreshold = 0.75 -- TODO: Find the right values
        when (not (null acks)
            && t `diffUTCTime` fst (head acks) > ackThreshold) $ do
            acks <- getAcks (10000 - 7)
            circuitSendImpl circ Unreliable (PacketAck (map PacketAck_Packets acks))
        fmap (not . acctClosed) S.get
    when cont $ do
        threadDelay 500000 -- TODO: Find the right frequency
        ackSender circ

{-
    Handles confirming that a packet was successfully received.
-}
confirmPacket :: SequenceNum -> StateT Accounting IO ()
confirmPacket seq = do
    q <- fmap acctReliableQueue S.get
    m <- fmap acctConfirmations S.get
    liftIO $ cancel q seq
    case M.lookup seq m of
        Nothing -> return ()
        Just mv -> do
            liftIO $ putMVar mv True
            modify $ \s -> s { acctConfirmations = M.delete seq m }

{-
    The low-level function used to receive data from a UDP datagram and
    turn it into a 'Packet'
-}
recvRaw :: Socket -> IO (Maybe (Packet, SockAddr))
recvRaw sock = fmap (Just . first deserialize) (recvFrom sock 10000)
    `catch` \(e :: SomeException) -> return Nothing

{-
    The thread that receives packets from the server and delivers them as
    appropriate to the 'circuitIncoming' channel.  Also handles packet
    accounting related to received packets.
-}
packetReceiver :: Circuit -> IO ()
packetReceiver circ = do
    let sock = circuitSocket circ
    let addr = circuitAddr   circ
    res <- recvRaw sock
    case res of
        Just (packet, addr') -> when (addr == addr') $ do
            cont <- runWithMVar (circuitAccounting circ) $ do
                t <- liftIO getCurrentTime
                modify $ \s -> s { acctLastTime = t }

                mapM_ confirmPacket (packetAcks packet)

                when (packetReliable packet) $ do
                    modify $ \s -> s {
                        acctPendingAcks = acctPendingAcks s ++ [ (t, packetSequence packet) ]
                    }

                recent <- fmap acctRecentPackets S.get
                when (packetReliable packet) $ modify $ \s ->
                    s { acctRecentPackets = take 100 (packetSequence packet : acctRecentPackets s) }

                {-
                    Handle some acks as built-ins, rather than delivering them to the
                    channel.
                -}
                case packetBody packet of
                    PacketAck acks -> do
                        mapM_ confirmPacket (map packetAck_Packets_ID acks)
                    _ -> do
                        when (not (packetRetransmit packet)
                            || not (packetSequence packet `elem` recent)) $ do

                            liftIO $ writeChan (circuitIncoming circ) (Just (packetBody packet))

                fmap (not . acctClosed) S.get
            when cont $ packetReceiver circ
        Nothing -> circuitClose circ

{-
    The protocol specification stipulates that clients should send the
    occasional ping, so we do here, once every five seconds.
-}
pingSender :: Circuit -> Word8 -> IO ()
pingSender circ n = do
    threadDelay 5000000

    t0 <- fmap acctLastTime $ readMVar $ circuitAccounting circ
    t1 <- getCurrentTime
    cont <- if (t1 `diffUTCTime` t0 > 60)
        then circuitClose circ >> return False
        else runWithMVar (circuitAccounting circ) $ do
             circuitSendImpl circ Unreliable $
                StartPingCheck (StartPingCheck_PingID n 0)
             return True
    when cont $ pingSender circ (n+1)

{-
    Thread that responds to pings from the server.
-}
pingResponder :: Circuit -> Chan (Maybe PacketBody) -> IO ()
pingResponder circ source = do
    packet <- readChan source
    cont   <- case packet of
        Just (StartPingCheck (StartPingCheck_PingID x y)) -> do
            circuitSend circ False $ CompletePingCheck
                (CompletePingCheck_PingID x)
            return True
        Just _  -> return True
        Nothing -> return False
    when cont $ pingResponder circ source

{-|
    Closes a circuit, terminating its threads, closing its network resources,
    and cleaning up after it.
-}
circuitClose :: Circuit -- ^ The circuit to close
             -> IO ()
circuitClose circ = do
    writeChan (circuitIncoming circ) Nothing
    closeQueue . acctReliableQueue =<< readMVar (circuitAccounting circ)
    threads <- swapMVar (circuitThreads circ) []
    mapM_ killThread threads
    sClose (circuitSocket circ)

{-
    Determines if a circuit is closed or not.  A circuit may be closed because
    of 'circuitClose', or because of network issues or disconnect by the
    server.
-}
circuitIsClosed :: Circuit -> IO Bool
circuitIsClosed circ = fmap acctClosed $ readMVar (circuitAccounting circ)

{-|
    Connects to a circuit, using connection information given in the login
    token provided.  This sets up all the accounting and other data structures
    associated with the circuit and gets it all started.
-}
circuitConnect :: MVToken -- ^ Token of circuit info obtained from login server
               -> IO Circuit
circuitConnect token = do
    sock    <- socket AF_INET Datagram defaultProtocol
    host    <- inet_addr (tokenSimIP token)
    let port = fromIntegral (tokenSimPort token)

    acct    <- newEmptyMVar
    threads <- newEmptyMVar
    inc     <- newChan

    let circ = Circuit {
        circuitAgentID    = tokenAgentID token,
        circuitSessionID  = tokenSessionID token,
        circuitCode       = tokenCircuitCode token,
        circuitSocket     = sock,
        circuitAddr       = SockAddrInet port host,
        circuitThreads    = threads,
        circuitIncoming   = inc,
        circuitAccounting = acct
        }

    queue      <- newTaskQueue
    pingSource <- dupChan inc

    putMVar threads =<< mapM forkIO [
        ackSender      circ,
        packetReceiver circ,
        pingSender     circ 0,
        pingResponder  circ pingSource
        ]

    time <- getCurrentTime

    putMVar acct $ Accounting {
        acctClosed        = False,
        acctSequence      = 1,
        acctRecentPackets = [],
        acctPendingAcks   = [],
        acctReliableQueue = queue,
        acctConfirmations = M.empty,
        acctLastTime      = time
        }

    circuitSendSync circ $ UseCircuitCode $ UseCircuitCode_CircuitCode
        (circuitCode circ) (circuitSessionID circ) (circuitAgentID circ)

    return circ