{-# LANGUAGE RecursiveDo #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
-- |
-- Module: Network.Transport.InMemory.Internal
--
-- Internal part of the implementation. This module is for internal use
-- or advanced debuging. There are no guarantees about stability of this
-- module.
module Network.Transport.InMemory.Internal
  ( createTransportExposeInternals
    -- * Internal structures
  , TransportInternals(..)
  , TransportState(..)
  , ValidTransportState(..)
  , LocalEndPoint(..)
  , LocalEndPointState(..)
  , ValidLocalEndPointState(..)
  , LocalConnection(..)
  , LocalConnectionState(..)
    -- * Low level functionality
  , apiNewEndPoint
  , apiCloseEndPoint
  , apiBreakConnection
  , apiConnect
  , apiSend
  , apiClose
  ) where

import Network.Transport
import Network.Transport.Internal ( mapIOException )
import Control.Category ((>>>))
import Control.Concurrent.STM
import Control.Exception (handle, throw)
import Data.Map (Map)
import Data.Maybe (fromJust)
import Data.Monoid
import Data.Foldable
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BSC (pack)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:))
import qualified Data.Accessor.Container as DAC (mapMaybe)
import Data.Typeable (Typeable)
import Prelude hiding (foldr)

data TransportState
  = TransportValid {-# UNPACK #-} !ValidTransportState
  | TransportClosed

data ValidTransportState = ValidTransportState
  { _localEndPoints :: !(Map EndPointAddress LocalEndPoint)
  , _nextLocalEndPointId :: !Int
  }

data LocalEndPoint = LocalEndPoint
  { localEndPointAddress :: !EndPointAddress
  , localEndPointChannel :: !(TChan Event)
  , localEndPointState   :: !(TVar LocalEndPointState)
  }

data LocalEndPointState
  = LocalEndPointValid {-# UNPACK #-} !ValidLocalEndPointState
  | LocalEndPointClosed

data ValidLocalEndPointState = ValidLocalEndPointState
  { _nextConnectionId :: !ConnectionId
  , _connections :: !(Map (EndPointAddress,ConnectionId) LocalConnection)
  , _multigroups :: Map MulticastAddress (TVar (Set EndPointAddress))
  }

data LocalConnection = LocalConnection
  { localConnectionId :: !ConnectionId
  , localConnectionLocalAddress :: !EndPointAddress
  , localConnectionRemoteAddress :: !EndPointAddress
  , localConnectionState :: !(TVar LocalConnectionState)
  }

data LocalConnectionState
  = LocalConnectionValid
  | LocalConnectionClosed
  | LocalConnectionFailed

newtype TransportInternals = TransportInternals (TVar TransportState)

-- | Create a new Transport exposing internal state.
--
-- Useful for testing and/or debugging purposes.
-- Should not be used in production. No guarantee as to the stability of the internals API.
createTransportExposeInternals :: IO (Transport, TransportInternals)
createTransportExposeInternals = do
  state <- newTVarIO $ TransportValid $ ValidTransportState
    { _localEndPoints = Map.empty
    , _nextLocalEndPointId = 0
    }
  return (Transport
    { newEndPoint    = apiNewEndPoint state
    , closeTransport = do
        -- transactions are splitted into smaller ones intentionally
        old <- atomically $ swapTVar state TransportClosed
        case old of
          TransportClosed -> return ()
          TransportValid tvst -> do
            forM_ (tvst ^. localEndPoints) $ \l -> do
              cons <- atomically $ whenValidLocalEndPointState l $ \lvst -> do
                writeTChan (localEndPointChannel l) EndPointClosed
                writeTVar  (localEndPointState l) LocalEndPointClosed
                return (lvst ^. connections)
              forM_ cons $ \con -> atomically $
                writeTVar (localConnectionState con) LocalConnectionClosed
    }, TransportInternals state)


-- | Create a new end point.
apiNewEndPoint :: TVar TransportState
               -> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint state = handle (return . Left) $ atomically $ do
  chan <- newTChan
  (lep,addr) <- withValidTransportState state NewEndPointFailed $ \vst -> do
    lepState <- newTVar $ LocalEndPointValid $ ValidLocalEndPointState
      { _nextConnectionId = 1
      , _connections = Map.empty
      , _multigroups = Map.empty
      }
    let r = nextLocalEndPointId ^: (+ 1) $ vst
        addr = EndPointAddress . BSC.pack . show $ r ^. nextLocalEndPointId
        lep = LocalEndPoint
          { localEndPointAddress = addr
          , localEndPointChannel = chan
          , localEndPointState = lepState
          }
    writeTVar state (TransportValid $ localEndPointAt addr ^= Just lep $ r)
    return (lep, addr)
  return $ Right $ EndPoint
    { receive       = atomically $ do
        result <- tryReadTChan chan
        case result of
          Nothing -> do st <- readTVar (localEndPointState lep)
                        case st of
                          LocalEndPointClosed ->
                            throwSTM (userError "Channel is closed.")
                          LocalEndPointValid{} -> retry
          Just x -> return x
    , address       = addr
    , connect       = apiConnect addr state
    , closeEndPoint = apiCloseEndPoint state addr
    , newMulticastGroup     = return $ Left $ newMulticastGroupError
    , resolveMulticastGroup = return . Left . const resolveMulticastGroupError
    }
  where
    -- see [Multicast] section
    newMulticastGroupError =
      TransportError NewMulticastGroupUnsupported "Multicast not supported"
    resolveMulticastGroupError =
      TransportError ResolveMulticastGroupUnsupported "Multicast not supported"

apiCloseEndPoint :: TVar TransportState -> EndPointAddress -> IO ()
apiCloseEndPoint state addr = atomically $ whenValidTransportState state $ \vst ->
    forM_ (vst ^. localEndPointAt addr) $ \lep -> do
      old <- swapTVar (localEndPointState lep) LocalEndPointClosed
      case old of
        LocalEndPointClosed -> return ()
        LocalEndPointValid lepvst -> do
          forM_ (Map.elems (lepvst ^. connections)) $ \lconn -> do
            st <- swapTVar (localConnectionState lconn) LocalConnectionClosed
            case st of
              LocalConnectionClosed -> return ()
              LocalConnectionFailed -> return ()
              _ -> forM_ (vst ^. localEndPointAt (localConnectionRemoteAddress lconn)) $ \thep ->
                     whenValidLocalEndPointState thep $ \_ -> do
                        writeTChan (localEndPointChannel thep)
                                   (ConnectionClosed (localConnectionId lconn))
          writeTChan (localEndPointChannel lep) EndPointClosed
          writeTVar  (localEndPointState lep)    LocalEndPointClosed
      writeTVar state (TransportValid $ (localEndPoints ^: Map.delete addr) vst)

-- | Tear down functions that should be called in case if conncetion fails.
apiBreakConnection :: TVar TransportState
                   -> EndPointAddress
                   -> EndPointAddress
                   -> String
                   -> STM ()
apiBreakConnection state us them msg
  | us == them = return ()
  | otherwise  = whenValidTransportState state $ \vst -> do
      breakOne vst us them >> breakOne vst them us
  where
    breakOne vst a b = do
      forM_ (vst ^. localEndPointAt a) $ \lep ->
        whenValidLocalEndPointState lep $ \lepvst -> do
          let (cl, other) = Map.partitionWithKey (\(addr,_) _ -> addr == b)
                                                 (lepvst ^.connections)
          forM_ cl $ \c -> modifyTVar (localConnectionState c)
                                      (\x -> case x of
                                               LocalConnectionValid -> LocalConnectionFailed
                                               _ -> x)
          writeTChan (localEndPointChannel lep)
                     (ErrorEvent (TransportError (EventConnectionLost b) msg))
          writeTVar (localEndPointState lep)
                    (LocalEndPointValid $ (connections ^= other) lepvst)


-- | Create a new connection
apiConnect :: EndPointAddress
           -> TVar TransportState
           -> EndPointAddress
           -> Reliability
           -> ConnectHints
           -> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect ourAddress state theirAddress _reliability _hints = do
    handle (return . Left) $ fmap Right $ atomically $ do
      (chan, lconn) <- do
        withValidTransportState state ConnectFailed $ \vst -> do
          ourlep <- case vst ^. localEndPointAt ourAddress of
                      Nothing ->
                        throwSTM $ TransportError ConnectFailed "Endpoint closed"
                      Just x  -> return x
          theirlep <- case vst ^. localEndPointAt theirAddress of
                        Nothing ->
                          throwSTM $ TransportError ConnectNotFound "Endpoint not found"
                        Just x  -> return x
          conid <- withValidLocalEndPointState theirlep ConnectFailed $ \lepvst -> do
            let r = nextConnectionId ^: (+ 1) $ lepvst
            writeTVar (localEndPointState theirlep) (LocalEndPointValid r)
            return (r ^. nextConnectionId)
          withValidLocalEndPointState ourlep ConnectFailed $ \lepvst -> do
            lconnState <- newTVar LocalConnectionValid
            let lconn = LocalConnection
                           { localConnectionId = conid
                           , localConnectionLocalAddress = ourAddress
                           , localConnectionRemoteAddress = theirAddress
                           , localConnectionState = lconnState
                           }
            writeTVar (localEndPointState ourlep)
                      (LocalEndPointValid $
                         connectionAt (theirAddress, conid) ^= lconn $ lepvst)
            return (localEndPointChannel theirlep, lconn)
      writeTChan chan $
        ConnectionOpened (localConnectionId lconn) ReliableOrdered ourAddress
      return $ Connection
        { send  = apiSend chan state lconn
        , close = apiClose chan state lconn
        }

-- | Send a message over a connection
apiSend :: TChan Event
        -> TVar TransportState
        -> LocalConnection
        -> [ByteString]
        -> IO (Either (TransportError SendErrorCode) ())
apiSend chan state lconn msg = handle handleFailure $ mapIOException sendFailed $
    atomically $ do
      connst <- readTVar (localConnectionState lconn)
      case connst of
        LocalConnectionValid -> do
          foldr seq () msg `seq`
            writeTChan chan (Received (localConnectionId lconn) msg)
          return $ Right ()
        LocalConnectionClosed -> do
          -- If the local connection was closed, check why.
          withValidTransportState state SendFailed $ \vst -> do
            let addr = localConnectionLocalAddress lconn
                mblep = vst ^. localEndPointAt addr
            case mblep of
              Nothing -> throwSTM $ TransportError SendFailed "Endpoint closed"
              Just lep -> do
                lepst <- readTVar (localEndPointState lep)
                case lepst of
                  LocalEndPointValid _ -> do
                    return $ Left $ TransportError SendClosed "Connection closed"
                  LocalEndPointClosed -> do
                    throwSTM $ TransportError SendFailed "Endpoint closed"
        LocalConnectionFailed -> return $
          Left $ TransportError SendFailed "Endpoint closed"
    where
      sendFailed = TransportError SendFailed . show
      handleFailure ex@(TransportError SendFailed reason) = atomically $ do
        apiBreakConnection state (localConnectionLocalAddress lconn)
                                 (localConnectionRemoteAddress lconn)
                                 reason
        return (Left ex)
      handleFailure ex = return (Left ex)

-- | Close a connection
apiClose :: TChan Event
         -> TVar TransportState
         -> LocalConnection
         -> IO ()
apiClose chan state lconn = do
  atomically $ do -- XXX: whenValidConnectionState
    connst <- readTVar (localConnectionState lconn)
    case connst of
      LocalConnectionValid -> do
        writeTChan chan $ ConnectionClosed (localConnectionId lconn)
        writeTVar (localConnectionState lconn) LocalConnectionClosed
        whenValidTransportState state $ \vst -> do
          let mblep = vst ^. localEndPointAt (localConnectionLocalAddress lconn)
              theirAddress = localConnectionRemoteAddress lconn
          forM_ mblep $ \lep ->
            whenValidLocalEndPointState lep $
              writeTVar (localEndPointState lep)
                . LocalEndPointValid
                . (connections ^: Map.delete (theirAddress, localConnectionId lconn))
      _ -> return ()

-- [Multicast]
-- Currently multicast implementation doesn't pass it's tests, so it
-- disabled. Here we have old code that could be improved, see GitHub ISSUE 5
-- https://github.com/haskell-distributed/network-transport-inmemory/issues/5

-- | Construct a multicast group
--
-- When the group is deleted some endpoints may still receive messages, but
-- subsequent calls to resolveMulticastGroup will fail. This mimicks the fact
-- that some multicast messages may still be in transit when the group is
-- deleted.
createMulticastGroup :: TVar TransportState
                     -> EndPointAddress
                     -> MulticastAddress
                     -> TVar (Set EndPointAddress)
                     -> MulticastGroup
createMulticastGroup state ourAddress groupAddress group = MulticastGroup
    { multicastAddress     = groupAddress
    , deleteMulticastGroup = atomically $
        whenValidTransportState state $ \vst -> do
          -- XXX best we can do given current broken API, which needs fixing.
          let lep = fromJust $ vst ^. localEndPointAt ourAddress
          modifyTVar' (localEndPointState lep) $ \lepst -> case lepst of
            LocalEndPointValid lepvst ->
              LocalEndPointValid $ multigroups ^: Map.delete groupAddress $ lepvst
            LocalEndPointClosed ->
              LocalEndPointClosed
    , maxMsgSize           = Nothing
    , multicastSend        = \payload -> atomically $
        withValidTransportState state SendFailed $ \vst -> do
          es <- readTVar group
          forM_ (Set.elems es) $ \ep -> do
            let ch = localEndPointChannel $ fromJust $ vst ^. localEndPointAt ep
            writeTChan ch (ReceivedMulticast groupAddress payload)
    , multicastSubscribe   = atomically $ modifyTVar' group $ Set.insert ourAddress
    , multicastUnsubscribe = atomically $ modifyTVar' group $ Set.delete ourAddress
    , multicastClose       = return ()
    }

-- | Create a new multicast group
_apiNewMulticastGroup :: TVar TransportState
                     -> EndPointAddress
                     -> IO (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
_apiNewMulticastGroup state ourAddress = handle (return . Left) $ do
  group <- newTVarIO Set.empty
  groupAddr <- atomically $
    withValidTransportState state NewMulticastGroupFailed $ \vst -> do
      lep <- maybe (throwSTM $ TransportError NewMulticastGroupFailed "Endpoint closed")
                   return
                   (vst ^. localEndPointAt ourAddress)
      withValidLocalEndPointState lep NewMulticastGroupFailed $ \lepvst -> do
        let addr = MulticastAddress . BSC.pack . show . Map.size $ lepvst ^. multigroups
        writeTVar (localEndPointState lep) (LocalEndPointValid $ multigroupAt addr ^= group $ lepvst)
        return addr
  return . Right $ createMulticastGroup state ourAddress groupAddr group

-- | Resolve a multicast group
_apiResolveMulticastGroup :: TVar TransportState
                         -> EndPointAddress
                         -> MulticastAddress
                         -> IO (Either (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
_apiResolveMulticastGroup state ourAddress groupAddress = handle (return . Left) $ atomically $
    withValidTransportState state ResolveMulticastGroupFailed $ \vst -> do
      lep <- maybe (throwSTM $ TransportError ResolveMulticastGroupFailed "Endpoint closed")
                   return
                   (vst ^. localEndPointAt ourAddress)
      withValidLocalEndPointState lep ResolveMulticastGroupFailed $ \lepvst -> do
          let group = lepvst ^. (multigroups >>> DAC.mapMaybe groupAddress)
          case group of
            Nothing ->
              return . Left $
                TransportError ResolveMulticastGroupNotFound
                  ("Group " ++ show groupAddress ++ " not found")
            Just mvar ->
              return . Right $ createMulticastGroup state ourAddress groupAddress mvar

--------------------------------------------------------------------------------
-- Lens definitions                                                           --
--------------------------------------------------------------------------------

nextLocalEndPointId :: Accessor ValidTransportState Int
nextLocalEndPointId = accessor _nextLocalEndPointId (\eid st -> st{ _nextLocalEndPointId = eid} )

localEndPoints :: Accessor ValidTransportState (Map EndPointAddress LocalEndPoint)
localEndPoints = accessor _localEndPoints (\leps st -> st { _localEndPoints = leps })

nextConnectionId :: Accessor ValidLocalEndPointState ConnectionId
nextConnectionId = accessor _nextConnectionId (\cid st -> st { _nextConnectionId = cid })

connections :: Accessor ValidLocalEndPointState (Map (EndPointAddress,ConnectionId) LocalConnection)
connections = accessor _connections (\conns st -> st { _connections = conns })

multigroups :: Accessor ValidLocalEndPointState (Map MulticastAddress (TVar (Set EndPointAddress)))
multigroups = accessor _multigroups (\gs st -> st { _multigroups = gs })

at :: Ord k => k -> String -> Accessor (Map k v) v
at k err = accessor (Map.findWithDefault (error err) k) (Map.insert k)

localEndPointAt :: EndPointAddress -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr

connectionAt :: (EndPointAddress, ConnectionId) -> Accessor ValidLocalEndPointState LocalConnection
connectionAt addr = connections >>> at addr "Invalid connection"

multigroupAt :: MulticastAddress -> Accessor ValidLocalEndPointState (TVar (Set EndPointAddress))
multigroupAt addr = multigroups >>> at addr "Invalid multigroup"

---------------------------------------------------------------------------------
-- Helpers
---------------------------------------------------------------------------------

-- | LocalEndPoint state deconstructor.
overValidLocalEndPointState :: LocalEndPoint -> STM a -> (ValidLocalEndPointState -> STM a) -> STM a
overValidLocalEndPointState lep fallback action = do
  lepst <- readTVar (localEndPointState lep)
  case lepst of
    LocalEndPointValid lepvst -> action lepvst
    _ -> fallback

-- | Specialized deconstructor that throws TransportError in case of Closed state
withValidLocalEndPointState :: (Typeable e, Show e) => LocalEndPoint -> e -> (ValidLocalEndPointState -> STM a) -> STM a
withValidLocalEndPointState lep ex = overValidLocalEndPointState lep (throw $ TransportError ex "EndPoint closed")

-- | Specialized deconstructor that do nothing in case of failure
whenValidLocalEndPointState :: Monoid m => LocalEndPoint -> (ValidLocalEndPointState -> STM m) -> STM m
whenValidLocalEndPointState lep = overValidLocalEndPointState lep (return mempty)

overValidTransportState :: TVar TransportState -> STM a -> (ValidTransportState -> STM a) -> STM a
overValidTransportState ts fallback action = do
  tsst <- readTVar ts
  case  tsst of
    TransportValid tsvst -> action tsvst
    _ -> fallback

withValidTransportState :: (Typeable e, Show e) => TVar TransportState -> e -> (ValidTransportState -> STM a) -> STM a
withValidTransportState ts ex = overValidTransportState ts (throw $ TransportError ex "Transport closed")

whenValidTransportState :: Monoid m => TVar TransportState -> (ValidTransportState -> STM m) -> STM m
whenValidTransportState ts = overValidTransportState ts (return mempty)