{-# LANGUAGE DeriveGeneric             #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE PatternGuards             #-}

module Language.Fixpoint.Solver.Rewrite
  ( getRewrite
  , subExprs
  , unify
  , RewriteArgs(..)
  , RWTerminationOpts(..)
  , SubExpr
  , TermOrigin(..)
  ) where

import           Control.Monad.State
import           Control.Monad.Trans.Maybe
import           GHC.Generics
import           Data.Hashable
import qualified Data.HashMap.Strict  as M
import qualified Data.HashSet         as S
import qualified Data.List            as L
import qualified Data.Maybe           as Mb
import           Language.Fixpoint.Types hiding (simplify)
import qualified Data.Text as TX

type Op = Symbol
type OpOrdering = [Symbol]
data Term = Term Symbol [Term] deriving (Eq, Generic)
instance Hashable Term

termSym :: Term -> Symbol
termSym (Term s _) = s

instance Show Term where
  show (Term op [])   = TX.unpack $ symbolText op
  show (Term op args) =
    TX.unpack (symbolText op) ++ "(" ++ L.intercalate ", " (map show args) ++ ")"

data SCDir =
    SCUp
  | SCEq
  | SCDown
  deriving (Eq, Ord, Show, Generic)

instance Hashable SCDir

type SCPath = ((Op, Int), (Op, Int), [SCDir])
type SubExpr = (Expr, Expr -> Expr)

data SCEntry = SCEntry {
    from :: (Op, Int)
  , to   :: (Op, Int)
  , dir  :: SCDir
} deriving (Eq, Ord, Show, Generic)

instance Hashable SCEntry

getDir :: OpOrdering -> Term -> Term -> SCDir
getDir o from to =
  case (synGTE o from to, synGTE o to from) of
      (True, True)  -> SCEq
      (True, False) -> SCDown
      (False, _)    -> SCUp

getSC :: OpOrdering -> Term -> Term -> S.HashSet SCEntry
getSC o (Term op ts) (Term op' us) =
  S.fromList $ do
    (i, from) <- zip [0..] ts
    (j, to)   <- zip [0..] us
    return $ SCEntry (op, i) (op', j) (getDir o from to)

scp :: OpOrdering -> [Term] -> S.HashSet SCPath
scp _ []       = S.empty
scp _ [_]      = S.empty
scp o [t1, t2] = S.fromList $ do
  (SCEntry a b d) <- S.toList $ getSC o t1 t2
  return (a, b, [d])
scp o (t1:t2:trms) = S.fromList $ do
  (SCEntry a b' d) <- S.toList $ getSC o t1 t2
  (a', b, ds)      <- S.toList $ scp o (t2:trms)
  guard $ b' == a'
  return (a, b, d:ds)

synEQ :: OpOrdering -> Term -> Term -> Bool
synEQ o l r = synGTE o l r && synGTE o r l

opGT :: OpOrdering -> Op -> Op -> Bool
opGT ordering op1 op2 = case (L.elemIndex op1 ordering, L.elemIndex op2 ordering) of
  (Just index1, Just index2) -> index1 < index2
  (Just _, Nothing)          -> True
  _                          -> False

removeSynEQs :: OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs _ [] ys      = ([], ys)
removeSynEQs ordering (x:xs) ys
  | Just yIndex <- L.findIndex (synEQ ordering x) ys
  = removeSynEQs ordering xs $ take yIndex ys ++ drop (yIndex + 1) ys
  | otherwise =
    let
      (xs', ys') = removeSynEQs ordering xs ys
    in
      (x:xs', ys')

synGTEM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTEM ordering xs ys =
  case removeSynEQs ordering xs ys of
    (_   , []) -> True
    (xs', ys') -> any (\x -> all (synGT ordering x) ys') xs'

synGT :: OpOrdering -> Term -> Term -> Bool
synGT o t1 t2 = synGTE o t1 t2 && not (synGTE o t2 t1)

synGTM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTM o t1 t2 = synGTEM o t1 t2 && not (synGTEM o t2 t1)

synGTE :: OpOrdering -> Term -> Term -> Bool
synGTE ordering t1@(Term x tms) t2@(Term y tms') =
  if opGT ordering x y then
    synGTM ordering [t1] tms'
  else if opGT ordering y x then
    synGTEM ordering tms [t2]
  else
    synGTEM ordering tms tms'

subsequencesOfSize :: Int -> [a] -> [[a]]
subsequencesOfSize n xs = let l = length xs
                          in if n>l then [] else subsequencesBySize xs !! (l-n)
 where
   subsequencesBySize [] = [[[]]]
   subsequencesBySize (x:xs) = let next = subsequencesBySize xs
                             in zipWith (++) ([]:next) (map (map (x:)) next ++ [[]])

data TermOrigin = PLE | RW OpOrdering deriving (Show, Eq)

data DivergeResult = Diverging | NotDiverging OpOrdering

fromRW :: TermOrigin -> Bool
fromRW (RW _) = True
fromRW PLE    = False

getOrdering :: TermOrigin -> Maybe OpOrdering
getOrdering (RW o) = Just o
getOrdering PLE    = Nothing

diverges :: Maybe Int -> [(Term, TermOrigin)] -> Term -> DivergeResult
diverges maxOrderingConstraints path term = go 0
  where
   path' = map fst path ++ [term]
   go n |    n > length syms'
          || n > Mb.fromMaybe (length syms') maxOrderingConstraints = Diverging
   go n = case L.find (not . diverges') (orderings' n) of
     Just ordering -> NotDiverging ordering
     Nothing       -> go (n + 1)
   ops (Term o xs) = o:concatMap ops xs
   syms'           = L.nub $ concatMap ops path'
   suggestedOrderings :: [OpOrdering]
   suggestedOrderings =
     reverse $ Mb.catMaybes $ map (getOrdering . snd) path
   orderings' n    =
     suggestedOrderings ++ concatMap L.permutations ((subsequencesOfSize n) syms')
   diverges' o     = divergesFor o path term

divergesFor :: OpOrdering -> [(Term, TermOrigin)] -> Term -> Bool
divergesFor o path term = any diverges' terms'
  where
    terms = map fst path ++ [term]
    lastRWIndex =
      Mb.fromMaybe 0 (fmap fst $ L.find (fromRW . snd . snd) $ reverse $ zip [1..] path)
    okTerms    = take lastRWIndex terms
    checkTerms = drop lastRWIndex terms
    terms' = L.subsequences checkTerms ++ do
      firstpart  <- L.tails okTerms
      secondpart <- L.inits checkTerms
      return $ firstpart ++ secondpart
    diverges' :: [Term] -> Bool
    diverges' trms' =
      if length trms' <= 1 || termSym (head trms') /= termSym (last trms') then
        False
      else
        any ascending (scp o trms') && all (not . descending) (scp o trms')

descending :: SCPath -> Bool
descending (a, b, ds) = a == b && L.elem SCDown ds && L.notElem SCUp ds

ascending :: SCPath -> Bool
ascending  (a, b, ds) = a == b && L.elem SCUp ds

data RWTerminationOpts =
    RWTerminationCheckEnabled (Maybe Int) -- # Of constraints to consider
  | RWTerminationCheckDisabled

data RewriteArgs = RWArgs
 { isRWValid          :: Expr -> IO Bool
 , rwTerminationOpts  :: RWTerminationOpts
 }

getRewrite :: RewriteArgs -> [(Expr, TermOrigin)] -> SubExpr -> AutoRewrite -> MaybeT IO (Expr, TermOrigin)
getRewrite rwArgs path (subE, toE) (AutoRewrite args lhs rhs) =
  do
    su <- MaybeT $ return $ unify freeVars lhs subE
    let subE' = subst su rhs
    let expr' = toE subE'
    guard $ all ( (/= expr') . fst) path
    mapM_ (check . subst su) exprs
    let termPath = map (\(t, o) -> (convert t, o)) path
    case rwTerminationOpts rwArgs of
      RWTerminationCheckEnabled maxConstraints ->
        case diverges maxConstraints termPath (convert expr') of
          NotDiverging opOrdering  ->
            return (expr', RW opOrdering)
          Diverging ->
            mzero
      RWTerminationCheckDisabled -> return (expr', RW [])
  where

    convert (EIte i t e) = Term "$ite" $ map convert [i,t,e]
    convert (EApp (EVar s) (EVar var))
      | dcPrefix `isPrefixOfSym` s
      = Term (symbol $ TX.concat [symbolText s, "$", symbolText var]) []

    convert e@(EApp{})    | (EVar fName, terms) <- splitEApp e
                          = Term fName $ map convert terms
    convert (EVar s)      = Term s []
    convert (PAnd es)     = Term "$and" $ map convert es
    convert (POr es)      = Term "$or" $ map convert es
    convert (PAtom s l r) = Term (symbol $ "$atom" ++ show s) [convert l, convert r]
    convert (EBin o l r)  = Term (symbol $ "$ebin" ++ show o) [convert l, convert r]
    convert (ECon c)      = Term (symbol $ "$econ" ++ show c) []
    convert e             = error (show e)


    check :: Expr -> MaybeT IO ()
    check e = do
      valid <- MaybeT $ Just <$> isRWValid rwArgs e
      guard valid

    dcPrefix = "lqdc"

    freeVars = [s | RR _ (Reft (s, _)) <- args ]
    exprs    = [e | RR _ (Reft (_, e)) <- args ]

subExprs :: Expr -> [SubExpr]
subExprs e = (e,id):subExprs' e

subExprs' :: Expr -> [SubExpr]
subExprs' (EIte c lhs rhs)  = c'' ++ l'' ++ r''
  where
    c' = subExprs c
    l' = subExprs lhs
    r' = subExprs rhs
    c'' = map (\(e, f) -> (e, \e' -> EIte (f e') lhs rhs)) c'
    l'' = map (\(e, f) -> (e, \e' -> EIte c (f e') rhs)) l'
    r'' = map (\(e, f) -> (e, \e' -> EIte c lhs (f e'))) r'

subExprs' (EBin op lhs rhs) = lhs'' ++ rhs''
  where
    lhs' = subExprs lhs
    rhs' = subExprs rhs
    lhs'' :: [SubExpr]
    lhs'' = map (\(e, f) -> (e, \e' -> EBin op (f e') rhs)) lhs'
    rhs'' :: [SubExpr]
    rhs'' = map (\(e, f) -> (e, \e' -> EBin op lhs (f e'))) rhs'

subExprs' (PImp lhs rhs) = lhs'' ++ rhs''
  where
    lhs' = subExprs lhs
    rhs' = subExprs rhs
    lhs'' :: [SubExpr]
    lhs'' = map (\(e, f) -> (e, \e' -> PImp (f e') rhs)) lhs'
    rhs'' :: [SubExpr]
    rhs'' = map (\(e, f) -> (e, \e' -> PImp lhs (f e'))) rhs'

subExprs' (PAtom op lhs rhs) = lhs'' ++ rhs''
  where
    lhs' = subExprs lhs
    rhs' = subExprs rhs
    lhs'' :: [SubExpr]
    lhs'' = map (\(e, f) -> (e, \e' -> PAtom op (f e') rhs)) lhs'
    rhs'' :: [SubExpr]
    rhs'' = map (\(e, f) -> (e, \e' -> PAtom op lhs (f e'))) rhs'

-- subExprs' e@(EApp{}) = concatMap replace indexedArgs
--   where
--     (f, es)          = splitEApp e
--     indexedArgs      = zip [0..] es
--     replace (i, arg) = do
--       (subArg, toArg) <- subExprs arg
--       return (subArg, \subArg' -> eApps f $ (take i es) ++ (toArg subArg'):(drop (i+1) es))

subExprs' _ = []

unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll _ []     []               = Just (Su M.empty)
unifyAll freeVars (template:xs) (seen:ys) =
  do
    rs@(Su s1) <- unify freeVars template seen
    let xs' = map (subst rs) xs
    let ys' = map (subst rs) ys
    (Su s2) <- unifyAll (freeVars L.\\ M.keys s1) xs' ys'
    return $ Su (M.union s1 s2)
unifyAll _ _ _ = undefined

unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify _ template seenExpr | template == seenExpr = Just (Su M.empty)
unify freeVars template seenExpr = case (template, seenExpr) of
  (EVar rwVar, _) | rwVar `elem` freeVars ->
    return $ Su (M.singleton rwVar seenExpr)
  (EApp templateF templateBody, EApp seenF seenBody) ->
    unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
  (ENeg rw, ENeg seen) ->
    unify freeVars rw seen
  (EBin op rwLeft rwRight, EBin op' seenLeft seenRight) | op == op' ->
    unifyAll freeVars [rwLeft, rwRight] [seenLeft, seenRight]
  (EIte cond rwLeft rwRight, EIte seenCond seenLeft seenRight) ->
    unifyAll freeVars [cond, rwLeft, rwRight] [seenCond, seenLeft, seenRight]
  (ECst rw _, ECst seen _) ->
    unify freeVars rw seen
  (ETApp rw _, ETApp seen _) ->
    unify freeVars rw seen
  (ETAbs rw _, ETAbs seen _) ->
    unify freeVars rw seen
  (PAnd rw, PAnd seen ) ->
    unifyAll freeVars rw seen
  (POr rw, POr seen ) ->
    unifyAll freeVars rw seen
  (PNot rw, PNot seen) ->
    unify freeVars rw seen
  (PImp templateF templateBody, PImp seenF seenBody) ->
    unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
  (PIff templateF templateBody, PIff seenF seenBody) ->
    unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
  (PAtom rel templateF templateBody, PAtom rel' seenF seenBody) | rel == rel' ->
    unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
  (PAll _ rw, PAll _ seen) ->
    unify freeVars rw seen
  (PExist _ rw, PExist _ seen) ->
    unify freeVars rw seen
  (PGrad _ _ _ rw, PGrad _ _ _ seen) ->
    unify freeVars rw seen
  (ECoerc _ _ rw, ECoerc _ _ seen) ->
    unify freeVars rw seen
  _ -> Nothing