{-# LANGUAGE Arrows, GeneralizedNewtypeDeriving, TypeOperators, CPP, DisambiguateRecordFields, RecordWildCards #-}

-- | Nettle signal functions and drivers. These drivers take care
-- of low level details, such as message numbering, correlation of stats
-- requests and replies, translation of higher level flow rules expressed
-- using packet predicates to low level flow rules expressed in terms of
-- matches.
module Nettle.FRPControl.NettleSF
    (
      
      -- * Nettle Signal Functions
    runNettleSF
     , simpleNettleDriver
    
      
     -- * Switch event sources
     , SwitchMessage(..) 
     , arrivalE  
     , departureE
     , featureUpdateE
     , portUpdateE
     , switchErrorE
     , packetInE
     , filteredPacketInE
     , flowRemovedE
     , portStatReplyE
     , flowStatReplyE       
     
     -- * Switch commands
     , SwitchCommand
     , sendPacket
     , modifyFlowTable
     , deleteFlowRules
     , clearTables
     , configurePort
     , requestStats
     , requestFeatures
     , FlowRule
     , PrioritizedFlowRule
     , addFlowRule
     , addFlowRule'
     , addFlowRules

     , (<+>)
     , noOp
     , (==>)
     , expiringAfter
     , expiringAfterInactive
     , withPriority


     , module Nettle.FRPControl.AFRP
     , module Nettle.FRPControl.PacketPredicate
     , module Nettle.Ethernet.EthernetAddress
     , module Data.Monoid

    ) where

import Nettle.FRPControl.AFRP
import Nettle.FRPControl.AFRPEvent

import Nettle.FRPControl.SwitchInterface ((<+>), noOp)
import qualified Nettle.FRPControl.SwitchInterface as SI
import Nettle.FRPControl.PacketPredicate
import Nettle.FRPControl.AFRPUtils

import Nettle.Ethernet.EthernetAddress
import Nettle.Ethernet.EthernetFrame

import Nettle.Servers.TCPServer (SockAddr)
import Nettle.Servers.MultiplexedTCPServer (TCPMessage(..))
import Nettle.OpenFlow.Messages hiding (SCMessage(..), CSMessage(..))
import qualified Nettle.OpenFlow.Messages as M
import Nettle.OpenFlow.Switch
import Nettle.OpenFlow.Match
import qualified Nettle.OpenFlow.FlowTable as FlowTable
import Nettle.OpenFlow.Statistics hiding (StatsReply(..))
import qualified Nettle.OpenFlow.Statistics as M
import Nettle.OpenFlow.Port
import Nettle.OpenFlow.Error
import Nettle.OpenFlow.Packet
import Nettle.OpenFlow.Action
import Data.Monoid
import Data.Bimap (Bimap)
import qualified Data.Bimap as Bimap
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map
import Data.ByteString.Lazy (ByteString)
import Data.Word
import Control.Exception
import qualified Control.Category as Category
import Control.Monad.State hiding (lift)
import Data.Maybe (mapMaybe)
import System.IO
import Control.Concurrent
import Nettle.Servers.TCPServer (ServerPortNumber, SockAddr)

data SwitchMessage = Arrival SwitchFeatures
                   | Departure IOException
                   | FeatureUpdate SwitchFeatures
                   | PortUpdate PortStatus
                   | SwitchError SwitchError
                   | PacketIn PacketInfo
                   | FlowRemoved FlowTable.FlowRemoved
                   | PortStatsUpdate [(PortID, PortStats)]
                   | FlowStatsUpdate [FlowStats]
                   deriving (Show,Eq)


-- | Outputs an event whenever a switch connects with the controller.
-- The event carries the @SwitchID@ of the switch.
arrivalE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, SwitchFeatures)
arrivalE = mapFilterE f 
  where f (sid, Arrival sfr) = Just (sid, sfr)
        f _ = Nothing

-- | Outputs an event whenever the switch disconnects from the controller. 
-- The event carries the @SwitchID@ of the switch and  
-- an @IOException@ value indicating the reason for the disconnection.
departureE ::  Event (SwitchID, SwitchMessage) -> Event (SwitchID, IOException)
departureE = mapFilterE f
  where f (sid, Departure e) = Just (sid, e)
        f _ = Nothing

-- | Outputs an event whenever a switch sends a switch features update.
-- The event carries the @SwitchID@ of the sending switch and the @SwitchFeatures@ data.
featureUpdateE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, SwitchFeatures)
featureUpdateE = mapFilterE f
  where f (sid, FeatureUpdate sfr) = Just (sid, sfr)
        f _ = Nothing

-- | Outputs an event whenever a switch sends a port status update.
-- The event carries the @SwitchID@ of the sending switch and the @PortStatus@.
portUpdateE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, PortStatus)
portUpdateE = mapFilterE f
  where f (sid, PortUpdate e) = Just (sid, e)
        f _ = Nothing

-- | Outputs an event whenever a switch sends an error message.
-- The event carries the @SwitchID@ of the sending switch and the @SwitchError@.
switchErrorE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, SwitchError)
switchErrorE = mapFilterE f
  where f (sid, SwitchError e) = Just (sid, e)
        f _ = Nothing

-- | Outputs an event whenever a switch sends a packet in message.
-- The event carries the @SwitchID@ of the sending switch and the @PacketInfo@.
packetInE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, PacketInfo)
packetInE = mapFilterE f
  where f (sid, PacketIn e) = Just (sid, e)
        f _ = Nothing


-- | Packet-in events, filtered by a packet predicate applied to the 
-- packet carried by the packet-in event. The output stream will only include 
-- Packet-in messages that satisfy the filter criteria.
filteredPacketInE :: PacketPredicate 
                       -> Event (SwitchID, SwitchMessage) 
                       -> Event (SwitchID, PacketInfo)
filteredPacketInE p = packetInE >>> filterE (f . snd)
  where f pktIn = either (const False) id (packetInMatches pktIn p)

-- | Outputs an event whenever a switch notifies the controller of a flow removal.
flowRemovedE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, FlowTable.FlowRemoved)
flowRemovedE = mapFilterE f
  where f (sid, FlowRemoved e) = Just (sid, e)
        f _ = Nothing

-- | Outputs an event whenever a switch sends a port statistics information to the controller. 
-- This information may be communicated with several messages; the statistics from several messages  
-- related to the same statistics request are aggregated and output in one event from this signal function. 
portStatReplyE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, [(PortID, PortStats)])
portStatReplyE = mapFilterE f
  where f (sid, PortStatsUpdate e) = Just (sid, e)
        f _ = Nothing

-- | Outputs an event whenever a switch sends a flow statistics information to the controller. 
-- This information may be communicated with several messages; the statistics from several messages  
-- related to the same statistics request are aggregated and output in one event from this signal function. 
flowStatReplyE :: Event (SwitchID, SwitchMessage) -> Event (SwitchID, [FlowStats])
flowStatReplyE = mapFilterE f
  where f (sid, FlowStatsUpdate e) = Just (sid, e)
        f _ = Nothing
  
-- | Type of switch commands
newtype SwitchCommand = SwitchCommand [(SwitchID, M.CSMessage)] deriving (Monoid, Show, Eq)

-- | Send a packet
sendPacket :: SwitchID -> PacketOut -> SwitchCommand
sendPacket switchID packet = SwitchCommand [ (switchID, M.PacketOut packet) ]
{-
sendPacketIn :: (SwitchID, PacketIn) -> ActionSequence -> SwitchCommand
floodPacketIn :: (SwitchID, PacketIn) -> SwitchCommand

emptyAddFlow = AddFlow { match    = undefined
                       , priority = 1
                       , actions  = drop
                       , cookie   = 0
                       , idleTimeOut = Permanent
                       , hardTimeOut = Permanent
                       , notifyWhenRemoved = True
                       , applyToPacket = Nothing
                       , overlapAllowed = True 
                       } 
                                    
-}

-- | Modify a flow table
modifyFlowTable :: SwitchID -> FlowTable.FlowMod -> SwitchCommand
modifyFlowTable switchID mod = SwitchCommand [ (switchID, M.FlowMod mod) ]
  
-- | Delete all flow entries whose match conditions fall inside a given packet predicate.
deleteFlowRules :: SwitchID -> PacketPredicate -> SwitchCommand
deleteFlowRules dpid pred = 
  case toMatches pred of
    Just ms -> mconcat [ modifyFlowTable dpid (FlowTable.DeleteFlows { match = m, outPort = Nothing }) | m <- ms ]
    Nothing -> error ("Attempted to delete flow rules with an unrealizable predicate: " ++ show pred)

-- | Clear the flow table of a switch. @clearTables switch = deleteFlowRules switch anyPacket@.
clearTables :: SwitchID -> SwitchCommand
clearTables dpid = deleteFlowRules dpid anyPacket

-- | Configure a port
configurePort :: SwitchID -> PortMod -> SwitchCommand
configurePort switchID mod = SwitchCommand [ (switchID, M.PortMod mod) ]

-- | Request statistics
requestStats :: SwitchID -> StatsRequest -> SwitchCommand
requestStats switchID request = SwitchCommand [ (switchID, M.StatsRequest request) ]

-- | Request switch features
requestFeatures :: SwitchID -> SwitchCommand
requestFeatures switchID = SwitchCommand [ (switchID, M.FeaturesRequest) ]

type PrioritizedFlowRule = (FlowTable.Priority, FlowRule)
type FlowRule = (PacketPredicate, ActionSequence, FlowTable.TimeOut, FlowTable.TimeOut)

-- | A synonym for the tupling operator, specialized to packet predicates. 
-- Provides suggestive infix syntax useful in writing addFlowRule commands.
(==>) :: PacketPredicate -> ActionSequence -> FlowRule
a ==> b = (a, b, FlowTable.Permanent, FlowTable.Permanent)

infix 4 ==>

expiringAfter :: FlowRule -> Word16 -> FlowRule
(a,b,c,d) `expiringAfter` to = (a,b,FlowTable.ExpireAfter to,d)

expiringAfterInactive :: FlowRule -> Word16 -> FlowRule
(a,b,c,d) `expiringAfterInactive` to = (a,b,c,FlowTable.ExpireAfter to)

withPriority ::  FlowRule -> FlowTable.Priority -> PrioritizedFlowRule
withPriority r p = (p,r)

addFlowRule :: PrioritizedFlowRule -> SwitchID -> SwitchCommand
addFlowRule (priority, (pred, acts, idleTimeout, hardTimeout)) dpid 
  = SwitchCommand [(dpid, msg) | msg <- compileFlowRule priority (pred, acts) idleTimeout hardTimeout ] 

compileFlowRule :: FlowTable.Priority -> 
                   (PacketPredicate, [Action]) -> 
                   FlowTable.TimeOut -> 
                   FlowTable.TimeOut -> 
                   [M.CSMessage]
compileFlowRule priority (pred,ports) idleTimeOut hardTimeOut  
    = case toMatches pred of 
        Just ms -> flowMods ms 
        Nothing -> error ("Attempted to add a flow rule with an unrealizable predicate: " ++ show pred)
      where flowMods ms = [ M.FlowMod $ FlowTable.AddFlow { 
                               match             = m, 
                               actions           = ports, 
                               priority          = priority, 
                               idleTimeOut       = idleTimeOut, 
                               hardTimeOut       = hardTimeOut, 
                               applyToPacket     = Nothing 
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1                                                    
                             , overlapAllowed    = True
                             , notifyWhenRemoved = True
#endif
#if OPENFLOW_VERSION==1                                                    
                             , cookie            = 0 
#endif                                                                        
                             } 
                          | m <- ms ]

addFlowRule' :: PrioritizedFlowRule -> BufferID -> SwitchID -> SwitchCommand
addFlowRule' (priority, (pred, acts, idleTimeout, hardTimeout)) bufid dpid 
  = SwitchCommand [(dpid, msg) | msg <- compileFlowRule' priority (pred, acts) idleTimeout hardTimeout bufid ] 

compileFlowRule' :: FlowTable.Priority -> 
                    (PacketPredicate, [Action]) -> 
                    FlowTable.TimeOut -> 
                    FlowTable.TimeOut -> 
                    BufferID -> 
                    [M.CSMessage]
compileFlowRule' priority (pred,ports) idleTimeOut hardTimeOut bufid  
    = case toMatches pred of 
        Just ms -> flowMods ms 
        Nothing -> error ("Attempted to add an flow rule with an unrealizable predicate: " ++ show pred)
      where flowMods ms = [ M.FlowMod $ FlowTable.AddFlow { 
                               match             = m, 
                               actions           = ports, 
                               priority          = priority, 
                               idleTimeOut       = idleTimeOut, 
                               hardTimeOut       = hardTimeOut, 
                               applyToPacket     = Just bufid 
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1                                                    
                             , overlapAllowed    = True
                             , notifyWhenRemoved = True
#endif
#if OPENFLOW_VERSION==1                                                    
                             , cookie            = 0 
#endif                                                                        
                             } 
                          | m <- ms ]
-- | Add a collection of flow rules to a switch.
addFlowRules 
    :: [(FlowTable.Priority, (PacketPredicate, ActionSequence, FlowTable.TimeOut, FlowTable.TimeOut))] -> 
       SwitchID -> 
       SwitchCommand
addFlowRules rs sw = 
    mconcat [ addFlowRule (priority, (pred, acts, idleTimeout, hardTimeout)) sw | (priority, (pred, acts, idleTimeout, hardTimeout)) <- rs ]

-- | Runs a signal function carrying messages and commands as defined in this module,
-- converting it into a signal function in terms of basic switch messages and switch output, 
-- which can be run using the switch driver provided in @SwitchInterface@.
runNettleSF :: SF (Event (SwitchID, SwitchMessage), i) (Event SwitchCommand, o) 
               -> SF (Event SI.SwitchMessage, i) (SI.SwitchOutput, o)
runNettleSF sf = proc (smsgE, i) -> do 
  let helloReplyE = liftE (\(addr, msgID) -> [(addr, (msgID, M.CSHello))]) (SI.switchHelloE smsgE) 
  let echoReplyE  = liftE (\(addr, xid, bytes) -> [(addr, (xid, M.CSEchoReply bytes))]) (SI.switchEchoRequestE smsgE)
  (bindings, bindingChange)  <- switchSockAddressBindingSF -< smsgE
  msgsE <- messageConverter -< (smsgE, bindings, bindingChange)
  (cmdsE, sfo) <- sf -< (msgsE, i)
  let cmdsE' = liftE (\(SwitchCommand cmds) -> mapMaybe (translateAddress bindings) cmds) cmdsE
  switchOutput <- xidTagger -< mergeBy (<+>) (liftE (\(addr,_) ->  [(addr, M.FeaturesRequest)]) (SI.switchHelloE smsgE)) cmdsE' 
  returnA -< (fromEvent (mergeEventsBy (<+>) [ helloReplyE, echoReplyE, switchOutput ]), sfo)

messageConverter :: SF (Event SI.SwitchMessage, Bimap SockAddr SwitchID, Event BindingChange) (Event (SwitchID, SwitchMessage))
messageConverter = proc (smsgE, bindings, bindingChange) -> do
  let arrivalsDeparts = liftE bindingChangeToMessage bindingChange
  let nonStatMessage  = mapFilterE (msgToNonStatMessage bindings) smsgE
  portStatReplies <- portStatRepliesSF' -< (smsgE, bindings)
  flowStatReplies <- flowStatRepliesSF' -< (smsgE, bindings)  
  returnA -< mergeEvents [arrivalsDeparts, nonStatMessage, portStatReplies, flowStatReplies]
  
portStatRepliesSF' :: SF (Event SI.SwitchMessage, Bimap SockAddr SwitchID) (Event (SwitchID, SwitchMessage))
portStatRepliesSF' = proc (sMsgE, bindings) -> do
  x <- accumFilter f Map.empty -< SI.portStatsReplyE sMsgE
  returnA -< mapFilterE (g bindings) x
  where f x2sMap (addr, xid, moreComing, stats) 
          | moreComing = (Map.insertWith (++) xid stats x2sMap, Nothing)
          | otherwise  = let stats' = Map.lookup xid x2sMap
                         in (Map.delete xid x2sMap, Just (addr, PortStatsUpdate (stats ++ maybe [] id stats')))
        g bindings (addr, msg) = 
          case Bimap.lookup addr bindings of 
            Nothing -> Nothing
            Just sid -> Just (sid, msg)

flowStatRepliesSF' :: SF (Event SI.SwitchMessage, Bimap SockAddr SwitchID) (Event (SwitchID, SwitchMessage))
flowStatRepliesSF' = proc (sMsgE, bindings) -> do
  x <- accumFilter f Map.empty -< SI.flowStatsReplyE sMsgE  
  returnA -< mapFilterE (g bindings) x
  where f x2sMap (addr, xid, moreComing, stats) 
          | moreComing = (Map.insertWith (++) xid stats x2sMap, Nothing)
          | otherwise  = let stats' = Map.lookup xid x2sMap
                         in (Map.delete xid x2sMap, Just (addr, FlowStatsUpdate (stats ++ maybe [] id stats')))
        g bindings (addr, msg) = 
          case Bimap.lookup addr bindings of 
            Nothing -> Nothing
            Just sid -> Just (sid, msg)


bindingChangeToMessage :: BindingChange -> (SwitchID, SwitchMessage)  
bindingChangeToMessage (AddSwitch sid sfr)   = (sid, Arrival sfr)
bindingChangeToMessage (SwitchRemoved sid e) = (sid, Departure e)

msgToNonStatMessage :: Bimap SockAddr SwitchID -> SI.SwitchMessage -> Maybe (SwitchID, SwitchMessage)
msgToNonStatMessage _ (ConnectionEstablished _) = Nothing
msgToNonStatMessage _ (ConnectionTerminated _ _) = Nothing
msgToNonStatMessage binding (PeerMessage addr (xid, msg)) = 
  case Bimap.lookup addr binding of 
    Nothing -> Nothing
    Just sid -> 
      case msg of 
        M.SCHello          -> Nothing
        M.SCEchoRequest _  -> Nothing
        M.SCEchoReply _    -> Nothing
        M.Features sfr     -> Just (sid, FeatureUpdate sfr)
        M.PacketIn pktInfo -> Just (sid, PacketIn pktInfo)
        M.PortStatus ps    -> Just (sid, PortUpdate ps)
        M.FlowRemoved fr   -> Just (sid, FlowRemoved fr)
        M.StatsReply sr    -> Nothing
        M.Error se         -> Just (sid, SwitchError se)
        M.BarrierReply     -> Nothing


translateAddress :: (Ord k, Ord k') => Bimap k' k -> (k, a) -> Maybe (k', a)
translateAddress bimap (k,a) = Bimap.lookupR k bimap >>= \k' -> return (k',a)

xidTagger :: SF (Event [(SockAddr, M.CSMessage)]) (Event SI.SwitchOutput)
xidTagger = proc cmdE -> do 
  let lenE = liftE (fromIntegral . length) cmdE
  xid <- hold 0 <<< accum 0 -< liftE (+) lenE
  returnA -< liftE (\(len, cmds) -> zipWith (\x (a,c) -> (a, (x,c))) [xid-len..] cmds) (joinE lenE cmdE)

-- | @switchSockAddressBindingSF@ is a signal function that maintains a binding of 
-- socket address to switch ID for every switch. It also outputs an event whenever
-- a switch binding is added or deleted.
switchSockAddressBindingSF :: SF (Event SI.SwitchMessage) (Bimap SockAddr SwitchID, Event BindingChange)
switchSockAddressBindingSF = proc i -> do
  let switchLeave = SI.switchLeaveE i
  let switchFeature = SI.switchFeaturesE i
  a <- accumFilter f Bimap.empty -< (liftE Left switchFeature `lMerge` liftE Right switchLeave)
  bimap <- hold Bimap.empty -< (liftE fst a)
  returnA -<  (bimap, liftE snd a)
    where f bimap e = 
              case e of 
                Left (addr, _, sfr) ->
                    case Bimap.lookup addr bimap of
                      Just dpid -> (bimap, Nothing)
                      Nothing   -> let dpid = switchID sfr
                                       bimap' = Bimap.insert addr dpid bimap 
                                   in (bimap', Just (bimap', AddSwitch dpid sfr))
                Right (addr, exc) -> 
                    case Bimap.lookup addr bimap of 
                      Just dpid -> let bimap' = Bimap.delete addr bimap
                                   in (bimap', Just (bimap', SwitchRemoved dpid exc))
                      Nothing   -> (bimap, Nothing)

data BindingChange = AddSwitch SwitchID SwitchFeatures
                   | SwitchRemoved SwitchID IOException
                     deriving (Show,Eq)


type StandardInput       = SFInput StandardInputVector
type StandardInputVector = SI.SwitchMessage ::: Nil

inputRep :: Rep (SI.SwitchMessage ::: Nil)
inputRep = RCons RNil

type StandardOutput = SFOutput OutputVector
type OutputVector   = SI.SwitchOutput ::: String ::: Nil 

outputRep :: Rep (SI.SwitchOutput ::: String ::: Nil)
outputRep = RCons (RCons RNil)

-- | Runs a signal function with a single input stream of switch messages
-- and two output streams, one for switch commands and one for messages to the
-- standard output device; starts a switch server at the specified port.
simpleNettleDriver :: ServerPortNumber 
                      -> SF (Event (SwitchID, SwitchMessage)) (Event SwitchCommand, Event String) 
                      -> IO ()
simpleNettleDriver pstring sf = do 
  (switchSensor, switchActuator) <- SI.switchInterfaceDriver pstring
  switchSensorCh <- newChan
  forkIO $ forever (switchSensor >>= writeChan switchSensorCh)
  let senseChans = (switchSensorCh,())
  hdl <- openFile "foo.out" ReadWriteMode 
  let consoleActuator msg = when (not (null msg)) (putStrLn msg >> hPutStrLn hdl msg)
  let actuators = (switchActuator, (consoleActuator, ()))
  let sf'       = arr (\(mx,_) -> (maybeToEvent mx, ())) >>> 
                  runNettleSF (arr (\(msg,()) -> msg) >>> sf) >>> 
                  arr (\(swOut, msgE) -> (swOut, (fromEvent msgE, ())))
  finally (sfDriver inputRep senseChans outputRep actuators sf') (hClose hdl)