{-# LANGUAGE CPP, MultiParamTypeClasses, FlexibleInstances, DisambiguateRecordFields , RecordWildCards , NamedFieldPuns #-}

module Nettle.FRPControl.PacketPredicate 
    (
     -- * Packet predicates and match semantics
     Logic(..), 
     (|-), 
     PacketPredicate(..), Clause, Literal,
     (<&>), 
     (<|>), 
     anyPacket, 
     inPortIs,
     ethSourceIs, 
     ethDestIs, 
     vLANIDIs,
     vlanPriority,
     ethFrameTypeIs, 
     ipTypeOfService,
     transportProtocolIs, 
     ipSourceIn,
     ipDestIn,
     senderTransportIs, 
     receiverTransportIs, 
     receiverTransportIn,
     ands, 
     ors,
     satisfies, 
     clauses, 
     literals, 
     overlaps,

     -- * Commonly occurring packet predicates.
     dhcp, dns, arp, lldp, ip, udp, ethSourceDestAre,

     -- * Packet predicates and matches for this version
     fromMatch,
     toMatches,
     realizable,
     packetInFrame, 
     exactPredicate, 
     packetInMatches

    )
    where

import Nettle.OpenFlow.Messages  
import Nettle.OpenFlow.Port
import Nettle.OpenFlow.Match hiding (ipTypeOfService)
import qualified Nettle.OpenFlow.Match as Match
import Nettle.OpenFlow.Packet
import Nettle.Ethernet.EthernetFrame
import Nettle.Ethernet.EthernetAddress
import Nettle.IPv4.IPAddress
import Nettle.IPv4.IPPacket
import Data.Word
import qualified Data.List as List
import Data.Maybe
import Data.Binary.Get
import Control.Monad.Error

-- | Type class for pairs of types where one type is a set of "structures", and the other is
-- a set of predicates over these structures, and for which there exists a relations of
-- satisfaction.
class Logic m p where
    holds :: m -> p -> Bool

-- | A synonym for holds
(|-) :: Logic m p => m -> p -> Bool
(|-) = holds

infix 5 |-

-- | Packet Predicates
-- Note that values of this data type should NOT be constructed using the constructors of this data type, 
-- but rather through the functions defined below. The functions below maintain this data type's invariants, 
-- whereas these constructors do not. 
data PacketPredicate 
    = AndPP PacketPredicate PacketPredicate
    | OrPP PacketPredicate PacketPredicate
    | TruePP
    | FalsePP
    | InPortIs PortID
    | MacSourceIs EthernetAddress
    | MacDestIs EthernetAddress
    | VLANIDIs VLANID
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1    
    | VLANPriority VLANPriority
#endif
    | MacFrameTypeIs EthernetTypeCode
#if OPENFLOW_VERSION==1      
    | IPTypeOfService IPTypeOfService      
#endif
    | IPProtocolIs IPProtocol
    | IPSourceIn IPAddressPrefix
    | IPDestIn IPAddressPrefix
    | SenderTransportPortIs Word16
    | ReceiverTransportPortIs Word16
      deriving (Show,Read,Eq)
 
anyPacket              = TruePP
inPortIs p             = InPortIs p
ethSourceIs a          = MacSourceIs a
ethDestIs a            = MacDestIs a
vLANIDIs x             = VLANIDIs x
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1    
vlanPriority x         = VLANPriority x
#endif
ethFrameTypeIs t       = MacFrameTypeIs t
#if OPENFLOW_VERSION==1      
ipTypeOfService x      = IPTypeOfService x
#endif
transportProtocolIs x  = IPProtocolIs x
ipSourceIn x           = IPSourceIn x
ipDestIn x             = IPDestIn x
senderTransportIs x    = SenderTransportPortIs x
receiverTransportIs x  = ReceiverTransportPortIs x
receiverTransportIn xs = ors [receiverTransportIs x | x <- xs]

infixl 8 <|>, <&>

(<|>), (<&>) :: PacketPredicate -> PacketPredicate -> PacketPredicate
f <|> g = f `OrPP` g
f <&> g = f `andPP` g

andPP TruePP p1 = p1
andPP FalsePP p1 = FalsePP
andPP p1 TruePP = p1
andPP p1 FalsePP = FalsePP

andPP (InPortIs p1) (InPortIs p2) = 
    if p1 /= p2 then FalsePP else InPortIs p1

andPP (MacSourceIs x1) (MacSourceIs x2) = 
    if x1 /= x2 then FalsePP else MacSourceIs x1

andPP (MacDestIs x1) (MacDestIs x2) = 
    if x1 /= x2 then FalsePP else MacDestIs x1

andPP (VLANIDIs x1) (VLANIDIs x2) = 
    if x1 /= x2 then FalsePP else VLANIDIs x1

andPP (MacFrameTypeIs x1) (MacFrameTypeIs x2) = 
    if x1 /= x2 then FalsePP else MacFrameTypeIs x1

andPP (IPProtocolIs x1) (IPProtocolIs x2) = 
    if x1 /= x2 then FalsePP else IPProtocolIs x1

andPP (SenderTransportPortIs x1) (SenderTransportPortIs x2) = 
    if x1 == x2 then SenderTransportPortIs x1 else SenderTransportPortIs x1 `AndPP` SenderTransportPortIs x2

andPP (ReceiverTransportPortIs x1) (ReceiverTransportPortIs x2) = 
    if x1 == x2 then ReceiverTransportPortIs x1 else ReceiverTransportPortIs x1 `AndPP` ReceiverTransportPortIs x2

andPP (p1 `OrPP` p2) p3 = andPP p1 p3 `OrPP` andPP p2 p3
andPP p1 (p2 `OrPP` p3) = andPP p1 p2 `OrPP` andPP p1 p3

andPP p1 p2 = p1 `AndPP` p2

ands, ors :: [PacketPredicate] -> PacketPredicate
ors  = foldl (<|>) FalsePP
ands = foldl (<&>) TruePP

-- | This function defines when an incoming packet (as received by a switch)
-- satisfies a given packet predicate.
satisfies :: (PortID, EthernetFrame) -> PacketPredicate -> Bool
satisfies _ TruePP  = True
satisfies _ FalsePP = False
satisfies (portid',_) (InPortIs portid) = portid==portid'
satisfies (_, frame) (MacSourceIs a) = a == sourceAddress frame
satisfies (_, frame) (MacDestIs a)   = a == destAddress frame
satisfies (_, EthernetFrame hdr _) (VLANIDIs a) = 
    case hdr of 
      EthernetHeader {}    -> True
      (Ethernet8021Q {..}) -> a == vlanId
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1    
satisfies (_, EthernetFrame hdr _) (VLANPriority p) = 
    case hdr of 
      EthernetHeader {}    -> True
      (Ethernet8021Q {..}) -> p == priorityCodePoint
#endif      
satisfies (_, EthernetFrame hdr _) (MacFrameTypeIs a) = a == typeCode hdr
#if OPENFLOW_VERSION==1    
satisfies (_, EthernetFrame hdr body) (IPTypeOfService tos) =
  case body of 
    IPInEthernet (IPPacket (IPHeader {..}) ipBody) -> tos == dscp
    _ -> True
#endif    
satisfies (_, EthernetFrame _ body) (IPProtocolIs a) = 
    case body of 
      IPInEthernet (IPPacket (IPHeader {..}) ipbody) -> ipProtocol == a
      _ -> True
satisfies (_, EthernetFrame _ body) (IPSourceIn prefix) = 
    case body of 
      IPInEthernet ipPkt -> sourceAddress ipPkt `elemOfPrefix` prefix
      _                  -> True
satisfies (_, EthernetFrame _ body) (IPDestIn prefix) = 
    case body of 
      IPInEthernet ipPkt -> destAddress ipPkt `elemOfPrefix` prefix
      _                  -> True
satisfies (_, EthernetFrame _ body) (SenderTransportPortIs a) = 
    case body of 
      IPInEthernet (IPPacket iphdr ipbody) -> 
          case ipbody of 
            TCPInIP (srcp,_) -> srcp == a
            UDPInIP (srcp,_) -> srcp == a
      _ -> True
satisfies (_, EthernetFrame _ body) (ReceiverTransportPortIs a) = 
    case body of 
      IPInEthernet (IPPacket iphdr ipbody) -> 
          case ipbody of 
            TCPInIP (_,destp) -> destp == a
            UDPInIP (_,destp) -> destp == a
      _ -> True
satisfies x (p1 `AndPP` p2) = satisfies x p1 && satisfies x p2
satisfies x (p1 `OrPP` p2)  = satisfies x p1 || satisfies x p2

instance Logic (PortID, EthernetFrame) PacketPredicate where
    holds = satisfies


-- | A literal is any packet predicate except those formed using conjunction or disjunction.
--   The type synonym does not enforce this constraint - we just use it as a reminder of the 
--   intent.
type Literal = PacketPredicate

-- | A clause is a list of literals. A packet satisfies a clause if it satisfies
-- all the literals in the clause. From this, it follows that any packet satisfies an 
-- empty clause; i.e. the empty clause is equivalent to True.
type Clause = [Literal]

-- | Computes the clauses for a packet predicate; assumes the data type invariants hold. 
-- A packet satisfies a list of clauses if it satisfies some clause in the list. From this
-- it follows that no packet satisfies the empty list of clauses. I.e. the empty list
-- of clauses is equivalent to False.
clauses :: PacketPredicate -> [Clause]
clauses (p1 `OrPP` p2) = clauses p1 ++ clauses p2
clauses p1 = let c = normalizeClause (literals p1)
             in if c == [FalsePP] then [] else [c]

-- Helper function to normalize a clause:
-- It does the following: 
-- (1) removes duplicate literals
-- (2) identifies unsatisfiable clauses
-- (3) ensures clauses equivalent to TruePP are represented by [TruePP]
-- Note that distinct "normal" forms may be logically equivalent, and logically equivalent 
-- clauses may be reduced to distinct "normal" forms by this procedure. This is due to the 
-- semantics of literals such as senderTransportIs, IPSourceIn, etc, which are really 
-- implications. 
normalizeClause :: Clause -> Clause
normalizeClause = 
    convertToTrue . conflictToFalse . remdups . normalizeSourceIPAddressConditions . normalizeDestIPAddressConditions
    where remdups = List.nub
          conflictToFalse ls = if (length inPortLits > 1 
                                   || length macSourceLits > 1 
                                   || length macDestLits > 1 
                                   || length macFrameTypeLits > 1 
                                   || (MacFrameTypeIs ethTypeVLAN `elem` ls && length vlanLits > 1)              
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1                                       
                                   || (MacFrameTypeIs ethTypeVLAN `elem` ls && length vlanPriorityLits > 1)      
#endif                                   
                                   || (MacFrameTypeIs ethTypeIP `elem` ls   && length ipProtocolLits > 1)        
#if OPENFLOW_VERSION==1                                   
                                   || (MacFrameTypeIs ethTypeIP `elem` ls   && length dscpLits > 1)
#endif                                   
                                   || (MacFrameTypeIs ethTypeIP `elem` ls   && length senderTransportLits > 1)   
                                   || (MacFrameTypeIs ethTypeIP `elem` ls   && length receiverTransportLits > 1) 
                                   || elem FalsePP ls)
                               then [FalsePP]
                               else ls
              where inPortLits            = [ x | InPortIs x <- ls ]
                    macSourceLits         = [ x | MacSourceIs x <- ls]
                    macDestLits           = [ a | MacDestIs a <- ls ]
                    vlanLits              = [ x | VLANIDIs x <- ls]
                    vlanPriorityLits      = [ x | VLANPriority x <- ls ]
                    macFrameTypeLits      = [ x | MacFrameTypeIs x <- ls ]
#if OPENFLOW_VERSION==1                                   
                    dscpLits              = [ x | IPTypeOfService x <- ls]
#endif                                            
                    ipProtocolLits        = [ x | IPProtocolIs x <- ls ]
                    senderTransportLits   = [ x | SenderTransportPortIs x <- ls ]
                    receiverTransportLits = [ x | ReceiverTransportPortIs x <- ls ]
                    ipSourcePrefixes      = [ x | IPSourceIn x <- ls ]

          convertToTrue ls = let ls' = filter (/=TruePP) ls
                             in if null ls' then [TruePP] else ls'

          removeTrues ls   = filter (/=TruePP) ls

          normalizeSourceIPAddressConditions ls = 
              case intersects [ x | IPSourceIn x <- ls ] of 
                Nothing -> if MacFrameTypeIs ethTypeIP `elem` ls 
                           then [FalsePP]
                           else ls
                Just x  -> IPSourceIn x : filter (not . isIPSourcePred) ls

          normalizeDestIPAddressConditions ls = 
              case intersects [ x | IPDestIn x <- ls ] of 
                Nothing -> if MacFrameTypeIs ethTypeIP `elem` ls 
                           then [FalsePP]
                           else ls
                Just x  -> IPDestIn x : filter (not . isIPDestPred) ls


isIPSourcePred :: PacketPredicate -> Bool
isIPSourcePred (IPSourceIn _) = True
isIPSourcePred _              = False

isIPDestPred :: PacketPredicate -> Bool
isIPDestPred (IPDestIn _) = True
isIPDestPred _            = False

-- Finds the literals in a clause; note: it is not defined for disjunctions.
literals :: PacketPredicate -> [Literal]
literals (p1 `AndPP` p2) = literals p1 ++ literals p2
literals p = [p]

-- | With the above, we can now calculate whether two packet predicates
--   overlap, that is, when their intersection is non-empty.
overlaps :: PacketPredicate -> PacketPredicate -> Bool
p1 `overlaps` p2 = clauses (p1 <&> p2) /= []

-- | Commonly occurring packet predicates.
dhcp, dns, arp, ip, udp :: PacketPredicate
dhcp     = ip <&> udp <&> (senderTransportIs dhcpPort <|> receiverTransportIs dhcpPort)
dns      = ip <&> udp <&> (senderTransportIs dnsPort  <|> receiverTransportIs dnsPort)
arp      = ethFrameTypeIs ethTypeARP
lldp     = ethFrameTypeIs ethTypeLLDP
ip       = ethFrameTypeIs ethTypeIP
udp      = transportProtocolIs udpCode

udpCode  = 17
dhcpPort = 67
dnsPort  = 53

ethSourceDestAre :: EthernetAddress -> EthernetAddress -> PacketPredicate
ethSourceDestAre s d = ethSourceIs s <&> ethDestIs d

-- | The disjunction of (toMatches pred) matches the same set of 
--   packets as the packet predicate pred does.
toMatches :: PacketPredicate -> Maybe [Match]
toMatches p = if all realizableClause cs
              then Just [ fromJust mm | c <- clauses p, let mm = clauseToMatch c, isJust mm ]
              else Nothing
    where cs = clauses p

realizable :: PacketPredicate -> Bool
realizable = all realizableClause . clauses

realizableClause :: Clause -> Bool
realizableClause lits 
    | FalsePP `elem` lits = False
    | conflictExists      = False
    | otherwise           = True
    where conflictExists = length vlanLits > 1            
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1              
                           || length vlanPriorityLits > 1    
#endif                           
                           || length ipProtocolLits > 1      
#if OPENFLOW_VERSION==1                           
                           || length dscpLits > 1
#endif                           
                           || length ipSourcePrefixes > 1    
                           || length ipDestPrefixes > 1      
                           || length senderTransportLits > 1 
                           || length receiverTransportLits > 1 
          vlanLits              = [ x | VLANIDIs x <- lits                ]
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1              
          vlanPriorityLits      = [ x | VLANPriority x <- lits            ]
#endif          
          ipProtocolLits        = [ x | IPProtocolIs x <- lits            ]
#if OPENFLOW_VERSION==1                                   
          dscpLits              = [ x | IPTypeOfService x <- lits         ]
#endif                                            
          senderTransportLits   = [ x | SenderTransportPortIs x <- lits   ]
          receiverTransportLits = [ x | ReceiverTransportPortIs x <- lits ]
          ipSourcePrefixes      = [ x | IPSourceIn x <- lits              ]
          ipDestPrefixes        = [ x | IPDestIn x <- lits                ]

-- Converts a single clause to maybe a match; If the clause is unsatisfiable,
-- then it is converted to Nothing.
clauseToMatch :: Clause -> Maybe Match
clauseToMatch [] = Just matchAny 
clauseToMatch (p:ps) = toMatchAux (p:ps) (Just matchAny)

toMatchAux :: Clause -> Maybe Match -> Maybe Match
toMatchAux [] m = m
toMatchAux (p:ps) mm = do m <- mm 
                          toMatchAux ps (literalToMatchUpdate p m)

literalToMatchUpdate :: Literal -> Match -> Maybe Match
literalToMatchUpdate TruePP m                      = Just matchAny
literalToMatchUpdate FalsePP m                     = undefined
literalToMatchUpdate (InPortIs p)    (Match {..})  = Just $ Match { inPort = Just p, .. } 
literalToMatchUpdate (MacSourceIs a) m             = Just $ m { srcEthAddress    = Just a }
literalToMatchUpdate (MacDestIs a)   m             = Just $ m { dstEthAddress    = Just a } 
literalToMatchUpdate (VLANIDIs  a)   m             = Just $ m { vLANID           = Just a }
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1    
literalToMatchUpdate (VLANPriority a)   m          = Just $ m { vLANPriority     = Just a }
#endif
literalToMatchUpdate (MacFrameTypeIs a) m          = Just $ m { ethFrameType     = Just a } 
literalToMatchUpdate (IPProtocolIs a) m            = Just $ m { Match.ipProtocol = Just a }
#if OPENFLOW_VERSION==1    
literalToMatchUpdate (IPTypeOfService a) m         = Just $ m { Match.ipTypeOfService = Just a }
#endif
literalToMatchUpdate (IPSourceIn prefix) m         = Just $ m { srcIPAddress     = prefix }
literalToMatchUpdate (IPDestIn prefix) m           = Just $ m { dstIPAddress     = prefix }
literalToMatchUpdate (SenderTransportPortIs  a)  m = Just $ m { srcTransportPort = Just a } 
literalToMatchUpdate (ReceiverTransportPortIs a) m = Just $ m { dstTransportPort = Just a } 

-- | Calculates a packet predicate that matches the same in-packets as the given match.
fromMatch :: Match -> PacketPredicate
fromMatch (Match {Match.ipTypeOfService=ipTypeOfService',..}) = 
  ands [ maybe TruePP inPortIs inPort, 
         maybe TruePP ethSourceIs srcEthAddress,
         maybe TruePP ethDestIs dstEthAddress,
         maybe TruePP vLANIDIs vLANID,
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1             
         maybe TruePP VLANPriority vLANPriority,
#endif         
         maybe TruePP ethFrameTypeIs ethFrameType,
         maybe TruePP transportProtocolIs ipProtocol,
#if OPENFLOW_VERSION==1             
         maybe TruePP ipTypeOfService ipTypeOfService',
#endif
         if srcIPAddress == defaultIPPrefix then TruePP else ipSourceIn srcIPAddress,
         if dstIPAddress == defaultIPPrefix then TruePP else ipDestIn dstIPAddress,
         maybe TruePP senderTransportIs srcTransportPort,
         maybe TruePP receiverTransportIs dstTransportPort
       ]

packetInMatches :: PacketInfo -> PacketPredicate -> Either ErrorMessage Bool
packetInMatches pktIn pred = 
  do ethFrame <- runGet (runErrorT getEthernetFrame) (packetData pktIn)
     return ((receivedOnPort pktIn, ethFrame) |- pred)

packetInFrame :: PacketInfo -> Either ErrorMessage EthernetFrame
packetInFrame = runGet (runErrorT getEthernetFrame) . packetData 

exactPredicate :: PacketInfo -> Either ErrorMessage PacketPredicate
exactPredicate pktIn = 
  do match <- runGet (runErrorT (getExactMatch inPort)) pktData 
     return $ fromMatch match
    where inPort  = receivedOnPort pktIn
          pktData = packetData pktIn