{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
#if ( __GLASGOW_HASKELL__ < 820 )
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-}
#endif

module NumHask.Rect
  ( Rect(..)
  , pattern Rect
  , pattern Ranges
  , corners
  , projectRect
  ) where

import NumHask.Space
import NumHask.Range
import NumHask.Pair
import NumHask.Prelude
import Data.Functor.Compose
import Data.Functor.Apply (Apply(..))
import Data.Semigroup.Foldable (Foldable1(..))
import Data.Functor.Rep
import Data.Distributive

-- | a two-dimensional plane, implemented as a composite of a 'Pair' of 'Range's.
newtype Rect a = Rect' (Compose Pair Range a)
    deriving (Show, Eq, Functor, Apply, Applicative, Foldable, Foldable1, Traversable)

pattern Rect :: a -> a -> a -> a -> Rect a
pattern Rect a b c d = Rect' (Compose (Pair (Range a b) (Range c d)))
{-# COMPLETE Rect#-}

pattern Ranges :: Range a -> Range a -> Rect a
pattern Ranges a b = Rect' (Compose (Pair a b))
{-# COMPLETE Ranges#-}

instance (Ord a, BoundedField a, FromInteger a) => MultiplicativeMagma (Rect a) where
    times (Ranges x0 y0) (Ranges x1 y1) = Ranges (x0 `times` x1) (y0 `times` y1)

instance (Ord a, BoundedField a, FromInteger a) => MultiplicativeUnital (Rect a) where
    one = Ranges one one

instance (Ord a, FromInteger a, BoundedField a) => MultiplicativeAssociative (Rect a)

instance (Ord a, BoundedField a, FromInteger a) => MultiplicativeCommutative (Rect a)

instance (Ord a, BoundedField a, FromInteger a) => Multiplicative (Rect a)

instance (Ord a, FromInteger a, BoundedField a) => MultiplicativeInvertible (Rect a) where
    recip (Ranges x y) = Ranges (recip x) (recip y)

instance (Ord a, BoundedField a, FromInteger a) => MultiplicativeGroup (Rect a)

instance (AdditiveInvertible a, BoundedField a, Ord a, FromInteger a) => Signed (Rect a) where
    sign (Ranges l u) = Ranges (sign l) (sign u)
    abs (Ranges l u) = Ranges (sign l * l) (sign u * u)

instance (AdditiveGroup a) => Normed (Rect a) (Pair a) where
    size (Ranges l u) = Pair (size l) (size u)

instance Distributive Rect where
  collect f x =
      Rect
      (getA . f <$> x)
      (getB . f <$> x)
      (getC . f <$> x)
      (getD . f <$> x)
    where getA (Rect a _ _ _) = a
          getB (Rect _ b _ _) = b
          getC (Rect _ _ c _) = c
          getD (Rect _ _ _ d) = d

instance Representable Rect where
  type Rep Rect = (Bool, Bool)
  tabulate f =
      Rect
      (f (False, False))
      (f (False, True))
      (f (True, False))
      (f (True, True))
  index (Rect a _ _ _) (False, False) = a
  index (Rect _ b _ _) (False, True) = b
  index (Rect _ _ c _) (True, False) = c
  index (Rect _ _ _ d) (True, True) = d

instance (FromInteger a, Ord a, BoundedField a) => Space (Rect a) where
    type Element (Rect a) = Pair a
    nul = Ranges nul nul
    union (Ranges a b) (Ranges c d) = Ranges (a `union` c) (b `union` d)
    lower (Rect l0 _ l1 _) = Pair l0 l1
    upper (Rect _ u0 _ u1) = Pair u0 u1
    singleton (Pair x y) = Rect x x y y
    type Grid (Rect a) = Pair Int
    grid :: (FromInteger a) => Pos -> Rect a -> Pair Int -> [Pair a]
    grid o s n = (+ if o==MidPos then step/(one+one) else zero) <$> posns
      where
        posns = (lower s +) . (step *) . fmap fromIntegral <$>
            [Pair x y | x <- [x0..x1], y <- [y0..y1]]
        step = (/) (width s) (fromIntegral <$> n)
        (Pair x0 y0, Pair x1 y1) = case o of
                    OuterPos -> (zero,n)
                    InnerPos -> (one,n - one)
                    LowerPos -> (zero,n - one)
                    UpperPos -> (one,n)
                    MidPos -> (zero,n - one)
    gridSpace (Ranges rX rY) (Pair stepX stepY)=
        [ Rect x (x+sx) y (y+sy)
        | x <- grid LowerPos rX stepX
        , y <- grid LowerPos rY stepY
        ]
      where
        sx = width rX / fromIntegral stepX
        sy = width rY / fromIntegral stepY

instance (Ord a, BoundedField a, FromInteger a) => Monoid (Rect a) where
    mempty = nul
    mappend = union

corners :: (FromInteger a, BoundedField a, Ord a) => Rect a -> [Pair a]
corners r = [lower r, upper r]

-- | project a Rect from an old Rect range to a new one
projectRect :: (FromInteger a, Ord a, BoundedField a) =>
    Rect a -> Rect a -> Rect a -> Rect a
projectRect r0 r1 (Rect a b c d) = Rect a' b' c' d' where
    (Pair a' c') = project r0 r1 (Pair a c)
    (Pair b' d') = project r0 r1 (Pair b d)