module Data.IP.Op where

import Data.Bits
import Data.IP.Addr
import Data.IP.Mask
import Data.IP.Range

----------------------------------------------------------------

-- |
--
-- >>> toIPv4 [127,0,2,1] `masked` intToMask 7
-- 126.0.0.0
class Eq a => Addr a where
    -- |
    --       The 'masked' function takes an 'Addr' and a contiguous
    --       mask and returned a masked 'Addr'.
    masked :: a -> a -> a

    -- |
    --
    --       The 'intToMask' function takes an 'Int' representing the number of bits to
    --       be set in the returned contiguous mask. When this integer is positive the
    --       bits will be starting from the MSB and from the LSB otherwise.
    --
    --       >>> intToMask 16 :: IPv4
    --       255.255.0.0
    --
    --       >>> intToMask (-16) :: IPv4
    --       0.0.255.255
    --
    --       >>> intToMask 16 :: IPv6
    --       ffff::
    --
    --       >>> intToMask (-16) :: IPv6
    --       ::ffff
    intToMask :: Int -> a

instance Addr IPv4 where
    masked :: IPv4 -> IPv4 -> IPv4
masked = IPv4 -> IPv4 -> IPv4
maskedIPv4
    intToMask :: Int -> IPv4
intToMask = Int -> IPv4
maskIPv4

instance Addr IPv6 where
    masked :: IPv6 -> IPv6 -> IPv6
masked = IPv6 -> IPv6 -> IPv6
maskedIPv6
    intToMask :: Int -> IPv6
intToMask = Int -> IPv6
maskIPv6

----------------------------------------------------------------

-- |
--   The >:> operator takes two 'AddrRange'. It returns 'True' if
--   the first 'AddrRange' contains the second 'AddrRange'. Otherwise,
--   it returns 'False'.
--
-- >>> makeAddrRange ("127.0.2.1" :: IPv4) 8 >:> makeAddrRange "127.0.2.1" 24
-- True
-- >>> makeAddrRange ("127.0.2.1" :: IPv4) 24 >:> makeAddrRange "127.0.2.1" 8
-- False
-- >>> makeAddrRange ("2001:DB8::1" :: IPv6) 16 >:> makeAddrRange "2001:DB8::1" 32
-- True
-- >>> makeAddrRange ("2001:DB8::1" :: IPv6) 32 >:> makeAddrRange "2001:DB8::1" 16
-- False
(>:>) :: Addr a => AddrRange a -> AddrRange a -> Bool
AddrRange a
a >:> :: forall a. Addr a => AddrRange a -> AddrRange a -> Bool
>:> AddrRange a
b = AddrRange a -> Int
forall a. AddrRange a -> Int
mlen AddrRange a
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= AddrRange a -> Int
forall a. AddrRange a -> Int
mlen AddrRange a
b Bool -> Bool -> Bool
&& (AddrRange a -> a
forall a. AddrRange a -> a
addr AddrRange a
b a -> a -> a
forall a. Addr a => a -> a -> a
`masked` AddrRange a -> a
forall a. AddrRange a -> a
mask AddrRange a
a) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== AddrRange a -> a
forall a. AddrRange a -> a
addr AddrRange a
a

-- |
--   The 'isMatchedTo' function take an 'Addr' address and an 'AddrRange',
--   and returns 'True' if the range contains the address.
--
-- >>> ("127.0.2.0" :: IPv4) `isMatchedTo` makeAddrRange "127.0.2.1" 24
-- True
-- >>> ("127.0.2.0" :: IPv4) `isMatchedTo` makeAddrRange "127.0.2.1" 32
-- False
-- >>> ("2001:DB8::1" :: IPv6) `isMatchedTo` makeAddrRange "2001:DB8::1" 32
-- True
-- >>> ("2001:DB8::" :: IPv6) `isMatchedTo` makeAddrRange "2001:DB8::1" 128
-- False
isMatchedTo :: Addr a => a -> AddrRange a -> Bool
isMatchedTo :: forall a. Addr a => a -> AddrRange a -> Bool
isMatchedTo a
a AddrRange a
r = a
a a -> a -> a
forall a. Addr a => a -> a -> a
`masked` AddrRange a -> a
forall a. AddrRange a -> a
mask AddrRange a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== AddrRange a -> a
forall a. AddrRange a -> a
addr AddrRange a
r

-- |
--   The 'makeAddrRange' functions takes an 'Addr' address and a mask
--   length. It creates a bit mask from the mask length and masks
--   the 'Addr' address, then returns 'AddrRange' made of them.
--
-- >>> makeAddrRange (toIPv4 [127,0,2,1]) 8
-- 127.0.0.0/8
-- >>> makeAddrRange (toIPv6 [0x2001,0xDB8,0,0,0,0,0,1]) 8
-- 2000::/8
makeAddrRange :: Addr a => a -> Int -> AddrRange a
makeAddrRange :: forall a. Addr a => a -> Int -> AddrRange a
makeAddrRange a
ad Int
len = a -> a -> Int -> AddrRange a
forall a. a -> a -> Int -> AddrRange a
AddrRange a
adr a
msk Int
len
  where
    msk :: a
msk = Int -> a
forall a. Addr a => Int -> a
intToMask Int
len
    adr :: a
adr = a
ad a -> a -> a
forall a. Addr a => a -> a -> a
`masked` a
msk

-- | Convert IPv4 range to IPV4-embedded-in-IPV6 range
ipv4RangeToIPv6 :: AddrRange IPv4 -> AddrRange IPv6
ipv4RangeToIPv6 :: AddrRange IPv4 -> AddrRange IPv6
ipv4RangeToIPv6 AddrRange IPv4
range =
    IPv6 -> Int -> AddrRange IPv6
forall a. Addr a => a -> Int -> AddrRange a
makeAddrRange
        ([Int] -> IPv6
toIPv6 [Int
0, Int
0, Int
0, Int
0, Int
0, Int
0xffff, (Int
i1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shift` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i2, (Int
i3 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shift` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i4])
        (Int
masklen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
96)
  where
    (IPv4
ip, Int
masklen) = AddrRange IPv4 -> (IPv4, Int)
forall a. Addr a => AddrRange a -> (a, Int)
addrRangePair AddrRange IPv4
range
    [Int
i1, Int
i2, Int
i3, Int
i4] = IPv4 -> [Int]
fromIPv4 IPv4
ip

-- |
--   The 'unmakeAddrRange' functions take a 'AddrRange' and
--   returns the network address and a mask length.
--
-- >>> addrRangePair ("127.0.0.0/8" :: AddrRange IPv4)
-- (127.0.0.0,8)
-- >>> addrRangePair ("2000::/8" :: AddrRange IPv6)
-- (2000::,8)
addrRangePair :: Addr a => AddrRange a -> (a, Int)
addrRangePair :: forall a. Addr a => AddrRange a -> (a, Int)
addrRangePair (AddrRange a
adr a
_ Int
len) = (a
adr, Int
len)