module Data.IP.Op where

import Data.Bits
import Data.IP.Addr
import Data.IP.Mask
import Data.IP.Range

----------------------------------------------------------------

{-|

>>> toIPv4 [127,0,2,1] `masked` intToMask 7
126.0.0.0
-}
class Eq a => Addr a where
    {-|
      The 'masked' function takes an 'Addr' and a contiguous
      mask and returned a masked 'Addr'.
    -}
    masked :: a -> a -> a
    {-|

      The 'intToMask' function takes an 'Int' representing the number of bits to
      be set in the returned contiguous mask. When this integer is positive the
      bits will be starting from the MSB and from the LSB otherwise.

      >>> intToMask 16 :: IPv4
      255.255.0.0

      >>> intToMask (-16) :: IPv4
      0.0.255.255

      >>> intToMask 16 :: IPv6
      ffff::

      >>> intToMask (-16) :: IPv6
      ::ffff

    -}
    intToMask :: Int -> a

instance Addr IPv4 where
    masked :: IPv4 -> IPv4 -> IPv4
masked    = IPv4 -> IPv4 -> IPv4
maskedIPv4
    intToMask :: Int -> IPv4
intToMask = Int -> IPv4
maskIPv4

instance Addr IPv6 where
    masked :: IPv6 -> IPv6 -> IPv6
masked    = IPv6 -> IPv6 -> IPv6
maskedIPv6
    intToMask :: Int -> IPv6
intToMask = Int -> IPv6
maskIPv6

----------------------------------------------------------------

{-|
  The >:> operator takes two 'AddrRange'. It returns 'True' if
  the first 'AddrRange' contains the second 'AddrRange'. Otherwise,
  it returns 'False'.

>>> makeAddrRange ("127.0.2.1" :: IPv4) 8 >:> makeAddrRange "127.0.2.1" 24
True
>>> makeAddrRange ("127.0.2.1" :: IPv4) 24 >:> makeAddrRange "127.0.2.1" 8
False
>>> makeAddrRange ("2001:DB8::1" :: IPv6) 16 >:> makeAddrRange "2001:DB8::1" 32
True
>>> makeAddrRange ("2001:DB8::1" :: IPv6) 32 >:> makeAddrRange "2001:DB8::1" 16
False
-}
(>:>) :: Addr a => AddrRange a -> AddrRange a -> Bool
AddrRange a
a >:> :: AddrRange a -> AddrRange a -> Bool
>:> AddrRange a
b = AddrRange a -> Int
forall a. AddrRange a -> Int
mlen AddrRange a
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= AddrRange a -> Int
forall a. AddrRange a -> Int
mlen AddrRange a
b Bool -> Bool -> Bool
&& (AddrRange a -> a
forall a. AddrRange a -> a
addr AddrRange a
b a -> a -> a
forall a. Addr a => a -> a -> a
`masked` AddrRange a -> a
forall a. AddrRange a -> a
mask AddrRange a
a) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== AddrRange a -> a
forall a. AddrRange a -> a
addr AddrRange a
a

{-|
  The 'isMatchedTo' function take an 'Addr' address and an 'AddrRange',
  and returns 'True' if the range contains the address.

>>> ("127.0.2.0" :: IPv4) `isMatchedTo` makeAddrRange "127.0.2.1" 24
True
>>> ("127.0.2.0" :: IPv4) `isMatchedTo` makeAddrRange "127.0.2.1" 32
False
>>> ("2001:DB8::1" :: IPv6) `isMatchedTo` makeAddrRange "2001:DB8::1" 32
True
>>> ("2001:DB8::" :: IPv6) `isMatchedTo` makeAddrRange "2001:DB8::1" 128
False
-}

isMatchedTo :: Addr a => a -> AddrRange a -> Bool
isMatchedTo :: a -> AddrRange a -> Bool
isMatchedTo a
a AddrRange a
r = a
a a -> a -> a
forall a. Addr a => a -> a -> a
`masked` AddrRange a -> a
forall a. AddrRange a -> a
mask AddrRange a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== AddrRange a -> a
forall a. AddrRange a -> a
addr AddrRange a
r

{-|
  The 'makeAddrRange' functions takes an 'Addr' address and a mask
  length. It creates a bit mask from the mask length and masks
  the 'Addr' address, then returns 'AddrRange' made of them.

>>> makeAddrRange (toIPv4 [127,0,2,1]) 8
127.0.0.0/8
>>> makeAddrRange (toIPv6 [0x2001,0xDB8,0,0,0,0,0,1]) 8
2000::/8
-}
makeAddrRange :: Addr a => a -> Int -> AddrRange a
makeAddrRange :: a -> Int -> AddrRange a
makeAddrRange a
ad Int
len = a -> a -> Int -> AddrRange a
forall a. a -> a -> Int -> AddrRange a
AddrRange a
adr a
msk Int
len
  where
    msk :: a
msk = Int -> a
forall a. Addr a => Int -> a
intToMask Int
len
    adr :: a
adr = a
ad a -> a -> a
forall a. Addr a => a -> a -> a
`masked` a
msk


-- | Convert IPv4 range to IPV4-embedded-in-IPV6 range
ipv4RangeToIPv6 :: AddrRange IPv4 -> AddrRange IPv6
ipv4RangeToIPv6 :: AddrRange IPv4 -> AddrRange IPv6
ipv4RangeToIPv6 AddrRange IPv4
range =
  IPv6 -> Int -> AddrRange IPv6
forall a. Addr a => a -> Int -> AddrRange a
makeAddrRange ([Int] -> IPv6
toIPv6 [Int
0,Int
0,Int
0,Int
0,Int
0,Int
0xffff, (Int
i1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shift` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i2, (Int
i3 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shift` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i4]) (Int
masklen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
96)
  where
    (IPv4
ip, Int
masklen) = AddrRange IPv4 -> (IPv4, Int)
forall a. Addr a => AddrRange a -> (a, Int)
addrRangePair AddrRange IPv4
range
    [Int
i1,Int
i2,Int
i3,Int
i4] = IPv4 -> [Int]
fromIPv4 IPv4
ip


{-|
  The 'unmakeAddrRange' functions take a 'AddrRange' and
  returns the network address and a mask length.

>>> addrRangePair ("127.0.0.0/8" :: AddrRange IPv4)
(127.0.0.0,8)
>>> addrRangePair ("2000::/8" :: AddrRange IPv6)
(2000::,8)
-}
addrRangePair :: Addr a => AddrRange a -> (a, Int)
addrRangePair :: AddrRange a -> (a, Int)
addrRangePair (AddrRange a
adr a
_ Int
len) = (a
adr, Int
len)