{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE PatternGuards        #-}
{-# LANGUAGE FlexibleContexts     #-}

module Language.Fixpoint.Solver.Extensionality (expand) where

import           Control.Monad.State
import qualified Data.HashMap.Strict       as M
import           Data.Maybe  (fromMaybe)

import           Language.Fixpoint.Types.Config
import           Language.Fixpoint.SortCheck
import           Language.Fixpoint.Solver.Sanitize (symbolEnv)
import           Language.Fixpoint.Types hiding (mapSort)
import           Language.Fixpoint.Types.Visitor ( (<$$>), mapSort )

mytracepp :: (PPrint a) => String -> a -> a
mytracepp = notracepp

expand :: Config -> SInfo a -> SInfo a
expand cfg si = evalState (extend si) $ initST (symbolEnv cfg si) (ddecls si)


class Extend a where
  extend :: a -> Ex a


instance Extend (SInfo a) where
  extend si = do
    setBEnv (bs si)
    cm'      <- extend (cm si)
    bs'      <- exbenv <$> get
    return $ si{ cm = cm' , bs = bs' }

instance (Extend a) => Extend (M.HashMap SubcId a) where
  extend h = M.fromList <$> mapM extend (M.toList h)

instance (Extend a, Extend b) => Extend (a,b) where
  extend (a,b) = (,) <$> extend a <*> extend b

instance Extend SubcId where
  extend i = return i

instance Extend (SimpC a) where
  extend c = do
    setExBinds (_cenv c)
    rhs <- extendExpr Pos (_crhs c)
    is  <- exbinds <$> get
    return $ c{_crhs = rhs, _cenv = is }


extendExpr :: Pos -> Expr -> Ex Expr
extendExpr p e
  | p == Pos
  = mapMPosExpr Pos goP e' >>= mapMPosExpr Pos goN
  | otherwise
  = mapMPosExpr Neg goP e' >>= mapMPosExpr Neg goN
    where
      e' = normalize e
      goP Pos (PAtom b e1 e2)
       | b == Eq || b == Ne
       , Just s <- getArg (exprSort "extensionality" e1)
       = mytracepp ("extending POS = " ++ showpp e) <$> (extendRHS b e1 e2 s >>= goP Pos)
      goP _ e = return e
      goN Neg (PAtom b e1 e2)
       | b == Eq || b == Ne
       , Just s <- getArg (exprSort "extensionality" e1)
       = mytracepp ("extending NEG = " ++ showpp e) <$> (extendLHS b e1 e2 s >>= goN Neg)
      goN _ e = return e

getArg :: Sort -> Maybe Sort
getArg s = case bkFFunc s of
             Just (_,(a:_:_)) -> Just a
             _                -> Nothing

extendRHS, extendLHS :: Brel -> Expr -> Expr -> Sort -> Ex Expr
extendRHS b e1 e2 s =
  do es <- generateArguments s
     (mytracepp "extendRHS = " . pAnd) <$> mapM (makeEq b e1 e2) es

extendLHS b e1 e2 s =
  do es  <- generateArguments s
     dds <- exddecl <$> get
     is  <- instantiate dds s
     (mytracepp "extendLHS = " . pAnd . (PAtom b e1 e2:)) <$> mapM (makeEq b e1 e2) (es ++ is)


generateArguments :: Sort -> Ex [Expr]
generateArguments s = do
  st   <- get
  case breakSort (exddecl st) s of
    Left dds -> mapM freshArgDD dds
    Right s  -> (\x -> [EVar x]) <$> freshArgOne s

makeEq :: Brel-> Expr -> Expr -> Expr -> Ex Expr
makeEq b e1 e2 e = do
  env <- exenv <$> get
  let elab = elaborate (dummyLoc "extensionality") env
  return $ PAtom b (elab $ EApp (unElab e1) e)  (elab $ EApp (unElab e2) e)

instantiate :: [DataDecl]  -> Sort -> Ex [Expr]
instantiate ds s = instantiateOne (breakSort ds s)

instantiateOne :: Either [(LocSymbol, [Sort])] Sort  -> Ex [Expr]
instantiateOne (Right s@(FVar _)) =
  (\x -> [EVar x]) <$> freshArgOne s
instantiateOne (Right s) = do
  xss <- excbs <$> get
  return [EVar x | (x,xs) <- xss, xs == s ]
instantiateOne (Left [(dc, ts)]) =
  (map (mkEApp dc) . combine) <$>  mapM instantiateOne (Right <$> ts)
instantiateOne _ = undefined

combine :: [[a]] -> [[a]]
combine []          = [[]]
combine ([]:_)      = []
combine ((x:xs):ys) = map (x:) (combine ys) ++ combine (xs:ys)


data Pos = Pos | Neg deriving Eq
negatePos :: Pos -> Pos
negatePos Pos = Neg
negatePos Neg = Pos

mapMPosExpr :: (Monad m) => Pos -> (Pos -> Expr -> m Expr) -> Expr -> m Expr
mapMPosExpr p f = go p
  where
    go p e@(ESym _)      = f p e
    go p e@(ECon _)      = f p e
    go p e@(EVar _)      = f p e
    go p e@(PKVar _ _)   = f p e
    go p (ENeg e)        = f p =<< (ENeg        <$>  go p e                     )
    go p (ECst e t)      = f p =<< ((`ECst` t)  <$>  go p e                     )
    go p (ECoerc a t e)  = f p =<< (ECoerc a t  <$>  go p e                     )
    go p (EApp g e)      = f p =<< (EApp        <$>  go p g  <*> go p e             )
    go p (EBin o e1 e2)  = f p =<< (EBin o      <$>  go p e1 <*> go p e2            )
    go p (PAtom r e1 e2) = f p =<< (PAtom r     <$>  go p e1 <*> go p e2            )

    go p (PImp p1 p2)    = f p =<< (PImp        <$>  go (negatePos p) p1 <*> go p p2)
    go p (PAnd ps)       = f p =<< (PAnd        <$> (go p <$$> ps)                  )

    -- The below cannot appear due to normalization
    go p (PNot e)        = f p =<< (PNot        <$>  go p e                     )
    go p (PIff p1 p2)    = f p =<< (PIff        <$>  go p p1 <*> go p p2            )
    go p (EIte e e1 e2)  = f p =<< (EIte        <$>  go p e  <*> go p e1 <*> go p e2)
    go p (POr  ps)       = f p =<< (POr         <$> (go p <$$> ps)                  )

    -- The following canot appear in general 
    go p (PAll xts e)    = f p =<< (PAll   xts  <$>  go p e                     )
    go p (ELam (x,t) e)  = f p =<< (ELam (x,t)  <$>  go p e                     )
    go p (PExist xts e)  = f p =<< (PExist xts  <$>  go p e                     )
    go p (ETApp e s)     = f p =<< ((`ETApp` s) <$>  go p e                     )
    go p (ETAbs e s)     = f p =<< ((`ETAbs` s) <$>  go p e                     )
    go p (PGrad k s i e) = f p =<< (PGrad k s i <$>  go p e                     )

normalize :: Expr -> Expr
normalize e = mytracepp ("normalize: " ++ showpp e) $ go e
  where
    go e@(ESym _)        = e
    go e@(ECon _)        = e
    go e@(EVar _)        = e
    go e@(PKVar _ _)     = e
    go e@(ENeg _)        = e
    go (PNot e)          = PImp e PFalse
    go e@(ECst _ _)      = e
    go e@(ECoerc _ _ _)  = e
    go e@(EApp _ _)      = e
    go e@(EBin _ _ _)    = e
    go (PImp p1 p2)      = PImp (go p1) (go p2)
    go (PIff p1 p2)      = PAnd [PImp p1' p2', PImp p2' p1'] where (p1', p2') = (go p1, go p2)
    go e@(PAtom _ _ _)   = e
    go (EIte e e1 e2)    = go $ PAnd [PImp e e1, PImp (PNot e) e2]
    go (PAnd ps)         = pAnd (go <$> ps)
    go (POr  ps)         = foldl (\x y -> PImp (PImp (go x) PFalse) y) PFalse ps
    go e@(PAll _ _)      = e -- Cannot appear
    go e@(ELam _ _)      = e -- Cannot appear
    go e@(PExist _ _)    = e -- Cannot appear
    go e@(ETApp _ _)     = e -- Cannot appear
    go e@(ETAbs _ _)     = e -- Cannot appear
    go e@(PGrad _ _ _ _) = e -- Cannot appear


type Ex    = State ExSt
data ExSt = ExSt { unique  :: Int
                 , exddecl :: [DataDecl]
                 , exenv   :: SymEnv        -- used for elaboration 
                 , exbenv  :: BindEnv
                 , exbinds :: IBindEnv
                 , excbs   :: [(Symbol, Sort)]
                 }

initST :: SymEnv -> [DataDecl]  -> ExSt
initST env dd = ExSt 0 (d:dd) env mempty mempty mempty
  where
    -- NV: hardcore Haskell pairs because they do not appear in DataDecl (why?)
    d = mytracepp "Tuple DataDecl" $ DDecl (symbolFTycon (dummyLoc tupConName)) 2 [ct]
    ct = DCtor (dummyLoc (symbol "GHC.Tuple.(,)")) [
            DField (dummyLoc (symbol "lqdc$select$GHC.Tuple.(,)$1")) (FVar 0)
          , DField (dummyLoc (symbol "lqdc$select$GHC.Tuple.(,)$2")) (FVar 1)
          ]


setBEnv :: BindEnv -> Ex ()
setBEnv benv = modify (\st -> st{exbenv = benv})

setExBinds :: IBindEnv-> Ex ()
setExBinds bids = modify (\st -> st{ exbinds = bids
                                   , excbs   = [ (x, sr_sort r) | (i, x, r) <- bindEnvToList (exbenv st)
                                                                , memberIBindEnv i bids]})


freshArgDD :: (LocSymbol, [Sort]) -> Ex Expr
freshArgDD (dc, xs) = do
  xs <- mapM freshArgOne xs
  return $ mkEApp dc (EVar <$> xs)


freshArgOne :: Sort -> Ex Symbol
freshArgOne s = do
  st   <- get
  let x = symbol ("ext$" ++ show (unique st))
  let (id, benv') = insertBindEnv x (trueSortedReft s) (exbenv st)
  modify (\st -> st{ exenv   = insertSymEnv x s (exenv st)
                   , exbenv  = benv'
                   , exbinds = insertsIBindEnv [id] (exbinds st)
                   , unique   = 1 + (unique st)
                   , excbs = (x,s):(excbs st)
                   })
  return x


breakSort :: [DataDecl] -> Sort -> Either [(LocSymbol, [Sort])] Sort
breakSort ddecls s
    | Just (tc, ts) <- splitTC s
    , [(dds,i)] <- [ (ddCtors dd,ddVars dd) | dd <- ddecls, ddTyCon dd == tc ]
    = Left ((\dd -> (dcName dd, (instSort  (Sub $ zip [0..(i-1)] ts)) <$> dfSort <$> dcFields dd)) <$> dds)
    | otherwise
    = Right s

instSort :: Sub -> Sort -> Sort
instSort (Sub su) x = mapSort go x
  where
    go :: Sort -> Sort
    go (FVar i) = fromMaybe (FVar i) $ lookup i su
    go s        = s

splitTC :: Sort -> Maybe (FTycon, [Sort])
splitTC s
     | (FTC f, ts) <- splitFApp s
     = Just (f, ts)
     | otherwise
     = Nothing

splitFApp :: Sort -> (Sort, [Sort])
splitFApp = go []
    where go acc (FApp s1 s2) = go (s2:acc) s1
          go acc s            = (s, acc)