-- | Rules
module SSTG.Core.Execution.Rules
    ( Rule(..)
    , reduce
    , isStateValueForm
    ) where

import SSTG.Core.Syntax
import SSTG.Core.Execution.Models
import SSTG.Core.Execution.Naming

-- | Rules
data Rule = RuleAtomLit | RuleAtomLitPtr | RuleAtomValPtr | RuleAtomUnInt
          | RulePrimApp
          | RuleConApp
          | RuleFunAppExact | RuleFunAppUnder  | RuleFunAppSym
                            | RuleFunAppConPtr | RuleFunAppUnInt
          | RuleLet
          | RuleCaseLit | RuleCaseConPtr | RuleCaseAnyLit | RuleCaseAnyConPtr
                        | RuleCaseSym

          | RuleUpdateCThunk
          | RuleUpdateDLit | RuleUpdateDValPtr

          | RuleCaseCCaseNonVal
          | RuleCaseDLit | RuleCaseDValPtr

          | RuleApplyCFunThunk  | RuleApplyCFunAppOver
          | RuleApplyDReturnFun | RuleApplyDReturnSym

          | RuleIdentity
          deriving (Show, Eq, Read, Ord)

--  Stack Independent Rules

-- | Is Heap Normal Form?
--   Does not include LitObj. i.e. if something points to this, nothing to do.
isHeapValueForm :: HeapObj -> Bool
isHeapValueForm (SymObj _)         = True
isHeapValueForm (ConObj _ _)       = True
isHeapValueForm (FunObj (_:_) _ _) = True
isHeapValueForm _                  = False

-- | Is Value Form
--   Either a lit or points to a heap value (not LitObj!). If we find nothing
--   in the heap, then this means we can still upcast the var to a symbolic.
isExprValueForm :: Expr -> Locals -> Globals -> Heap -> Bool
isExprValueForm (Atom (LitAtom _))   _      _       _    = True
isExprValueForm (Atom (VarAtom var)) locals globals heap =
    case vlookupHeap var locals globals heap of
        Just (_, hobj) -> isHeapValueForm hobj
        Nothing        -> False
isExprValueForm _                    _      _       _    = False

-- | Is State Value?
isStateValueForm :: State -> Bool
isStateValueForm State { state_stack = stack
                       , state_heap  = heap
                       , state_code  = code }
  | Stack []          <- stack
  , Return (LitVal _) <- code = True

  | Stack []             <- stack
  , Return (MemVal addr) <- code
  , Just hobj            <- lookupHeap addr heap
  , isHeapValueForm hobj = True

  | otherwise = False

-- | Value to Lit
valueToLit :: Value -> Lit
valueToLit (LitVal lit)  = lit
valueToLit (MemVal addr) = AddrLit (memAddrInt addr)

-- | Uneven Zipping
unevenZip :: [a] -> [b] -> ([(a, b)], Either [a] [b])
unevenZip as     []     = ([], Left as)
unevenZip []     bs     = ([], Right bs)
unevenZip (a:as) (b:bs) = ((a, b) : acc, excess)
  where (acc, excess) = unevenZip as bs

-- | Lift Action Wrap Type
data LiftAct a = LiftAct  a Locals Globals Heap [Name]

-- | Lift Uninterpreted Variable
liftUnInt :: LiftAct Var -> LiftAct MemAddr
liftUnInt (LiftAct var locals globals heap confs) = pass_out
  where sname    = freshSeededName (varName var) confs
        svar     = Var sname (varType var)
        (heap', addr) = allocHeap (SymObj (Symbol svar Nothing)) heap
        globals' = insertGlobals var (MemVal addr) globals
        confs'   = sname : confs
        pass_out = LiftAct addr locals globals' heap' confs'

-- | Lift Atom
liftAtom :: LiftAct Atom -> LiftAct Value
liftAtom (LiftAct atom locals globals heap confs) = pass_out
  where pass_out = LiftAct aval locals globals' heap' confs'
        (aval, globals', heap', confs') = case atom of
            LitAtom lit -> (LitVal lit, globals, heap, confs)
            VarAtom var -> case lookupValue var locals globals of
                Just val -> (val, globals, heap, confs)
                Nothing  -> let pass_in = LiftAct var locals globals heap confs
                                LiftAct addr _ g' h' c' = liftUnInt pass_in
                            in (MemVal addr, g', h', c')

-- | Lift Atom List
liftAtomList :: LiftAct [Atom] -> LiftAct [Value]
liftAtomList (LiftAct []        locals globals heap confs) = pass_out
  where pass_out  = LiftAct [] locals globals heap confs
liftAtomList (LiftAct (atom:as) locals globals heap confs) = pass_out
  where pass_in   = LiftAct atom locals globals heap confs
        LiftAct val locals' globals' heap' confs' = liftAtom pass_in
        pass_rest = LiftAct as locals' globals' heap' confs'
        LiftAct vs  localsf globalsf heapf confsf = liftAtomList pass_rest
        pass_out  = LiftAct (val : vs) localsf globalsf heapf confsf

-- | Lift Bind Rhs
liftBindRhs :: LiftAct BindRhs -> LiftAct HeapObj
liftBindRhs (LiftAct (FunForm prms expr) locals globals heap confs) = pass_out
  where pass_out = LiftAct (FunObj prms expr locals) locals globals heap confs
liftBindRhs (LiftAct (ConForm dcon args) locals globals heap confs) = pass_out
  where pass_in  = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtomList pass_in
        pass_out = LiftAct (ConObj dcon vals) locals' globals' heap' confs'

-- | Lift Bind Rhs List
liftBindRhsList :: LiftAct [BindRhs] -> LiftAct [HeapObj]
liftBindRhsList (LiftAct []       locals globals heap confs) = pass_out
  where pass_out  = LiftAct [] locals globals heap confs
liftBindRhsList (LiftAct (rhs:rs) locals globals heap confs) = pass_out
  where pass_in   = LiftAct rhs locals globals heap confs
        LiftAct hobj locals' globals' heap' confs' = liftBindRhs pass_in
        pass_rest = LiftAct rs locals' globals' heap' confs'
        LiftAct hos localsf globalsf heapf confsf = liftBindRhsList pass_rest
        pass_out  = LiftAct (hobj : hos) localsf globalsf heapf confsf

-- | Lift Binding
liftBinding :: LiftAct Binding -> LiftAct ()
liftBinding (LiftAct (Binding NonRec bnd) locals globals heap confs) = pass_out
  where pass_in  = LiftAct (map snd bnd) locals globals heap confs
        LiftAct hobjs locals' globals' heap' confs' = liftBindRhsList pass_in
        (heapf, addrs) = allocHeapList hobjs heap'
        mem_vals = map MemVal addrs
        localsf  = insertLocalsList (zip (map fst bnd) mem_vals) locals'
        pass_out = LiftAct () localsf globals' heapf confs'
liftBinding (LiftAct (Binding Rec bnd)    locals globals heap confs) = pass_out
  where hfakes   = map (const Blackhole) bnd
        -- Allocate dummy BLACKHOLEs
        (heap', addrs) = allocHeapList hfakes heap
        mem_vals = map MemVal addrs
        -- Use the reigstered loca BLACKHOLEs to construct the locals closure.
        locals'  = insertLocalsList (zip (map fst bnd) mem_vals) locals
        pass_in  = LiftAct (map snd bnd) locals' globals heap' confs
        LiftAct hobjs localsf globals' heap'' confs' = liftBindRhsList pass_in
        heapf    = insertHeapList (zip addrs hobjs) heap''
        pass_out = LiftAct () localsf globals' heapf confs'

-- | Default Alts
defaultAlts :: [Alt] -> [Alt]
defaultAlts alts = [a | a @ (Alt Default _ _) <- alts]

-- | AltCon Based Alts
altConAlts :: [Alt] -> [Alt]
altConAlts alts = [a | a @ (Alt acon _ _) <- alts, acon /= Default]

-- | Match Lit Alts
matchLitAlts :: Lit -> [Alt] -> [Alt]
matchLitAlts lit alts = [a | a @ (Alt (LitAlt alit) _ _) <- alts, lit == alit]

-- | Match Data Alts
matchDataAlts :: DataCon -> [Alt] -> [Alt]
matchDataAlts dc alts = [a | a @ (Alt (DataAlt adc) _ _) <- alts, dc == adc]

-- | Negate Path Cons
negatePathCons :: PathCons -> PathCons
negatePathCons pcs = map (\(PathCond a e l b) -> (PathCond a e l (not b))) pcs

-- | Lift Sym Alt
liftSymAlt :: LiftAct (Var, MemAddr, Var, Alt) -> LiftAct (Expr, PathCons)
liftSymAlt (LiftAct args locals globals heap confs) = pass_out
  where (mvar, addr, cvar, Alt ac params expr) = args
        snames   = freshSeededNameList (map varName params) confs
        svars    = map (\(p, n) -> Var n (varType p)) (zip params snames)
        hobjs    = map (\s -> SymObj (Symbol s Nothing)) svars
        (heap', addrs) = allocHeapList hobjs heap
        mem_vals = map MemVal addrs
        llist    = (cvar, MemVal addr) : zip params mem_vals
        locals'  = insertLocalsList llist locals
        mxpr     = Atom (VarAtom mvar)
        pcons    = [PathCond (ac, params) mxpr locals' True]
        confs'   = snames ++ confs
        pass_out = LiftAct (expr, pcons) locals' globals heap' confs'

-- | Alt Closure to State
liftedAltToState :: State -> LiftAct (Expr, PathCons) -> State
liftedAltToState state (LiftAct args locals globals heap confs) = state'
  where (expr, pcons) = args
        state' = state { state_heap    = heap
                       , state_globals = globals
                       , state_code    = Evaluate expr locals
                       , state_names   = confs
                       , state_paths   = pcons ++ state_paths state }

-- | Reduce
reduce :: State -> Maybe (Rule, [State])
reduce state @ State { state_stack   = stack
                     , state_heap    = heap
                     , state_globals = globals
                     , state_code    = code
                     , state_names   = confs }

  -- Stack Independent Rules

  -- Atom Lit
  | Evaluate (Atom (LitAtom lit)) _ <- code =
       Just (RuleAtomLit
            ,[state { state_code = Return (LitVal lit) }])

  -- Atom Lit Pointer
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (_, hobj) <- vlookupHeap var locals globals heap
  , LitObj lit     <- hobj =
       Just (RuleAtomLitPtr
            ,[state { state_code = Evaluate (Atom (LitAtom lit)) locals }])

  -- Rule Atom Val Pointer
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (addr, hobj) <- vlookupHeap var locals globals heap
  , isHeapValueForm hobj =
       Just (RuleAtomValPtr
            ,[state { state_code = Return (MemVal addr) }])

  -- Rule Atom Uninterpreted
  | Evaluate (Atom (VarAtom uvar)) locals <- code
  , Nothing <- vlookupHeap uvar locals globals heap =
    let pass_in = LiftAct uvar locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftUnInt pass_in
    in Just (RuleAtomUnInt
            ,[state { state_heap    = heap'
                    , state_globals = globals'
                    , state_code    = Evaluate (Atom (VarAtom uvar)) locals'
                    , state_names   = confs' }])

  -- Prim Function App
  | Evaluate (PrimApp pfun args) locals <- code =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtomList pass_in
        eval    = SymLitEval pfun (map valueToLit vals)
    in Just (RulePrimApp
            ,[state { state_heap    = heap'
                    , state_globals = globals'
                    , state_code    = Evaluate (Atom (LitAtom eval)) locals'
                    , state_names   = confs' }])

  -- Rule Con App
  | Evaluate (ConApp dcon args) locals <- code =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtomList pass_in
        (heapf, addr) = allocHeap (ConObj dcon vals) heap'
    in Just (RuleConApp
            ,[state { state_heap    = heapf
                    , state_globals = globals'
                    , state_code    = Return (MemVal addr)
                    , state_names   = confs' }])

  -- Rule Fun App Exact
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj)              <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , length params == length args =
    let pass_in   = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtomList pass_in
        fun_locs' = insertLocalsList (zip params vals) fun_locs
    in Just (RuleFunAppExact
            ,[state { state_heap    = heap'
                    , state_globals = globals'
                    , state_code    = Evaluate expr fun_locs'
                    , state_names   = confs' }])

  -- Rule Fun App Under
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj)              <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , (_, Left ex_ps)             <- unevenZip params args =
    let pass_in   = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtomList pass_in
        fun_locs' = insertLocalsList (zip params vals) fun_locs
        -- New Fun Object.
        pobj      = FunObj ex_ps expr fun_locs'
        (heapf, paddr) = allocHeap pobj heap'
    in Just (RuleFunAppUnder
            ,[state { state_heap    = heapf
                    , state_globals = globals'
                    , state_code    = Return (MemVal paddr)
                    , state_names   = confs' }])

  -- Rule Fun App Symbolic
  | Evaluate (FunApp sfun args) locals <- code
  , Just (_, hobj)         <- vlookupHeap sfun locals globals heap
  , SymObj (Symbol svar _) <- hobj =
    let sname = freshSeededName (varName svar) confs
        svar' = Var sname (foldl AppTy (varType svar) (map atomType args))
        sym   = Symbol svar' (Just (FunApp sfun args, locals))
        (heap', addr) = allocHeap (SymObj sym) heap
    in Just (RuleFunAppSym
            ,[state { state_heap  = heap'
                    , state_code  = Return (MemVal addr)
                    , state_names = sname : confs }])

  -- Rule Fun App ConObj
  | Evaluate (FunApp cvar []) locals <- code
  , Just (addr, hobj) <- vlookupHeap cvar locals globals heap
  , ConObj _ _        <- hobj =
       Just (RuleFunAppConPtr
            ,[state { state_code = Return (MemVal addr) }])

  -- Rule Fun App Uninterpreted
  | Evaluate (FunApp ufun args) locals <- code
  , Nothing <- vlookupHeap ufun locals globals heap =
    let pass_in = LiftAct ufun locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftUnInt pass_in
    in Just (RuleFunAppUnInt
            ,[state { state_heap    = heap'
                    , state_globals = globals'
                    , state_code    = Evaluate (FunApp ufun args) locals'
                    , state_names   = confs' }])

  -- Rule Let
  | Evaluate (Let bnd expr) locals <- code =
    let pass_in = LiftAct bnd locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftBinding pass_in
    in Just (RuleLet
            ,[state { state_heap    = heap'
                    , state_globals = globals'
                    , state_code    = Evaluate expr locals'
                    , state_names   = confs' }])

  -- Rule Case Lit
  | Evaluate (Case (Atom (LitAtom lit)) cvar alts) locals <- code
  , (Alt _ _ expr):_ <- matchLitAlts lit alts =
    let locals' = insertLocals cvar (LitVal lit) locals
    in Just (RuleCaseLit
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Con Pointer
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj)     <- vlookupHeap mvar locals globals heap
  , ConObj dcon vals      <- hobj
  , (Alt _ params expr):_ <- matchDataAlts dcon alts
  , length params == length vals =
    let llist   = (cvar, MemVal addr) : zip params vals
        locals' = insertLocalsList llist locals
    in Just (RuleCaseConPtr
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Any Lit
  | Evaluate (Case (Atom (LitAtom lit)) cvar alts) locals <- code
  , []               <- matchLitAlts lit alts
  , (Alt _ _ expr):_ <- defaultAlts alts =
    let locals' = insertLocals cvar (LitVal lit) locals
    in Just (RuleCaseAnyLit
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Any Con Pointer
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , ConObj dcon _     <- hobj
  , []                <- matchDataAlts dcon alts
  , (Alt _ _ expr):_  <- defaultAlts alts =
    let locals' = insertLocals cvar (MemVal addr) locals
    in Just (RuleCaseAnyConPtr
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Sym
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj)     <- vlookupHeap mvar locals globals heap
  , SymObj _              <- hobj
  , (acon_alts, def_alts) <- (altConAlts alts, defaultAlts alts)
  , length (acon_alts ++ def_alts) > 0 =
    let acon_ins   = map (\a -> LiftAct (mvar, addr, cvar, a)
                                        locals globals heap confs) acon_alts
        acon_lifts = map liftSymAlt acon_ins
        def_ins    = map (\a -> LiftAct (mvar, addr, cvar, a)
                                        locals globals heap confs) def_alts
        def_lifts  = map liftSymAlt def_ins
        -- Make AltCon states first.
        acon_sts   = map (liftedAltToState state) acon_lifts
        -- Make DEFAULT states next.
        all_pcons  = concatMap (\(LiftAct (_, pc) _ _ _ _) -> pc) acon_lifts
        negs       = negatePathCons all_pcons
        def_lifts' = map (\(LiftAct (e, _) l g h c) ->
                           (LiftAct (e, negs) l g h c)) def_lifts
        def_sts    = map (liftedAltToState state) def_lifts'
    in Just (RuleCaseSym, acon_sts ++ def_sts)

  -- Stack Dependent Rules

  -- Rule Update Frame Create Thunk
  | Stack frames                         <- stack
  , Evaluate (Atom (VarAtom var)) locals <- code
  , Just (addr, hobj)       <- vlookupHeap var locals globals heap
  , FunObj [] expr fun_locs <- hobj = -- Thunk form.
       Just (RuleUpdateCThunk
            ,[state { state_stack = Stack (UpdateFrame addr : frames)
                    , state_heap  = insertHeap addr Blackhole heap
                    , state_code  = Evaluate expr fun_locs }])

  -- Rule Update Frame Delete Lit
  | Stack (UpdateFrame frm_addr : rest) <- stack
  , Return (LitVal lit)                 <- code =
       Just (RuleUpdateDLit
            ,[state { state_stack = Stack rest
                    , state_heap  = insertHeap frm_addr (LitObj lit) heap
                    , state_code  = Return (LitVal lit) }])

  -- Rule Update Frame Delete Val Pointer
  | Stack (UpdateFrame frm_addr : rest) <- stack
  , Return (MemVal addr)                <- code
  , Just hobj <- lookupHeap addr heap
  , isHeapValueForm hobj =
       Just (RuleUpdateDValPtr
            ,[state { state_stack = Stack rest
                    , state_heap  = insertHeap frm_addr hobj heap
                    , state_code  = Return (MemVal addr) }])

  -- Rule Case Frame Create Case Non LitVal or MemVal
  | Stack frames                          <- stack
  , Evaluate (Case mxpr cvar alts) locals <- code
  , not (isExprValueForm mxpr locals globals heap) =
       Just (RuleCaseCCaseNonVal
            ,[state { state_stack = Stack (CaseFrame cvar alts locals : frames)
                    , state_code  = Evaluate mxpr locals }])

    -- Rule Case Frame Delete Lit
  | Stack (CaseFrame cvar alts frm_locs : rest) <- stack
  , Return (LitVal lit)                         <- code =
    let mxpr = Atom (LitAtom lit)
    in Just (RuleCaseDLit
            ,[state { state_stack = Stack rest
                    , state_code  = Evaluate (Case mxpr cvar alts) frm_locs }])

  -- Rule Case Frame Delete Heap Value
  | Stack (CaseFrame cvar alts frm_locs : rest) <- stack
  , Return (MemVal addr)                        <- code
  , Just hobj <- lookupHeap addr heap
  , isHeapValueForm hobj =
    let vname     = freshSeededName (varName cvar) confs
        vvar      = Var vname (varType cvar)
        mxpr      = Atom (VarAtom vvar)
        frm_locs' = insertLocals vvar (MemVal addr) frm_locs
    in Just (RuleCaseDValPtr
            ,[state { state_stack = Stack rest
                    , state_code  = Evaluate (Case mxpr cvar alts) frm_locs'
                    , state_names = vname : confs }])

  -- Rule Apply Frame Create Function Thunk
  | Stack frames                      <- stack
  , Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj)          <- vlookupHeap fun locals globals heap
  , FunObj [] expr fun_locs <- hobj =
       Just (RuleApplyCFunThunk
            ,[state { state_stack = Stack (ApplyFrame args locals : frames)
                    , state_code  = Evaluate expr fun_locs }])

  -- Rule Apply Frame Create Function Over Application
  | Stack frames                      <- stack
  , Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj)              <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , (_, Right ex_as)            <- unevenZip params args =
    let pass_in   = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtomList pass_in
        fun_locs' = insertLocalsList (zip params vals) fun_locs
    in Just (RuleApplyCFunAppOver
            ,[state { state_stack   = Stack (ApplyFrame ex_as locals' : frames)
                    , state_heap    = heap'
                    , state_globals = globals'
                    , state_code    = Evaluate expr fun_locs'
                    , state_names   = confs' }])

  -- Rule Apply Frame Delete ReturnPtr Function
  | Stack (ApplyFrame args frm_locs : rest) <- stack
  , Return (MemVal addr)                    <- code
  , Just hobj    <- lookupHeap addr heap
  , FunObj _ _ _ <- hobj
  , Just ftype   <- memAddrType addr heap =
    let fname     = freshName VarNSpace confs
        fvar      = Var fname ftype
        frm_locs' = insertLocals fvar (MemVal addr) frm_locs
    in Just (RuleApplyDReturnFun
            ,[state { state_stack = Stack rest
                    , state_code  = Evaluate (FunApp fvar args) frm_locs'
                    , state_names = fname : confs }])

  -- Rule Apply Frame Delete ReturnPtr Sym
  | Stack (ApplyFrame args frm_locs : rest) <- stack
  , Return (MemVal addr)                    <- code
  , Just hobj             <- lookupHeap addr heap
  , SymObj (Symbol sym _) <- hobj =
    let sname     = freshSeededName (varName sym) confs
        svar      = Var sname (varType sym)
        frm_locs' = insertLocals svar (MemVal addr) frm_locs
    in Just (RuleApplyDReturnSym
            ,[state { state_stack = Stack rest
                    , state_code  = Evaluate (FunApp svar args) frm_locs'
                    , state_names = sname : confs }])

  -- State is Value Form
  | isStateValueForm state = Just (RuleIdentity, [state])

  -- Everything Broke!!!
  | otherwise = Nothing