{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE DeriveAnyClass  #-}
module Data.Range( EndPoint(..)
                 , isOpen, isClosed
                 , unEndPoint
                 , Range(..)
                 , prettyShow
                 , lower, upper
                 , pattern OpenRange, pattern ClosedRange, pattern Range'
                 , inRange, width, clipLower, clipUpper, midPoint, clampTo
                 , isValid, covers
                 , shiftLeft, shiftRight
                 ) where
import Control.DeepSeq
import Control.Lens
import Data.Intersection
import Data.Vinyl.CoRec
import GHC.Generics (Generic)
import Test.QuickCheck
import Text.Printf (printf)
data EndPoint a = Open   !a
                | Closed !a
                deriving (Show,Read,Eq,Functor,Foldable,Traversable,Generic,NFData)
instance Ord a => Ord (EndPoint a) where
  
  a `compare` b = f a `compare` f b
    where
      f (Open x)   = (x,False)
      f (Closed x) = (x,True)
instance Arbitrary r => Arbitrary (EndPoint r) where
  arbitrary = frequency [ (1, Open   <$> arbitrary)
                        , (9, Closed <$> arbitrary)
                        ]
_unEndPoint            :: EndPoint a -> a
_unEndPoint (Open a)   = a
_unEndPoint (Closed a) = a
unEndPoint :: Lens (EndPoint a) (EndPoint b) a b
unEndPoint = lens _unEndPoint f
  where
    f (Open _) a   = Open a
    f (Closed _) a = Closed a
{-# INLINE unEndPoint #-}
isOpen          :: EndPoint a -> Bool
isOpen (Open _) = True
isOpen _        = False
isClosed :: EndPoint a -> Bool
isClosed = not . isOpen
data Range a = Range { _lower :: !(EndPoint a)
                     , _upper :: !(EndPoint a)
                     }
               deriving (Eq,Functor,Foldable,Traversable,Generic,NFData)
makeLenses ''Range
instance Show a => Show (Range a) where
  show (Range l u) = printf "Range (%s) (%s)" (show l) (show u)
pattern OpenRange       :: a -> a -> Range a
pattern OpenRange   l u = Range (Open l)   (Open u)
pattern ClosedRange     :: a -> a -> Range a
pattern ClosedRange l u = Range (Closed l) (Closed u)
pattern Range'     :: a -> a -> Range a
pattern Range' l u <- ((\r -> (r^.lower.unEndPoint,r^.upper.unEndPoint) -> (l,u)))
{-# COMPLETE Range' #-}
instance (Arbitrary r, Ord r) => Arbitrary (Range r) where
  arbitrary = do
                l <- arbitrary
                r <- suchThat arbitrary (p l)
                return $ Range l r
   where
     p (Open l)   r = l <  r^.unEndPoint
     p (Closed l) r = l <= r^.unEndPoint
prettyShow             :: Show a => Range a -> String
prettyShow (Range l u) = concat [ lowerB, show (l^.unEndPoint), ","
                                , show (u^.unEndPoint), upperB
                                ]
  where
    lowerB = if isOpen l then "(" else "["
    upperB = if isOpen u then ")" else "]"
inRange                 :: Ord a => a -> Range a -> Bool
x `inRange` (Range l u) = case ((l^.unEndPoint) `compare` x, x `compare` (u^.unEndPoint)) of
    (_, GT) -> False
    (GT, _) -> False
    (LT,LT) -> True
    (LT,EQ) -> include u 
    (EQ,LT) -> include l 
    (EQ,EQ) -> include l && include u 
  where
    include = isClosed
type instance IntersectionOf (Range a) (Range a) = [ NoIntersection, Range a]
instance Ord a => (Range a) `IsIntersectableWith` (Range a) where
  nonEmptyIntersection = defaultNonEmptyIntersection
  
  
  (Range l u) `intersect` s = let i = clipLower' l . clipUpper' u $ s
                              in if isValid i then coRec i else coRec NoIntersection
width   :: Num r => Range r -> r
width i = i^.upper.unEndPoint - i^.lower.unEndPoint
midPoint   :: Fractional r => Range r -> r
midPoint r = let w = width r in r^.lower.unEndPoint + (w / 2)
clampTo                :: Ord r => Range r -> r -> r
clampTo (Range' l u) x = (x `max` l) `min` u
clipLower     :: Ord a => EndPoint a -> Range a -> Maybe (Range a)
clipLower l r = let r' = clipLower' l r in if isValid r' then Just r' else Nothing
clipUpper     :: Ord a => EndPoint a -> Range a -> Maybe (Range a)
clipUpper u r = let r' = clipUpper' u r in if isValid r' then Just r' else Nothing
covers       :: forall a. Ord a => Range a -> Range a -> Bool
x `covers` y = maybe False (== y) . asA @(Range a) $ x `intersect` y
isValid             :: Ord a => Range a -> Bool
isValid (Range l u) = case (_unEndPoint l) `compare` (_unEndPoint u) of
                          LT                            -> True
                          EQ | isClosed l || isClosed u -> True
                          _                             -> False
clipLower'                  :: Ord a => EndPoint a -> Range a -> Range a
clipLower' l' r@(Range l u) = case l' `cmpLower` l of
                                GT -> Range l' u
                                _  -> r
clipUpper'                  :: Ord a => EndPoint a -> Range a -> Range a
clipUpper' u' r@(Range l u) = case u' `cmpUpper` u of
                                LT -> Range l u'
                                _  -> r
cmpLower     :: Ord a => EndPoint a -> EndPoint a -> Ordering
cmpLower a b = case (_unEndPoint a) `compare` (_unEndPoint b) of
                 LT -> LT
                 GT -> GT
                 EQ -> case (a,b) of
                         (Open _,   Open _)   -> EQ  
                         (Closed _, Closed _) -> EQ
                         (Open _,  _)         -> GT  
                         (Closed _,_)         -> LT  
cmpUpper     :: Ord a => EndPoint a -> EndPoint a -> Ordering
cmpUpper a b = case (_unEndPoint a) `compare` (_unEndPoint b) of
                 LT -> LT
                 GT -> GT
                 EQ -> case (a,b) of
                         (Open _,   Open _)   -> EQ  
                         (Closed _, Closed _) -> EQ
                         (Open _,  _)         -> LT  
                         (Closed _,_)         -> GT  
shiftLeft   :: Num r => r -> Range r -> Range r
shiftLeft x = shiftRight (-x)
shiftRight   :: Num r => r -> Range r -> Range r
shiftRight x = fmap (+x)