module DDC.Core.Flow.Transform.Rates.Constraints
        ( Constraint(..)
        , ConstraintMap, EquivClass
        , canonName
        , checkBindConstraints
        , getMaxSize )
where
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Transform.Rates.Fail
import Control.Monad
import qualified Data.Map               as Map
import qualified Data.Set               as Set



-- | Constraint information
-- An equal can have multiple - eg map3
-- Filtered only has its source input
data Constraint
 = ConEqual     [Name]
 | ConFiltered   Name
 deriving (Eq,Show)

type ConstraintMap = Map.Map Name Constraint
type EquivClass    = [Set.Set Name]


-- | Get canonical name for given equivalence class
-- Return original if there is none
-- (for example, a filter with no maps applied would have none since equiv classes are only built from maps)
canonName :: EquivClass -> Name -> Name
canonName equivs n
 = case equivSet equivs n of
    Nothing -> n
    Just s  -> Set.findMin s


-- | Get set of associated names in given equivalence class
equivSet :: EquivClass -> Name -> Maybe (Set.Set Name)
equivSet equivs n = go equivs
 where
  -- No classes left, not found
  go []
   = Nothing

  -- If @n@ is a member of this class, return it
  go (c:cs')
   | Set.member n c
   = Just c

   -- Check the rest 
   | otherwise
   = go cs'


-- | Check constraints for a single function body's bindings.
-- The bindings must be in a-normal form.
checkBindConstraints :: [(Name,ExpF)] -> LogFailures (ConstraintMap, EquivClass)
checkBindConstraints binds
 = -- Generate all constraints
   let constrs       = getConstraints binds
   -- Squash down eqs into equivalence classes
       equivs        = equivConstrs   constrs
   -- Get filter constraints as pairs
       filts         = filterConstrs  constrs equivs

   -- Check for ill-formed constraints:
   --      Filter "a <= a" is bad, as restricts to a=a
   --      Filter "a <= b" and "a <= c" is bad because 'a' mentioned twice in lhs
   in   checkFilters filts >> return (constrs, equivs)


getMaxSize :: ConstraintMap -> EquivClass -> [Name] -> Name -> Name
getMaxSize constrs equivs mans get
 = let get' = upFiltered get
   in  getFromMans get'
 where
  -- Keep moving up through filtered constraints until we hit the top
  upFiltered g
   | Just eqs <- equivSet equivs g
   = upFiltered' g (Set.toList eqs)
   | otherwise
   = g

  upFiltered' g []
   = g
  upFiltered' g (e:es)
   | Just (ConFiltered g') <- Map.lookup e constrs
   = upFiltered g'
   | otherwise
   = upFiltered' g es

  -- Find a manifest vector in the same equivalence class
  getFromMans g
   = let e = canonName equivs g
     in  getFromMans' e mans

  getFromMans' g []
   = g
  getFromMans' g (m:ms)
   | g == canonName equivs m
   = m
   | otherwise
   = getFromMans' g ms
   
 

-- | Squash constraints into equivalence classes
-- I'm sure this could be smarter.
equivConstrs :: ConstraintMap -> EquivClass
equivConstrs m
 = let sets = filter (not . Set.null)
            $ map gen
            $ Map.toList m
   in  squash sets []
 where
  -- Simply generate a set from each constraint
  gen (k, (ConEqual eqs))
   = Set.fromList (k:eqs)
  -- Ignore filter constraints
  gen (k, (ConFiltered _from))
   = Set.singleton k

  -- Squash constraint sets together
  squash []     acc
   = acc

  squash (a:as) acc
   -- Try to merge the @a@ set into @acc@ somewhere
   -- If so, start merging the whole thing again
   | Just merged <- squash_merge a acc
   = squash (merged ++ as) []

   -- Nothing in @a@ is mentioned in @acc@, so no merging required:
   --   just add this set to the accumulator
   | otherwise
   = squash as (a:acc)

  squash_merge ins (s:ss)
   -- Check if any members of @ins@ are mentioned in @s@
   -- If so, merge them into one equivalence class
   | not $ Set.null $ ins `Set.intersection` s
   = Just (ins `Set.union` s : ss)

   -- Check if there is a chance to merge later
   | Just ss' <- squash_merge ins ss
   = Just (s : ss')

  -- No merge is possible
  squash_merge _ins _ss
   = Nothing


-- Get canonical names of all filter constraints
filterConstrs :: ConstraintMap -> EquivClass -> [(Name,Name, Name, Name)]
filterConstrs m equivs = Map.foldWithKey go [] m
 where
  go k (ConFiltered src) ms
   = (canonName equivs k, canonName equivs src, k, src) : ms
  go _  _                ms
   = ms


-- | Generate constraints map from bindings
getConstraints :: [(Name,ExpF)] -> ConstraintMap
getConstraints lets
 = foldl go Map.empty lets
 where
  go m (n,x)
   | Just (n',c) <- getConstraint n x 
   = Map.insert n' c m
   | otherwise
   = m

getConstraint :: Name -> ExpF -> Maybe (Name, Constraint)
getConstraint n xx
 | Just (f, args)                   <- takeXApps xx
 , XVar (UPrim (NameOpVector ov) _) <- f
 = case ov of
   OpVectorMap i
    -- Args:
    -- map1 :: [a b   : *]. (a -> b)      -> Vector a -> Vector b
    -- (drop 3)
    -- map2 :: [a b c : *]. (a -> b -> c) -> Vector a -> Vector b -> Vector c
    -- (drop 4)
    | vecs         <- drop (i+2) args
    -- Must be fully applied
    , length vecs  == i
    , names        <- getNames vecs
    -- Each name must also be a bound variable
    , length names == i
    -> Just (n, ConEqual names)

   OpVectorFilter
    | [_tyA, _p, XVar (UName vec)] <- args
    -> Just (n, ConFiltered vec)

   OpVectorGenerate
   -- Not really sure about this
    -> Just (n, ConEqual [])

   OpVectorReduce
    | [_tyA, _f, _z, XVar (UName vec)] <- args
    -> Just (n, ConEqual [vec])

   OpVectorLength
    | [_tyA, XVar (UName vec)] <- args
    -> Just (n, ConEqual [vec])

   _
    -> Nothing

 | otherwise
 = Nothing

-- | Get bound name for each expression
-- All expressions must be variables of bound names,
-- otherwise result list will be shorter than input.
getNames :: [ExpF] -> [Name]
getNames vs
 = concatMap get vs
 where
  get x
   | XVar (UName v) <- x
   = [v]
   | otherwise
   = []


-- | Check for ill-formed constraints:
---
--      Filter 'a <= a' is bad, as restricts to 'a=a'
--      Filter 'a <= b' and 'a <= c' is bad because a mentioned twice in lhs
-- For some filter
-- > bs = filter p as
-- the arguments are
-- > (canon bs, canon as,  bs, as)
-- the 'raw' variable names bs and as are only used for error messages;
-- comparisons are done on canonical names.
checkFilters :: [(Name,Name, Name,Name)] -> LogFailures ()
checkFilters cs
 = go cs
 where
  go []
   = return ()
  go ((lc,rc, ln, rn):cs')
   = do when (lc == rc) $
          warn $ FailConstraintFilteredLessFiltered ln rn
        -- Check against later ones
        forM_ cs' $ \(lc', _, ln', _) ->
          when (lc == lc') $
            warn $ FailConstraintFilteredNotUnique  ln ln'

        go cs'



{-

f = \(as : Vector a).
    as'    = vmap [:a b:] g as
    return as'

==>
[as=as']
==>

f = \(as : Vector a).
    runSeries as /\(k1 : Rate). \(asS : Series k1 a).
    as'    = valloc [:k1 b:]
    as'S   = smap   [:k1 a b:] g asS
    sfill [:k1 b:] as' as'S
    return as'

---

f = \(as : Vector a).
    as'    = vmap [:a b:] g as
    as''   = vmap [:b b:] h as'
    return as''

==>
[as = as' = as'']
==>

f = \(as : Vector a).
    runSeries as /\(k1 : Rate). \(asS : Series k1 a).
    as'S   = smap   [:k1 a b:] g asS
    as''   = valloc [:k1 c:]
    as''S  = smap   [:k1 b c:] h as'S
    sfill [:k1 b:] as'' as''S
    return as''

---

f = \(as : Vector a).
    as' = filter p as
    n   = length   as'
    ns  = map (/n) as'

==>
[as' <= as
,ns   = as']
==>

f = \(as : Vector a).
    runSeries as /\(k1 : Rate). \(asS : Series k1 a).
    as'F = smap [:k1 a Bool:] p asS
    mkSel [:k1:] as'F /\(k2 : Rate). \(as'Se : Sel k1 k2).
    as'S = spack   [:k1 k2 a:] as'Se asS
    n    = slength [:k2:]
    nsS  = smap    [:k2 a a:] (/n) as'S
    ns   = valloc  [:k2 a:]
    sfill [:k2 a:] ns nsS

    return ns

---

f = \(as : Vector a).
    bs = filter p as
    cs = map2   f as bs
    return cs

==>
[bs <= as
,cs = as = bs]
==>
[as <= as]
Error!

---

f = \(as bs : Vector a).
    cs = filter p as
    ds = filter p bs
    es = map2   f cs ds
    return es

==>
[cs <= as
,ds <= bs
,cs=ds=es]
==>
[cs <= as
,cs <= bs]
Error, cs mentioned twice in lhs!
-}