-- | Rules
module SSTG.Core.Execution.Rules
    ( Rule(..)
    , reduce
    , isStateValForm
    ) where

import SSTG.Core.Language

import SSTG.Core.Execution.Support

-- | `Rule`s that are applied during STG reduction.
data Rule = RuleAtomLit | RuleAtomLitPtr | RuleAtomValPtr | RuleAtomUnInt
          | RulePrimApp
          | RuleConApp
          | RuleFunAppExact | RuleFunAppUnder | RuleFunAppSym
                            | RuleFunAppConPtr | RuleFunAppUnInt
          | RuleLet
          | RuleCaseLit | RuleCaseConPtr | RuleCaseAnyLit | RuleCaseAnyConPtr
                        | RuleCaseSym

          | RuleUpdateCThunk
          | RuleUpdateDLit | RuleUpdateDValPtr

          | RuleCaseCCaseNonVal
          | RuleCaseDLit | RuleCaseDValPtr

          | RuleApplyCFunThunk | RuleApplyCFunAppOver
          | RuleApplyDReturnFun | RuleApplyDReturnSym

          | RuleIdentity
          deriving (Show, Eq, Read, Ord)

-- Stack Independent Rules

-- | Does not include `LitObj`. i.e. if something points to this, we have
-- nothing to do in terms of reducitons.
isHeapValForm :: HeapObj -> Bool
isHeapValForm (SymObj _) = True
isHeapValForm (ConObj _ _) = True
isHeapValForm (FunObj (_:_) _ _) = True
isHeapValForm _ = False

-- | Either a `Lit` or points to a `Heap` value (not LitObj!). If we find
-- nothing in the `Heap`, then this means we can still upcast the var to a
-- symbolic that we create additional mappings to in the `Globals`.
isExprValForm :: Expr -> Locals -> Globals -> Heap -> Bool
isExprValForm (Atom (LitAtom _)) _ _ _ = True
isExprValForm (Atom (VarAtom var)) locals globals heap =
    case vlookupHeap var locals globals heap of
        Just (_, hobj) -> isHeapValForm hobj
        Nothing -> False
isExprValForm _ _ _ _ = False

-- | Is the `State` in a normal form that cannot be reduced further?
isStateValForm :: State -> Bool
isStateValForm state
  | Nothing <- popStack (state_stack state)
  , Return (LitVal _) <- (state_code state) = True

  | Nothing <- popStack (state_stack state)
  , Return (MemVal addr) <- (state_code state)
  , Just hobj <- lookupHeap addr (state_heap state)
  , isHeapValForm hobj = True

  | otherwise = False

-- | `Val` to `Lit`.
valToLit :: Val -> Lit
valToLit (LitVal lit) = lit
valToLit (MemVal addr) = AddrLit (memAddrInt addr)

-- | Uneven `zip` of two `List`s, with the leftover stored.
unevenZip :: [a] -> [b] -> ([(a, b)], Either [a] [b])
unevenZip as [] = ([], Left as)
unevenZip [] bs = ([], Right bs)
unevenZip (a:as) (b:bs) = ((a, b) : acc, excess)
  where
    (acc, excess) = unevenZip as bs

-- | Lift action wrapper type.
data LiftAct a = LiftAct a Locals Globals Heap [Name]

-- | Lift uninterpreted `Var`s into `Globals`.
liftUnInt :: LiftAct Var -> LiftAct MemAddr
liftUnInt (LiftAct var locals globals heap confs) = pass_out
  where
    sname = freshSeededName (varName var) confs
    svar = Var sname (typeOf var)
    (heap', addr) = allocHeapObj (SymObj (Symbol svar Nothing)) heap
    globals' = insertGlobalsVal var (MemVal addr) globals
    confs' = sname : confs
    pass_out = LiftAct addr locals globals' heap' confs'

-- | Lift `Atom` if necessary (i.e. uinterpreted / out-of-scope).
liftAtom :: LiftAct Atom -> LiftAct Val
liftAtom (LiftAct atom locals globals heap confs) = pass_out
  where
    pass_out = LiftAct aval locals globals' heap' confs'
    (aval, globals', heap', confs') = case atom of
        LitAtom lit -> (LitVal lit, globals, heap, confs)
        VarAtom var -> case lookupVal var locals globals of
            Just val -> (val, globals, heap, confs)
            Nothing -> let pass_in = LiftAct var locals globals heap confs
                           LiftAct addr _ g' h' c' = liftUnInt pass_in
                       in (MemVal addr, g', h', c')

-- | Lift a list of `Atom`s.
liftAtoms :: LiftAct [Atom] -> LiftAct [Val]
liftAtoms (LiftAct [] locals globals heap confs) = pass_out
  where
    pass_out = LiftAct [] locals globals heap confs
liftAtoms (LiftAct (atom:as) locals globals heap confs) = pass_out
  where
    pass_in = LiftAct atom locals globals heap confs
    pass_rest = LiftAct as locals' globals' heap' confs'
    pass_out = LiftAct (val : vs) localsf globalsf heapf confsf
    LiftAct val locals' globals' heap' confs' = liftAtom pass_in
    LiftAct vs localsf globalsf heapf confsf = liftAtoms pass_rest

-- | Lift `BindRhs`.
liftBindRhs :: LiftAct BindRhs -> LiftAct HeapObj
liftBindRhs (LiftAct (FunForm prms expr) locals globals heap confs) = pass_out
  where
    pass_out = LiftAct (FunObj prms expr locals) locals globals heap confs
liftBindRhs (LiftAct (ConForm dcon args) locals globals heap confs) = pass_out
  where
    pass_in = LiftAct args locals globals heap confs
    pass_out = LiftAct (ConObj dcon vals) locals' globals' heap' confs'
    LiftAct vals locals' globals' heap' confs' = liftAtoms pass_in

-- | Lift `BindRhs` list.
liftBindRhss :: LiftAct [BindRhs] -> LiftAct [HeapObj]
liftBindRhss (LiftAct [] locals globals heap confs) = pass_out
  where
    pass_out = LiftAct [] locals globals heap confs
liftBindRhss (LiftAct (rhs:rs) locals globals heap confs) = pass_out
  where
    pass_in = LiftAct rhs locals globals heap confs
    pass_rest = LiftAct rs locals' globals' heap' confs'
    pass_out = LiftAct (hobj : hs) localsf globalsf heapf confsf
    LiftAct hobj locals' globals' heap' confs' = liftBindRhs pass_in
    LiftAct hs localsf globalsf heapf confsf = liftBindRhss pass_rest

-- | Lift `Binds`.
liftBinds :: LiftAct Binds -> LiftAct ()
liftBinds (LiftAct (Binds NonRec kvs) locals globals heap confs) = pass_out
  where
    (heapf, addrs) = allocHeapObjs hobjs heap'
    mem_vals = map MemVal addrs
    localsf = insertLocalsVals (zip (map fst kvs) mem_vals) locals'
    pass_in = LiftAct (map snd kvs) locals globals heap confs
    pass_out = LiftAct () localsf globals' heapf confs'
    LiftAct hobjs locals' globals' heap' confs' = liftBindRhss pass_in
liftBinds (LiftAct (Binds Rec kvs) locals globals heap confs) = pass_out
  where
    hfakes = map (const Blackhole) kvs
    -- Allocate dummy BLACKHOLEs
    (heap', addrs) = allocHeapObjs hfakes heap
    mem_vals = map MemVal addrs
    -- Use the reigstered loca BLACKHOLEs to construct the locals closure.
    locals' = insertLocalsVals (zip (map fst kvs) mem_vals) locals
    heapf = insertHeapObjs (zip addrs hobjs) heap''
    pass_in = LiftAct (map snd kvs) locals' globals heap' confs
    pass_out = LiftAct () localsf globals' heapf confs'
    LiftAct hobjs localsf globals' heap'' confs' = liftBindRhss pass_in

-- | `Default` `Alt` branches in a `Case`.
defaultAlts :: [Alt] -> [Alt]
defaultAlts alts = [a | a @ (Alt Default _) <- alts]

-- | `AltCon` `Alt` branches in a `Case`.
nonDefaultAlts :: [Alt] -> [Alt]
nonDefaultAlts alts = [a | a @ (Alt acon _) <- alts, acon /= Default]

-- | Match `LitAlt` `Alt` branches.
matchLitAlts :: Lit -> [Alt] -> [Alt]
matchLitAlts lit alts = [a | a @ (Alt (LitAlt alit) _) <- alts, lit == alit]

-- | Match `DataCon` `Alt` branches.
matchDataAlts :: DataCon -> [Alt] -> [Alt]
matchDataAlts dc alts = [a | a @ (Alt (DataAlt adc _) _) <- alts, dc == adc]

-- | Negate `Constraint`.
negateConstraint :: Constraint -> Constraint
negateConstraint (Constraint a e l b) = Constraint a e l (not b)

-- | Lift `Alt`s during branching caused by symbolics.
liftSymAlt :: LiftAct (Var, MemAddr, Var, Alt) -> LiftAct (Expr, [Constraint])
liftSymAlt (LiftAct args locals globals heap confs) = pass_out
  where
    (mvar, addr, cvar, Alt acon expr) = args
    params = case acon of { DataAlt _ ps -> ps ; _ -> [] }
    snames = freshSeededNames (map varName params) confs
    svars = map (\(p, n) -> Var n (typeOf p)) (zip params snames)
    hobjs = map (\s -> SymObj (Symbol s Nothing)) svars
    (heap', addrs) = allocHeapObjs hobjs heap
    mem_vals = map MemVal addrs
    kvs = (cvar, MemVal addr) : zip params mem_vals
    locals' = insertLocalsVals kvs locals
    mexpr = Atom (VarAtom mvar)
    conss = [Constraint acon mexpr locals' True]
    confs' = snames ++ confs
    pass_out = LiftAct (expr, conss) locals' globals heap' confs'

-- | `Alt` closure to `State`.
liftedAltToState :: State -> LiftAct (Expr, [Constraint]) -> State
liftedAltToState state (LiftAct args locals globals heap confs) = state'
  where
    (expr, conss) = args
    pcons = state_path state
    state' = state { state_heap = heap
                   , state_globals = globals
                   , state_code = Evaluate expr locals
                   , state_names = confs
                   , state_path = insertPathConss conss pcons }

-- | Reduce the state if it matches some type of reduction `Rule`. Return
-- `Nothing` to denote that rule application has completely failed.
reduce :: State -> Maybe (Rule, [State])
reduce state @ State { state_stack = stack
                     , state_heap = heap
                     , state_globals = globals
                     , state_code = code
                     , state_names = confs }

-- Stack Independent Rules

  -- Atom Lit
  | Evaluate (Atom (LitAtom lit)) _ <- code =
    Just (RuleAtomLit
         ,[state { state_code = Return (LitVal lit) }])

  -- Atom Lit Pointer
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (_, hobj) <- vlookupHeap var locals globals heap
  , LitObj lit <- hobj =
    Just (RuleAtomLitPtr
         ,[state { state_code = Evaluate (Atom (LitAtom lit)) locals }])

  -- Rule Atom Val Pointer
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (addr, hobj) <- vlookupHeap var locals globals heap
  , isHeapValForm hobj =
    Just (RuleAtomValPtr
         ,[state { state_code = Return (MemVal addr) }])

  -- Rule Atom Uninterpreted
  | Evaluate (Atom (VarAtom uvar)) locals <- code
  , Nothing <- vlookupHeap uvar locals globals heap =
    let pass_in = LiftAct uvar locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftUnInt pass_in
    in Just (RuleAtomUnInt
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate (Atom (VarAtom uvar)) locals'
                    , state_names = confs' }])

  -- Prim Function App
  | Evaluate (PrimApp pfun args) locals <- code =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtoms pass_in
        eval = LitEval pfun (map valToLit vals)
    in Just (RulePrimApp
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate (Atom (LitAtom eval)) locals'
                    , state_names = confs' }])

  -- Rule Con App
  | Evaluate (ConApp dcon args) locals <- code =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtoms pass_in
        (heapf, addr) = allocHeapObj (ConObj dcon vals) heap'
    in Just (RuleConApp
            ,[state { state_heap = heapf
                    , state_globals = globals'
                    , state_code = Return (MemVal addr)
                    , state_names = confs' }])

  -- Rule Fun App Exact
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , length params == length args =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtoms pass_in
        fun_locs' = insertLocalsVals (zip params vals) fun_locs
    in Just (RuleFunAppExact
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate expr fun_locs'
                    , state_names = confs' }])

  -- Rule Fun App Under
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , (_, Left ex_params) <- unevenZip params args =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtoms pass_in
        fun_locs' = insertLocalsVals (zip params vals) fun_locs
        -- New Fun Object.
        pobj = FunObj ex_params expr fun_locs'
        (heapf, paddr) = allocHeapObj pobj heap'
    in Just (RuleFunAppUnder
            ,[state { state_heap = heapf
                    , state_globals = globals'
                    , state_code = Return (MemVal paddr)
                    , state_names = confs' }])

  -- Rule Fun App Symbolic
  | Evaluate (FunApp sfun args) locals <- code
  , Just (_, hobj) <- vlookupHeap sfun locals globals heap
  , SymObj (Symbol svar _) <- hobj =
    let sname = freshSeededName (varName svar) confs
        sres = Var sname (foldl AppTy (typeOf svar) (map typeOf args))
        sym_app = Symbol sres (Just (FunApp sfun args, locals))
        (heap', addr) = allocHeapObj (SymObj sym_app) heap
    in Just (RuleFunAppSym
            ,[state { state_heap = heap'
                    , state_code = Return (MemVal addr)
                    , state_names = sname : confs }])

  -- Rule Fun App ConObj
  | Evaluate (FunApp cvar []) locals <- code
  , Just (addr, hobj) <- vlookupHeap cvar locals globals heap
  , ConObj _ _ <- hobj =
    Just (RuleFunAppConPtr
         ,[state { state_code = Return (MemVal addr) }])

  -- Rule Fun App Uninterpreted
  | Evaluate (FunApp ufun args) locals <- code
  , Nothing <- vlookupHeap ufun locals globals heap =
    let pass_in = LiftAct ufun locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftUnInt pass_in
    in Just (RuleFunAppUnInt
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate (FunApp ufun args) locals'
                    , state_names = confs' }])

  -- Rule Let
  | Evaluate (Let binds expr) locals <- code =
    let pass_in = LiftAct binds locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftBinds pass_in
    in Just (RuleLet
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate expr locals'
                    , state_names = confs' }])

  -- Rule Case Lit
  | Evaluate (Case (Atom (LitAtom lit)) cvar alts) locals <- code
  , (Alt _ expr):_ <- matchLitAlts lit alts =
    let locals' = insertLocalsVal cvar (LitVal lit) locals
    in Just (RuleCaseLit
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Con Pointer
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , ConObj dcon vals <- hobj
  , (Alt (DataAlt _ params) expr):_ <- matchDataAlts dcon alts
  , length params == length vals =
    let kvs = (cvar, MemVal addr) : zip params vals
        locals' = insertLocalsVals kvs locals
    in Just (RuleCaseConPtr
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Any Lit
  | Evaluate (Case (Atom (LitAtom lit)) cvar alts) locals <- code
  , [] <- matchLitAlts lit alts
  , (Alt _ expr):_ <- defaultAlts alts =
    let locals' = insertLocalsVal cvar (LitVal lit) locals
    in Just (RuleCaseAnyLit
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Any Con Pointer
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , ConObj dcon _ <- hobj
  , [] <- matchDataAlts dcon alts
  , (Alt _ expr):_ <- defaultAlts alts =
    let locals' = insertLocalsVal cvar (MemVal addr) locals
    in Just (RuleCaseAnyConPtr
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Sym
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , SymObj _ <- hobj
  , (ndef_alts, def_alts) <- (nonDefaultAlts alts, defaultAlts alts)
  , length (ndef_alts ++ def_alts) > 0 =
    let ndef_ins = map (\a -> LiftAct (mvar, addr, cvar, a)
                                      locals globals heap confs) ndef_alts
        ndef_lifts = map liftSymAlt ndef_ins
        def_ins = map (\a -> LiftAct (mvar, addr, cvar, a)
                                     locals globals heap confs) def_alts
        def_lifts = map liftSymAlt def_ins
        -- Make AltCon states first.
        ndef_sts = map (liftedAltToState state) ndef_lifts
        -- Make DEFAULT states next.
        all_conss = concatMap (\(LiftAct (_, c) _ _ _ _) -> c) ndef_lifts
        negatives = map negateConstraint all_conss
        def_lifts' = map (\(LiftAct (e, _) l g h c) ->
                           (LiftAct (e, negatives) l g h c)) def_lifts
        def_sts = map (liftedAltToState state) def_lifts'
    in Just (RuleCaseSym, ndef_sts ++ def_sts)

-- Stack Dependent Rules

  -- Rule Update Frame Create Thunk
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (addr, hobj) <- vlookupHeap var locals globals heap
  , FunObj [] expr fun_locs <- hobj =  -- Thunk form.
    let frame = UpdateFrame addr
    in Just (RuleUpdateCThunk
            ,[state { state_stack = pushStack frame stack
                    , state_heap = insertHeapObj addr Blackhole heap
                    , state_code = Evaluate expr fun_locs }])

  -- Rule Update Frame Delete Lit
  | Just (UpdateFrame frm_addr, stack') <- popStack stack
  , Return (LitVal lit) <- code =
    Just (RuleUpdateDLit
         ,[state { state_stack = stack'
                 , state_heap = insertHeapObj frm_addr (LitObj lit) heap
                 , state_code = Return (LitVal lit) }])

  -- Rule Update Frame Delete Val Pointer
  | Just (UpdateFrame frm_addr, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , isHeapValForm hobj =
    Just (RuleUpdateDValPtr
         ,[state { state_stack = stack'
                 , state_heap = insertHeapRedir frm_addr addr heap
                 , state_code = Return (MemVal addr) }])

  -- Rule Case Frame Create Case Non LitVal or MemVal
  | Evaluate (Case mexpr cvar alts) locals <- code
  , not (isExprValForm mexpr locals globals heap) =
    let frame = CaseFrame cvar alts locals
    in Just (RuleCaseCCaseNonVal
            ,[state { state_stack = pushStack frame stack
                    , state_code = Evaluate mexpr locals }])

    -- Rule Case Frame Delete Lit
  | Just (CaseFrame cvar alts frm_locs, stack') <- popStack stack
  , Return (LitVal lit) <- code =
    let cexpr = Case (Atom (LitAtom lit)) cvar alts
    in Just (RuleCaseDLit
            ,[state { state_stack = stack'
                    , state_code = Evaluate cexpr frm_locs }])

  -- Rule Case Frame Delete Heap Val
  | Just (CaseFrame cvar alts frm_locs, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , isHeapValForm hobj =
    let vname = freshSeededName (varName cvar) confs
        vvar = Var vname (typeOf cvar)
        cexpr = Case (Atom (VarAtom vvar)) cvar alts
        frm_locs' = insertLocalsVal vvar (MemVal addr) frm_locs
    in Just (RuleCaseDValPtr
            ,[state { state_stack = stack'
                    , state_code = Evaluate cexpr frm_locs'
                    , state_names = vname : confs }])

  -- Rule Apply Frame Create Function Thunk
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj [] _ _ <- hobj =
    let frame = ApplyFrame args locals
    in Just (RuleApplyCFunThunk
            ,[state { state_stack = pushStack frame stack
                    , state_code = Evaluate (Atom (VarAtom fun)) locals }])

  -- Rule Apply Frame Create Function Over Application
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , (_, Right ex_args) <- unevenZip params args =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtoms pass_in
        fun_locs' = insertLocalsVals (zip params vals) fun_locs
        frame = ApplyFrame ex_args locals'
    in Just (RuleApplyCFunAppOver
            ,[state { state_stack = pushStack frame stack
                    , state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate expr fun_locs'
                    , state_names = confs' }])

  -- Rule Apply Frame Delete ReturnPtr Function
  | Just (ApplyFrame args frm_locs, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , FunObj _ _ _ <- hobj
  , Just ftype <- memAddrType addr heap =
    let fname = freshName VarNSpace confs
        fvar = Var fname ftype
        frm_locs' = insertLocalsVal fvar (MemVal addr) frm_locs
    in Just (RuleApplyDReturnFun
            ,[state { state_stack = stack'
                    , state_code = Evaluate (FunApp fvar args) frm_locs'
                    , state_names = fname : confs }])

  -- Rule Apply Frame Delete ReturnPtr Sym
  | Just (ApplyFrame args frm_locs, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , SymObj (Symbol svar _) <- hobj =
    let sname = freshSeededName (varName svar) confs
        sfun = Var sname (typeOf svar)
        frm_locs' = insertLocalsVal sfun (MemVal addr) frm_locs
    in Just (RuleApplyDReturnSym
            ,[state { state_stack = stack'
                    , state_code = Evaluate (FunApp sfun args) frm_locs'
                    , state_names = sname : confs }])

  -- State is Val Form
  | isStateValForm state = Just (RuleIdentity, [state])

  -- Everything Broke!!!
  | otherwise = Nothing