module Camfort.Specification.Stencils.Model ( Interval(..)
                                            , Bound(..)
                                            , approxVec
                                            , Offsets(..)
                                            , UnionNF
                                            , vecLength
                                            , unfCompare
                                            , optimise
                                            , maximas
                                            , Approximation(..)
                                            , lowerBound, upperBound
                                            , fromExact
                                            , Multiplicity(..)
                                            , Peelable(..)
                                            ) where
import qualified Control.Monad as CM
import           Algebra.Lattice
import           Data.Semigroup
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as S
import           Data.Foldable
import           Data.SBV
import           Data.Data
import           Data.List (sortBy, nub)
import           Data.Maybe (fromJust)
import qualified Data.PartialOrd as PO
import qualified Camfort.Helpers.Vec as V
import System.IO.Unsafe
class Container a where
  type MemberTyp a
  type CompTyp a
  member :: MemberTyp a -> a -> Bool
  compile :: a -> (CompTyp a -> SBool)
data Offsets =
    Offsets (S.Set Int64)
  | SetOfIntegers
  deriving Eq
instance Ord Offsets where
    Offsets s `compare` Offsets s' = s `compare` s'
    Offsets _ `compare` SetOfIntegers = LT
    SetOfIntegers `compare` Offsets _ = GT
    SetOfIntegers `compare` SetOfIntegers = EQ
instance Container Offsets where
  type MemberTyp Offsets = Int64
  type CompTyp Offsets = SInt64
  member i (Offsets s) = i `S.member` s
  member _ _ = True
  compile (Offsets s) i = i `sElem` map fromIntegral (S.toList s)
  compile SetOfIntegers _ = true
instance JoinSemiLattice Offsets where
  (Offsets s) \/ (Offsets s') = Offsets $ s `S.union` s'
  _ \/ _ = SetOfIntegers
instance MeetSemiLattice Offsets where
  (Offsets s) /\ (Offsets s') = Offsets $ s `S.intersection` s'
  off@Offsets{} /\ _ = off
  _ /\ o = o
instance Lattice Offsets
instance BoundedJoinSemiLattice Offsets where
  bottom = Offsets S.empty
instance BoundedMeetSemiLattice Offsets where
  top = SetOfIntegers
instance BoundedLattice Offsets
data Bound = Arbitrary | Standard
data Interval a where
  IntervArbitrary :: Int -> Int -> Interval Arbitrary
  IntervInfiniteArbitrary :: Interval Arbitrary
  IntervHoled     :: Int64 -> Int64 -> Bool -> Interval Standard
  IntervInfinite  :: Interval Standard
deriving instance Eq (Interval a)
instance Show (Interval Standard) where
  show IntervInfinite = "IntervInfinite"
  show (IntervHoled lb up p) =
    "Interv [" ++ show lb ++ "," ++ show up ++ "]^" ++ show p
approxInterv :: Interval Arbitrary -> Approximation (Interval Standard)
approxInterv (IntervArbitrary a b)
  | a > b = error
    "Interval condition violated: lower bound is bigger than the upper bound."
  | a <=  0, b >=  0 = Exact  $ IntervHoled a' b' True
  | a <= 1, b == 1 = Exact  $ IntervHoled a' 0  False
  | a ==  1, b >=  1 = Exact  $ IntervHoled 0  b' False
  | a >   1, b >   1 = Bound Nothing $ Just $ IntervHoled 0  b' False
  | a <  1, b <  1 = Bound Nothing $ Just $ IntervHoled a' 0  False
  | otherwise = error "Impossible: All posibilities are covered."
  where
    a' = fromIntegral a
    b' = fromIntegral b
approxInterv IntervInfiniteArbitrary = Exact IntervInfinite
approxVec :: forall n .
             V.Vec n (Interval Arbitrary)
          -> Approximation (V.Vec n (Interval Standard))
approxVec v =
  case findApproxIntervs stdVec of
    ([],_) -> Exact . fmap fromExact $ stdVec
    _      -> Bound Nothing (Just $ upperBound <$> stdVec)
  where
    stdVec :: V.Vec n (Approximation (Interval Standard))
    stdVec = fmap approxInterv v
    findApproxIntervs :: forall n . V.Vec n (Approximation (Interval Standard))
                      -> ([ Int ], [ Int ])
    findApproxIntervs v = findApproxIntervs' 0 v ([],[])
    findApproxIntervs' :: forall n . Int
                       -> V.Vec n (Approximation (Interval Standard))
                       -> ([ Int ], [ Int ])
                       -> ([ Int ], [ Int ])
    findApproxIntervs' _ V.Nil acc = acc
    findApproxIntervs' i (V.Cons x xs) (bixs, eixs) =
      findApproxIntervs' (i+1) xs $
        case x of
          Bound{} -> (i:bixs, eixs)
          Exact{} -> (bixs, i:eixs)
instance Container (Interval Standard) where
  type MemberTyp (Interval Standard) = Int64
  type CompTyp (Interval Standard) = SInt64
  member 0 (IntervHoled _ _ b) = b
  member i (IntervHoled a b _) = i >= a && i <= b
  member _ _ = True
  compile (IntervHoled i1 i2 b) i
    | b = inRange i range
    | otherwise = inRange i range &&& i ./= 0
    where
      range = (fromIntegral i1, fromIntegral i2)
  compile IntervInfinite _ = true
instance JoinSemiLattice (Interval Standard) where
  (IntervHoled lb ub noHole) \/ (IntervHoled lb' ub' noHole') =
    IntervHoled (min lb lb') (max ub ub') (noHole || noHole')
  _ \/ _ = top
instance MeetSemiLattice (Interval Standard) where
  (IntervHoled lb ub noHole) /\ (IntervHoled lb' ub' noHole') =
    IntervHoled (max lb lb') (min ub ub') (noHole && noHole')
  int@IntervHoled{} /\ _ = int
  _ /\ int = int
instance Lattice (Interval Standard)
instance BoundedJoinSemiLattice (Interval Standard) where
  bottom = IntervHoled 0 0 False
instance BoundedMeetSemiLattice (Interval Standard) where
  top = IntervInfinite
instance BoundedLattice (Interval Standard)
type UnionNF n a = NE.NonEmpty (V.Vec n a)
vecLength :: UnionNF n a -> V.Natural n
vecLength = V.lengthN . NE.head
instance Container a => Container (UnionNF n a) where
  type MemberTyp (UnionNF n a) = V.Vec n (MemberTyp a)
  type CompTyp (UnionNF n a) = V.Vec n (CompTyp a)
  member is = any (member' is)
    where
      member' is space = and $ V.zipWith member is space
  compile spaces is = foldr1 (|||) $ NE.map (`compile'` is) spaces
    where
      compile' space is =
        foldr' (\(set, i) -> (&&&) $ compile set i) true $ V.zip space is
instance JoinSemiLattice (UnionNF n a) where
  oi \/ oi' = oi <> oi'
instance BoundedLattice a => MeetSemiLattice (UnionNF n a) where
  (/\) = CM.liftM2 (V.zipWith (/\))
instance BoundedLattice a => Lattice (UnionNF n a)
unfCompare :: forall a b n . ( Container a,          Container b
                             , MemberTyp a ~ Int64,  MemberTyp b ~ Int64
                             , CompTyp a ~ SInt64,   CompTyp b ~ SInt64
                             )
           => UnionNF n a -> UnionNF n b -> Ordering
unfCompare oi oi' = unsafePerformIO $ do
    thmRes <- prove pred
    case thmRes of
      
      
      
      ThmResult (ProofError _ msgs) -> fail $ unlines msgs
      _ ->
        if modelExists thmRes
        then do
          ce <- counterExample thmRes
          case V.fromList ce of
             V.VecBox cev ->
               case V.proveEqSize (NE.head oi) cev of
                 Just V.ReflEq ->
                   
                   
                   
                   
                   if | cev `member` oi  -> return GT
                      | cev `member` oi' -> return LT
                      | otherwise -> fail
                         "Impossible: counter example is in \
                          \neither of the operands"
                 Nothing -> fail
                    "Impossible: Counter example size doesn't \
                    \match the original vector size."
        else return EQ
  where
    counterExample :: ThmResult -> IO [ Int64 ]
    counterExample thmRes =
      case getModelAssignment thmRes of
        Right (False, ce) -> return ce
        Right (True, _) -> fail "Returned probable model."
        Left str -> fail str
    pred :: Predicate
    pred = do
      freeVars <- (mkFreeVars . dimensionality) oi :: Symbolic [ SInt64 ]
      case V.fromList freeVars of
        V.VecBox freeVarVec ->
          case V.proveEqSize (NE.head oi) freeVarVec of
            Just V.ReflEq -> return $
              compile oi freeVarVec .== compile oi' freeVarVec
            Nothing -> fail $
              "Impossible: Free variables size doesn't match that of the " ++
              "union parameter."
    dimensionality = V.length . NE.head
instance PO.PartialOrd Offsets where
  (Offsets s) <= (Offsets s') = s <= s'
  SetOfIntegers <= Offsets{} = False
  _ <= SetOfIntegers = True
instance PO.PartialOrd (Interval Standard) where
  (IntervHoled lb ub p) <= (IntervHoled lb' ub' p') =
    (p' || not p) && lb >= lb' && ub <= ub'
  IntervInfinite <= IntervHoled{} = False
  _ <= IntervInfinite = True
instance PO.PartialOrd a => PO.PartialOrd (V.Vec n a) where
  v <= v' = and $ V.zipWith (PO.<=) v v'
optimise :: UnionNF n (Interval Standard) -> UnionNF n (Interval Standard)
optimise = NE.fromList . maximas . fixedPointUnion . NE.toList
  where
    fixedPointUnion unf =
      let unf' = unionLemma . maximas $ unf
      in if unf' == unf then unf' else fixedPointUnion unf'
sensibleGroupBy :: Eq a =>
                   (a -> a -> Ordering)
                -> (a -> a -> Bool)
                -> [ a ]
                -> [ [ a ] ]
sensibleGroupBy ord p l = nub . map (\el -> sortBy ord . filter (p el) $ l) $ l
maximas :: [ V.Vec n (Interval Standard) ] -> [ V.Vec n (Interval Standard) ]
maximas = nub
        . fmap (head . PO.maxima)
        . sensibleGroupBy ord (PO.<=)
  where
    ord a b = fromJust $ a `PO.compare` b
unionLemma :: [ V.Vec n (Interval Standard) ] -> [ V.Vec n (Interval Standard) ]
unionLemma = map (foldr1 (V.zipWith (\/)))
           . sensibleGroupBy (\a b -> if a == b then EQ else LT) agreeButOne
  where
    
    
    agreeButOne :: Eq a => V.Vec n a -> V.Vec n a -> Bool
    agreeButOne = go False
      where
        go :: Eq a => Bool -> V.Vec n a -> V.Vec n a -> Bool
        go _ V.Nil V.Nil = True
        go False (V.Cons x xs) (V.Cons y ys)
          | x == y = go False xs ys
          | otherwise = go True xs ys
        go True (V.Cons x xs) (V.Cons y ys)
          | x == y = go True xs ys
          | otherwise = False
data Approximation a = Exact a | Bound (Maybe a) (Maybe a)
  deriving (Eq, Show, Functor, Foldable, Traversable, Data, Typeable)
fromExact :: Approximation a -> a
fromExact (Exact a) = a
fromExact _ = error "Can't retrieve from bounded as if it was exact."
lowerBound :: Approximation a -> a
lowerBound (Bound (Just a) _) = a
lowerBound (Bound Nothing _) = error "Approximation doesn't have a lower bound."
lowerBound (Exact a) = a
upperBound :: Approximation a -> a
upperBound (Bound _ (Just a)) = a
upperBound (Bound _ Nothing) = error "Approximation doesn't have a upper bound."
upperBound (Exact a) = a
class Peelable a where
  type CoreTyp a
  peel :: a -> CoreTyp a
data Multiplicity a = Mult a | Once a
  deriving (Eq, Show, Functor, Foldable, Traversable, Data, Typeable)
instance Peelable (Multiplicity a) where
  type CoreTyp (Multiplicity a) = a
  peel (Mult a) = a
  peel (Once a) = a