{-# LANGUAGE PatternGuards    #-}
{-# LANGUAGE FlexibleContexts #-}

module Language.Fixpoint.Solver.Eliminate
       (eliminateAll, elimKVar, findWfC) where

import           Language.Fixpoint.Types
import qualified Language.Fixpoint.Solver.Deps as D
import           Language.Fixpoint.Visitor (kvars, mapKVars')
import           Language.Fixpoint.Names   (existSymbol)
import           Language.Fixpoint.Misc    (errorstar)

import qualified Data.HashMap.Strict           as M
import           Data.List (partition, (\\))
import           Data.Foldable (foldlM)
import           Control.Monad.State (get, put, runState, evalState, State)


--------------------------------------------------------------
eliminateAll :: FInfo a -> FInfo a
eliminateAll fi = evalState (foldlM eliminate fi nonCuts) 0
  where
    nonCuts = D.depNonCuts $ D.deps fi
--------------------------------------------------------------


class Elimable a where
  elimKVar :: ((KVar, Subst) -> Maybe Pred) -> a -> a

instance Elimable (SubC a) where
  elimKVar f x = x { slhs = elimKVar f (slhs x) 
                   , srhs = elimKVar f (srhs x)
                   }

instance Elimable SortedReft where
  elimKVar f x = x { sr_reft = mapKVars' f (sr_reft x) }

instance Elimable (FInfo a) where
  elimKVar f x = x { cm = M.map (elimKVar f) (cm x)
                   , bs = elimKVar f (bs x)
                   }

instance Elimable BindEnv where
  elimKVar f = mapBindEnv (\(sym, sr) -> (sym, elimKVar f sr))


eliminate :: FInfo a -> KVar -> State Integer (FInfo a)
eliminate fi kv = do
  let relevantSubCs  = M.filter (   elem kv . D.rhsKVars) (cm fi)
  let remainingSubCs = M.filter (notElem kv . D.rhsKVars) (cm fi)
  let (kvWfC, remainingWs) = findWfC kv (ws fi)
  foo <- mapM (extractPred kvWfC (bs fi)) (M.elems relevantSubCs)
  let orPred = POr $ map fst foo
  let symSrtList = concatMap snd foo
  let symSReftList = [(sym, trueSortedReft srt) | (sym, srt) <- symSrtList]
  let (ids, be) = insertsBindEnv symSReftList $ bs fi
  let newSubCs = M.map (\s -> s { senv = insertsIBindEnv ids (senv s)}) remainingSubCs
  let go (k, _) = if kv == k then Just orPred else Nothing
  return $ elimKVar go (fi { cm = newSubCs , ws = remainingWs , bs = be })

insertsBindEnv :: [(Symbol, SortedReft)] -> BindEnv -> ([BindId], BindEnv)
insertsBindEnv = runState . mapM go
  where
    go (sym, srft) = do be <- get
                        let (id, be') = insertBindEnv sym srft be
                        put be'
                        return id

findWfC :: KVar -> [WfC a] -> (WfC a, [WfC a])
findWfC kv ws = (w', ws')
  where
    (w, ws') = partition (elem kv . kvars . sr_reft . wrft) ws
    w' | [x] <- w  = x
       | otherwise = errorstar $ (show kv) ++ " needs exactly one wf constraint"

extractPred :: WfC a -> BindEnv -> SubC a -> State Integer (Pred, [(Symbol, Sort)])
extractPred wfc be subC = do foo <- mapM renameVar vars
                             let (bs, subs) = unzip foo
                             return (subst (mkSubst subs) finalPred, bs)
  where
    wfcIBinds  = elemsIBindEnv $ wenv wfc
    subcIBinds = elemsIBindEnv $ senv subC
    unmatchedIBinds = subcIBinds \\ wfcIBinds
    unmatchedIBindEnv = insertsIBindEnv unmatchedIBinds emptyIBindEnv
    unmatchedBindings = envCs be unmatchedIBindEnv
    lhs = slhs subC
    (vars, prList) = baz $ (reftBind $ sr_reft lhs, lhs) : unmatchedBindings

    suPreds = substPreds (domain be wfc) $ reftPred $ sr_reft $ srhs subC
    finalPred = PAnd $ prList ++ suPreds

-- on rhs, $k0[v:=e1][x:=e2] -> [v = e1, x = e2]
substPreds :: [Symbol] -> Pred -> [Pred]
substPreds dom (PKVar _ (Su subs)) = [PAtom Eq (eVar sym) expr | (sym, expr) <- subs , sym `elem` dom]

domain :: BindEnv -> WfC a -> [Symbol]
domain be wfc = (reftBind $ sr_reft $ wrft wfc) : (map fst $ envCs be $ wenv wfc)

renameVar :: (Symbol, Sort) -> State Integer ((Symbol, Sort), (Symbol, Expr))
renameVar (sym, srt) = do n <- get
                          let sym' = existSymbol sym n
                          put (n+1)
                          return ((sym', srt), (sym, eVar sym'))

-- [ x:{v:int|v=10} , y:{v:int|v=20} ] -> [x:int, y:int], [(x=10), (y=20)]
baz :: [(Symbol, SortedReft)] -> ([(Symbol,Sort)],[Pred])
baz = unzip . map blah

blah :: (Symbol, SortedReft) -> ((Symbol,Sort), Pred)
blah (sym, sr) = ((sym, sr_sort sr), subst1 (reftPred reft) sub)
  where
    reft = sr_reft sr
    sub = ((reftBind reft), (eVar sym))