module Data.IP.RouteTable.Internal where
#if __GLASGOW_HASKELL__ < 709
import Control.Applicative ((<$>),(<*>),pure)
#endif
import Control.Monad
import Data.Bits
import Data.Foldable (Foldable(..))
import Data.IP.Addr
import Data.IP.Op
import Data.IP.Range
import Data.IntMap (IntMap, (!))
import qualified Data.IntMap as IM (fromList)
import Data.Monoid
import Data.Traversable
import Data.Word
import GHC.Generics (Generic, Generic1)
import Prelude hiding (lookup)
class Addr a => Routable a where
intToTBit :: Int -> a
isZero :: a -> a -> Bool
instance Routable IPv4 where
intToTBit = intToTBitIPv4
isZero a b = a `masked` b == IP4 0
instance Routable IPv6 where
intToTBit = intToTBitIPv6
isZero a b = a `masked` b == IP6 (0,0,0,0)
intToTBitIPv4 :: Int -> IPv4
intToTBitIPv4 len = IP4 (intToTBitsIPv4 ! len)
intToTBitIPv6 :: Int -> IPv6
intToTBitIPv6 len = IP6 (intToTBitsIPv6 ! len)
intToTBitsWord32 :: [Word32]
intToTBitsWord32 = iterate (`shift` (1)) 0x80000000
intToTBitsIPv4 :: IntMap IPv4Addr
intToTBitsIPv4 = IM.fromList $ zip [0..32] intToTBitsWord32
intToTBitsIPv6 :: IntMap IPv6Addr
intToTBitsIPv6 = IM.fromList $ zip [0..128] bs
where
bs = b1 ++ b2 ++ b3 ++ b4 ++ b5
b1 = map (\vbit -> (vbit,all0,all0,all0)) intToTBits
b2 = map (\vbit -> (all0,vbit,all0,all0)) intToTBits
b3 = map (\vbit -> (all0,all0,vbit,all0)) intToTBits
b4 = map (\vbit -> (all0,all0,all0,vbit)) intToTBits
b5 = [(all0,all0,all0,all0)]
intToTBits = take 32 intToTBitsWord32
all0 = 0x00000000
data IPRTable k a =
Nil
| Node !(AddrRange k) !k !(Maybe a) !(IPRTable k a) !(IPRTable k a)
deriving (Eq, Generic, Generic1, Show)
empty :: Routable k => IPRTable k a
empty = Nil
instance Functor (IPRTable k) where
fmap _ Nil = Nil
fmap f (Node r a mv b1 b2) = Node r a (f <$> mv) (fmap f b1) (fmap f b2)
instance Foldable (IPRTable k) where
foldMap _ Nil = mempty
foldMap f (Node _ _ mv b1 b2) = foldMap f mv <> foldMap f b1 <> foldMap f b2
instance Traversable (IPRTable k) where
traverse _ Nil = pure Nil
traverse f (Node r a mv b1 b2) = Node r a <$> traverse f mv <*> traverse f b1 <*> traverse f b2
insert :: (Routable k) => AddrRange k -> a -> IPRTable k a -> IPRTable k a
insert k1 v1 Nil = Node k1 tb1 (Just v1) Nil Nil
where
tb1 = keyToTestBit k1
insert k1 v1 s@(Node k2 tb2 v2 l r)
| k1 == k2 = Node k1 tb1 (Just v1) l r
| k2 >:> k1 = if isLeft k1 tb2 then
Node k2 tb2 v2 (insert k1 v1 l) r
else
Node k2 tb2 v2 l (insert k1 v1 r)
| k1 >:> k2 = if isLeft k2 tb1 then
Node k1 tb1 (Just v1) s Nil
else
Node k1 tb1 (Just v1) Nil s
| otherwise = let n = Node k1 tb1 (Just v1) Nil Nil
in link n s
where
tb1 = keyToTestBit k1
link :: Routable k => IPRTable k a -> IPRTable k a -> IPRTable k a
link s1@(Node k1 _ _ _ _) s2@(Node k2 _ _ _ _)
| isLeft k1 tbg = Node kg tbg Nothing s1 s2
| otherwise = Node kg tbg Nothing s2 s1
where
kg = glue 0 k1 k2
tbg = keyToTestBit kg
link _ _ = error "link"
glue :: (Routable k) => Int -> AddrRange k -> AddrRange k -> AddrRange k
glue n k1 k2
| addr k1 `masked` mk == addr k2 `masked` mk = glue (n + 1) k1 k2
| otherwise = makeAddrRange (addr k1) (n 1)
where
mk = intToMask n
keyToTestBit :: Routable k => AddrRange k -> k
keyToTestBit = intToTBit . mlen
isLeft :: Routable k => AddrRange k -> k -> Bool
isLeft adr = isZero (addr adr)
delete :: (Routable k) => AddrRange k -> IPRTable k a -> IPRTable k a
delete _ Nil = Nil
delete k1 s@(Node k2 tb2 v2 l r)
| k1 == k2 = node k2 tb2 Nothing l r
| k2 >:> k1 = if isLeft k1 tb2 then
node k2 tb2 v2 (delete k1 l) r
else
node k2 tb2 v2 l (delete k1 r)
| otherwise = s
node :: (Routable k) => AddrRange k -> k -> Maybe a -> IPRTable k a -> IPRTable k a -> IPRTable k a
node _ _ Nothing Nil r = r
node _ _ Nothing l Nil = l
node k tb v l r = Node k tb v l r
lookup :: Routable k => AddrRange k -> IPRTable k a -> Maybe a
lookup k s = fmap snd (search k s Nothing)
lookupKeyValue :: Routable k => AddrRange k -> IPRTable k a -> Maybe (AddrRange k, a)
lookupKeyValue k s = search k s Nothing
search :: Routable k => AddrRange k
-> IPRTable k a
-> Maybe (AddrRange k, a)
-> Maybe (AddrRange k, a)
search _ Nil res = res
search k1 (Node k2 tb2 Nothing l r) res
| k1 == k2 = res
| k2 >:> k1 = if isLeft k1 tb2 then
search k1 l res
else
search k1 r res
| otherwise = res
search k1 (Node k2 tb2 (Just vl) l r) res
| k1 == k2 = Just (k1, vl)
| k2 >:> k1 = if isLeft k1 tb2 then
search k1 l $ Just (k2, vl)
else
search k1 r $ Just (k2, vl)
| otherwise = res
findMatch :: MonadPlus m => Routable k => AddrRange k -> IPRTable k a -> m (AddrRange k, a)
findMatch _ Nil = mzero
findMatch k1 (Node k2 _ Nothing l r)
| k1 >:> k2 = findMatch k1 l `mplus` findMatch k1 r
| k2 >:> k1 = findMatch k1 l `mplus` findMatch k1 r
| otherwise = mzero
findMatch k1 (Node k2 _ (Just vl) l r)
| k1 >:> k2 = return (k2, vl) `mplus` findMatch k1 l `mplus` findMatch k1 r
| k2 >:> k1 = findMatch k1 l `mplus` findMatch k1 r
| otherwise = mzero
fromList :: Routable k => [(AddrRange k, a)] -> IPRTable k a
fromList = foldl' (\s (k,v) -> insert k v s) empty
toList :: Routable k => IPRTable k a -> [(AddrRange k, a)]
toList = foldt toL []
where
toL Nil xs = xs
toL (Node _ _ Nothing _ _) xs = xs
toL (Node k _ (Just a) _ _) xs = (k,a) : xs
foldt :: (IPRTable k a -> b -> b) -> b -> IPRTable k a -> b
foldt _ v Nil = v
foldt func v rt@(Node _ _ _ l r) = foldt func (foldt func (func rt v) l) r