-- | Rules
module SSTG.Core.Execution.Rules
    ( Rule(..)
    , reduce
    , isStateValueForm
    ) where

import SSTG.Core.Language

-- | `Rule`s that are applied during STG reduction.
data Rule = RuleAtomLit | RuleAtomLitPtr | RuleAtomValPtr | RuleAtomUnInt
          | RulePrimApp
          | RuleConApp
          | RuleFunAppExact | RuleFunAppUnder | RuleFunAppSym
                            | RuleFunAppConPtr | RuleFunAppUnInt
          | RuleLet
          | RuleCaseLit | RuleCaseConPtr | RuleCaseAnyLit | RuleCaseAnyConPtr
                        | RuleCaseSym

          | RuleUpdateCThunk
          | RuleUpdateDLit | RuleUpdateDValPtr

          | RuleCaseCCaseNonVal
          | RuleCaseDLit | RuleCaseDValPtr

          | RuleApplyCFunThunk | RuleApplyCFunAppOver
          | RuleApplyDReturnFun | RuleApplyDReturnSym

          | RuleIdentity
          deriving (Show, Eq, Read, Ord)

-- Stack Independent Rules

-- | Does not include `LitObj`. i.e. if something points to this, we have
-- nothing to do in terms of reducitons.
isHeapValueForm :: HeapObj -> Bool
isHeapValueForm (SymObj _) = True
isHeapValueForm (ConObj _ _) = True
isHeapValueForm (FunObj (_:_) _ _) = True
isHeapValueForm _ = False

-- | Either a `Lit` or points to a `Heap` value (not LitObj!). If we find
-- nothing in the `Heap`, then this means we can still upcast the var to a
-- symbolic that we create additional mappings to in the `Globals`.
isExprValueForm :: Expr -> Locals -> Globals -> Heap -> Bool
isExprValueForm (Atom (LitAtom _)) _ _ _ = True
isExprValueForm (Atom (VarAtom var)) locals globals heap =
    case vlookupHeap var locals globals heap of
        Just (_, hobj) -> isHeapValueForm hobj
        Nothing -> False
isExprValueForm _ _ _ _ = False

-- | Is the `State` in a normal form that cannot be reduced further?
isStateValueForm :: State -> Bool
isStateValueForm state
  | Nothing <- popStack (state_stack state)
  , Return (LitVal _) <- (state_code state) = True

  | Nothing <- popStack (state_stack state)
  , Return (MemVal addr) <- (state_code state)
  , Just hobj <- lookupHeap addr (state_heap state)
  , isHeapValueForm hobj = True

  | otherwise = False

-- | `Value` to `Lit`.
valueToLit :: Value -> Lit
valueToLit (LitVal lit) = lit
valueToLit (MemVal addr) = AddrLit (addrInt addr)

-- | Uneven `zip` of two `List`s, with the leftover stored.
unevenZip :: [a] -> [b] -> ([(a, b)], Either [a] [b])
unevenZip as [] = ([], Left as)
unevenZip [] bs = ([], Right bs)
unevenZip (a:as) (b:bs) = ((a, b) : acc, excess)
  where
    (acc, excess) = unevenZip as bs

-- | Lift action wrapper type.
data LiftAct a = LiftAct a Locals Globals Heap [Name]

-- | Lift uninterpreted `Var`s into `Globals`.
liftUnInt :: LiftAct Var -> LiftAct MemAddr
liftUnInt (LiftAct var locals globals heap confs) = pass_out
  where
    sname = freshSeededName (varName var) confs
    svar = Var sname (varType var)
    (heap', addr) = allocHeap (SymObj (Symbol svar Nothing)) heap
    globals' = insertGlobals (var, MemVal addr) globals
    confs' = sname : confs
    pass_out = LiftAct addr locals globals' heap' confs'

-- | Lift `Atom` if necessary (i.e. uinterpreted / out-of-scope).
liftAtom :: LiftAct Atom -> LiftAct Value
liftAtom (LiftAct atom locals globals heap confs) = pass_out
  where
    pass_out = LiftAct aval locals globals' heap' confs'
    (aval, globals', heap', confs') = case atom of
        LitAtom lit -> (LitVal lit, globals, heap, confs)
        VarAtom var -> case lookupValue var locals globals of
            Just val -> (val, globals, heap, confs)
            Nothing -> let pass_in = LiftAct var locals globals heap confs
                           LiftAct addr _ g' h' c' = liftUnInt pass_in
                       in (MemVal addr, g', h', c')

-- | Lift a list of `Atom`s.
liftAtomList :: LiftAct [Atom] -> LiftAct [Value]
liftAtomList (LiftAct [] locals globals heap confs) = pass_out
  where
    pass_out = LiftAct [] locals globals heap confs
liftAtomList (LiftAct (atom:as) locals globals heap confs) = pass_out
  where
    pass_in = LiftAct atom locals globals heap confs
    pass_rest = LiftAct as locals' globals' heap' confs'
    pass_out = LiftAct (val : vs) localsf globalsf heapf confsf
    LiftAct val locals' globals' heap' confs' = liftAtom pass_in
    LiftAct vs localsf globalsf heapf confsf = liftAtomList pass_rest

-- | Lift `BindRhs`.
liftBindRhs :: LiftAct BindRhs -> LiftAct HeapObj
liftBindRhs (LiftAct (FunForm prms expr) locals globals heap confs) = pass_out
  where
    pass_out = LiftAct (FunObj prms expr locals) locals globals heap confs
liftBindRhs (LiftAct (ConForm dcon args) locals globals heap confs) = pass_out
  where
    pass_in = LiftAct args locals globals heap confs
    pass_out = LiftAct (ConObj dcon vals) locals' globals' heap' confs'
    LiftAct vals locals' globals' heap' confs' = liftAtomList pass_in

-- | Lift `BindRhs` list.
liftBindRhsList :: LiftAct [BindRhs] -> LiftAct [HeapObj]
liftBindRhsList (LiftAct [] locals globals heap confs) = pass_out
  where
    pass_out = LiftAct [] locals globals heap confs
liftBindRhsList (LiftAct (rhs:rs) locals globals heap confs) = pass_out
  where
    pass_in = LiftAct rhs locals globals heap confs
    pass_rest = LiftAct rs locals' globals' heap' confs'
    pass_out = LiftAct (hobj : hos) localsf globalsf heapf confsf
    LiftAct hobj locals' globals' heap' confs' = liftBindRhs pass_in
    LiftAct hos localsf globalsf heapf confsf = liftBindRhsList pass_rest

-- | Lift `Bind`.
liftBind :: LiftAct Bind -> LiftAct ()
liftBind (LiftAct (Bind NonRec bnd) locals globals heap confs) = pass_out
  where
    (heapf, addrs) = allocHeapList hobjs heap'
    mem_vals = map MemVal addrs
    localsf = insertLocalsList (zip (map fst bnd) mem_vals) locals'
    pass_in = LiftAct (map snd bnd) locals globals heap confs
    pass_out = LiftAct () localsf globals' heapf confs'
    LiftAct hobjs locals' globals' heap' confs' = liftBindRhsList pass_in
liftBind (LiftAct (Bind Rec bnd) locals globals heap confs) = pass_out
  where
    hfakes = map (const Blackhole) bnd
    -- Allocate dummy BLACKHOLEs
    (heap', addrs) = allocHeapList hfakes heap
    mem_vals = map MemVal addrs
    -- Use the reigstered loca BLACKHOLEs to construct the locals closure.
    locals' = insertLocalsList (zip (map fst bnd) mem_vals) locals
    heapf = insertHeapList (zip addrs hobjs) heap''
    pass_in = LiftAct (map snd bnd) locals' globals heap' confs
    pass_out = LiftAct () localsf globals' heapf confs'
    LiftAct hobjs localsf globals' heap'' confs' = liftBindRhsList pass_in

-- | `Default` `Alt` branches in a `Case`.
defaultAlts :: [Alt] -> [Alt]
defaultAlts alts = [a | a @ (Alt Default _ _) <- alts]

-- | `AltCon` `Alt` branches in a `Case`.
altConAlts :: [Alt] -> [Alt]
altConAlts alts = [a | a @ (Alt acon _ _) <- alts, acon /= Default]

-- | Match `LitAlt` `Alt` branches.
matchLitAlts :: Lit -> [Alt] -> [Alt]
matchLitAlts lit alts = [a | a @ (Alt (LitAlt alit) _ _) <- alts, lit == alit]

-- | Match `DataCon` `Alt` branches.
matchDataAlts :: DataCon -> [Alt] -> [Alt]
matchDataAlts dc alts = [a | a @ (Alt (DataAlt adc) _ _) <- alts, dc == adc]

-- | Negate `Constraint`.
negateConstraint :: Constraint -> Constraint
negateConstraint (Constraint a e l b) = Constraint a e l (not b)

-- | Lift `Alt`s during branching caused by symbolics.
liftSymAlt :: LiftAct (Var, MemAddr, Var, Alt) -> LiftAct (Expr, [Constraint])
liftSymAlt (LiftAct args locals globals heap confs) = pass_out
  where
    (mvar, addr, cvar, Alt ac params expr) = args
    snames = freshSeededNameList (map varName params) confs
    svars = map (\(p, n) -> Var n (varType p)) (zip params snames)
    hobjs = map (\s -> SymObj (Symbol s Nothing)) svars
    (heap', addrs) = allocHeapList hobjs heap
    mem_vals = map MemVal addrs
    kvs = (cvar, MemVal addr) : zip params mem_vals
    locals' = insertLocalsList kvs locals
    mxpr = Atom (VarAtom mvar)
    conss = [Constraint (ac, params) mxpr locals' True]
    confs' = snames ++ confs
    pass_out = LiftAct (expr, conss) locals' globals heap' confs'

-- | `Alt` closure to `State`.
liftedAltToState :: State -> LiftAct (Expr, [Constraint]) -> State
liftedAltToState state (LiftAct args locals globals heap confs) = state'
  where
    (expr, conss) = args
    pcons = state_paths state
    state' = state { state_heap = heap
                   , state_globals = globals
                   , state_code = Evaluate expr locals
                   , state_names = confs
                   , state_paths = insertPathConsList conss pcons }

-- | Reduce the state if it matches some type of reduction `Rule`. Return
-- `Nothing` to denote that rule application has completely failed.
reduce :: State -> Maybe (Rule, [State])
reduce state @ State { state_stack = stack
                     , state_heap = heap
                     , state_globals = globals
                     , state_code = code
                     , state_names = confs }

-- Stack Independent Rules

  -- Atom Lit
  | Evaluate (Atom (LitAtom lit)) _ <- code =
    Just (RuleAtomLit
         ,[state { state_code = Return (LitVal lit) }])

  -- Atom Lit Pointer
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (_, hobj) <- vlookupHeap var locals globals heap
  , LitObj lit <- hobj =
    Just (RuleAtomLitPtr
         ,[state { state_code = Evaluate (Atom (LitAtom lit)) locals }])

  -- Rule Atom Val Pointer
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (addr, hobj) <- vlookupHeap var locals globals heap
  , isHeapValueForm hobj =
    Just (RuleAtomValPtr
         ,[state { state_code = Return (MemVal addr) }])

  -- Rule Atom Uninterpreted
  | Evaluate (Atom (VarAtom uvar)) locals <- code
  , Nothing <- vlookupHeap uvar locals globals heap =
    let pass_in = LiftAct uvar locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftUnInt pass_in
    in Just (RuleAtomUnInt
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate (Atom (VarAtom uvar)) locals'
                    , state_names = confs' }])

  -- Prim Function App
  | Evaluate (PrimApp pfun args) locals <- code =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtomList pass_in
        eval = SymLitEval pfun (map valueToLit vals)
    in Just (RulePrimApp
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate (Atom (LitAtom eval)) locals'
                    , state_names = confs' }])

  -- Rule Con App
  | Evaluate (ConApp dcon args) locals <- code =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtomList pass_in
        (heapf, addr) = allocHeap (ConObj dcon vals) heap'
    in Just (RuleConApp
            ,[state { state_heap = heapf
                    , state_globals = globals'
                    , state_code = Return (MemVal addr)
                    , state_names = confs' }])

  -- Rule Fun App Exact
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , length params == length args =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtomList pass_in
        fun_locs' = insertLocalsList (zip params vals) fun_locs
    in Just (RuleFunAppExact
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate expr fun_locs'
                    , state_names = confs' }])

  -- Rule Fun App Under
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , (_, Left ex_params) <- unevenZip params args =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals _ globals' heap' confs' = liftAtomList pass_in
        fun_locs' = insertLocalsList (zip params vals) fun_locs
        -- New Fun Object.
        pobj = FunObj ex_params expr fun_locs'
        (heapf, paddr) = allocHeap pobj heap'
    in Just (RuleFunAppUnder
            ,[state { state_heap = heapf
                    , state_globals = globals'
                    , state_code = Return (MemVal paddr)
                    , state_names = confs' }])

  -- Rule Fun App Symbolic
  | Evaluate (FunApp sfun args) locals <- code
  , Just (_, hobj) <- vlookupHeap sfun locals globals heap
  , SymObj (Symbol svar _) <- hobj =
    let sname = freshSeededName (varName svar) confs
        svar' = Var sname (foldl AppTy (varType svar) (map atomType args))
        sym = Symbol svar' (Just (FunApp sfun args, locals))
        (heap', addr) = allocHeap (SymObj sym) heap
    in Just (RuleFunAppSym
            ,[state { state_heap = heap'
                    , state_code = Return (MemVal addr)
                    , state_names = sname : confs }])

  -- Rule Fun App ConObj
  | Evaluate (FunApp cvar []) locals <- code
  , Just (addr, hobj) <- vlookupHeap cvar locals globals heap
  , ConObj _ _ <- hobj =
    Just (RuleFunAppConPtr
         ,[state { state_code = Return (MemVal addr) }])

  -- Rule Fun App Uninterpreted
  | Evaluate (FunApp ufun args) locals <- code
  , Nothing <- vlookupHeap ufun locals globals heap =
    let pass_in = LiftAct ufun locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftUnInt pass_in
    in Just (RuleFunAppUnInt
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate (FunApp ufun args) locals'
                    , state_names = confs' }])

  -- Rule Let
  | Evaluate (Let bnd expr) locals <- code =
    let pass_in = LiftAct bnd locals globals heap confs
        LiftAct _ locals' globals' heap' confs' = liftBind pass_in
    in Just (RuleLet
            ,[state { state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate expr locals'
                    , state_names = confs' }])

  -- Rule Case Lit
  | Evaluate (Case (Atom (LitAtom lit)) cvar alts) locals <- code
  , (Alt _ _ expr):_ <- matchLitAlts lit alts =
    let locals' = insertLocals (cvar, LitVal lit) locals
    in Just (RuleCaseLit
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Con Pointer
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , ConObj dcon vals <- hobj
  , (Alt _ params expr):_ <- matchDataAlts dcon alts
  , length params == length vals =
    let kvs = (cvar, MemVal addr) : zip params vals
        locals' = insertLocalsList kvs locals
    in Just (RuleCaseConPtr
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Any Lit
  | Evaluate (Case (Atom (LitAtom lit)) cvar alts) locals <- code
  , [] <- matchLitAlts lit alts
  , (Alt _ _ expr):_ <- defaultAlts alts =
    let locals' = insertLocals (cvar, LitVal lit) locals
    in Just (RuleCaseAnyLit
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Any Con Pointer
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , ConObj dcon _ <- hobj
  , [] <- matchDataAlts dcon alts
  , (Alt _ _ expr):_ <- defaultAlts alts =
    let locals' = insertLocals (cvar, MemVal addr) locals
    in Just (RuleCaseAnyConPtr
            ,[state { state_code = Evaluate expr locals' }])

  -- Rule Case Sym
  | Evaluate (Case (Atom (VarAtom mvar)) cvar alts) locals <- code
  , Just (addr, hobj) <- vlookupHeap mvar locals globals heap
  , SymObj _ <- hobj
  , (acon_alts, def_alts) <- (altConAlts alts, defaultAlts alts)
  , length (acon_alts ++ def_alts) > 0 =
    let acon_ins = map (\a -> LiftAct (mvar, addr, cvar, a)
                                      locals globals heap confs) acon_alts
        acon_lifts = map liftSymAlt acon_ins
        def_ins = map (\a -> LiftAct (mvar, addr, cvar, a)
                                     locals globals heap confs) def_alts
        def_lifts = map liftSymAlt def_ins
        -- Make AltCon states first.
        acon_sts = map (liftedAltToState state) acon_lifts
        -- Make DEFAULT states next.
        all_conss = concatMap (\(LiftAct (_, c) _ _ _ _) -> c) acon_lifts
        negatives = map negateConstraint all_conss
        def_lifts' = map (\(LiftAct (e, _) l g h c) ->
                                    (LiftAct (e, negatives) l g h c)) def_lifts
        def_sts = map (liftedAltToState state) def_lifts'
    in Just (RuleCaseSym, acon_sts ++ def_sts)

-- Stack Dependent Rules

  -- Rule Update Frame Create Thunk
  | Evaluate (Atom (VarAtom var)) locals <- code
  , Just (addr, hobj) <- vlookupHeap var locals globals heap
  , FunObj [] expr fun_locs <- hobj =  -- Thunk form.
    let frame = UpdateFrame addr
    in Just (RuleUpdateCThunk
            ,[state { state_stack = pushStack frame stack
                    , state_heap = insertHeap (addr, Blackhole) heap
                    , state_code = Evaluate expr fun_locs }])

  -- Rule Update Frame Delete Lit
  | Just (UpdateFrame frm_addr, stack') <- popStack stack
  , Return (LitVal lit) <- code =
    Just (RuleUpdateDLit
         ,[state { state_stack = stack'
                 , state_heap = insertHeap (frm_addr, LitObj lit) heap
                 , state_code = Return (LitVal lit) }])

  -- Rule Update Frame Delete Val Pointer
  | Just (UpdateFrame frm_addr, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , isHeapValueForm hobj =
    Just (RuleUpdateDValPtr
         ,[state { state_stack = stack'
                 , state_heap = insertHeap (frm_addr, AddrObj addr) heap
                 , state_code = Return (MemVal addr) }])

  -- Rule Case Frame Create Case Non LitVal or MemVal
  | Evaluate (Case mxpr cvar alts) locals <- code
  , not (isExprValueForm mxpr locals globals heap) =
    let frame = CaseFrame cvar alts locals
    in Just (RuleCaseCCaseNonVal
            ,[state { state_stack = pushStack frame stack
                    , state_code = Evaluate mxpr locals }])

    -- Rule Case Frame Delete Lit
  | Just (CaseFrame cvar alts frm_locs, stack') <- popStack stack
  , Return (LitVal lit) <- code =
    let mxpr = Atom (LitAtom lit)
    in Just (RuleCaseDLit
            ,[state { state_stack = stack'
                    , state_code = Evaluate (Case mxpr cvar alts) frm_locs }])

  -- Rule Case Frame Delete Heap Value
  | Just (CaseFrame cvar alts frm_locs, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , isHeapValueForm hobj =
    let vname = freshSeededName (varName cvar) confs
        vvar = Var vname (varType cvar)
        mxpr = Atom (VarAtom vvar)
        frm_locs' = insertLocals (vvar, MemVal addr) frm_locs
    in Just (RuleCaseDValPtr
            ,[state { state_stack = stack'
                    , state_code = Evaluate (Case mxpr cvar alts) frm_locs'
                    , state_names = vname : confs }])

  -- Rule Apply Frame Create Function Thunk
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj [] _ _ <- hobj =
    let frame = ApplyFrame args locals
    in Just (RuleApplyCFunThunk
            ,[state { state_stack = pushStack frame stack
                    , state_code = Evaluate (Atom (VarAtom fun)) locals }])

  -- Rule Apply Frame Create Function Over Application
  | Evaluate (FunApp fun args) locals <- code
  , Just (_, hobj) <- vlookupHeap fun locals globals heap
  , FunObj params expr fun_locs <- hobj
  , (_, Right ex_args) <- unevenZip params args =
    let pass_in = LiftAct args locals globals heap confs
        LiftAct vals locals' globals' heap' confs' = liftAtomList pass_in
        fun_locs' = insertLocalsList (zip params vals) fun_locs
        frame = ApplyFrame ex_args locals'
    in Just (RuleApplyCFunAppOver
            ,[state { state_stack = pushStack frame stack
                    , state_heap = heap'
                    , state_globals = globals'
                    , state_code = Evaluate expr fun_locs'
                    , state_names = confs' }])

  -- Rule Apply Frame Delete ReturnPtr Function
  | Just (ApplyFrame args frm_locs, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , FunObj _ _ _ <- hobj
  , Just ftype <- memAddrType addr heap =
    let fname = freshName VarNSpace confs
        fvar = Var fname ftype
        frm_locs' = insertLocals (fvar, MemVal addr) frm_locs
    in Just (RuleApplyDReturnFun
            ,[state { state_stack = stack'
                    , state_code = Evaluate (FunApp fvar args) frm_locs'
                    , state_names = fname : confs }])

  -- Rule Apply Frame Delete ReturnPtr Sym
  | Just (ApplyFrame args frm_locs, stack') <- popStack stack
  , Return (MemVal addr) <- code
  , Just hobj <- lookupHeap addr heap
  , SymObj (Symbol svar _) <- hobj =
    let sname = freshSeededName (varName svar) confs
        svar' = Var sname (varType svar)
        frm_locs' = insertLocals (svar', MemVal addr) frm_locs
    in Just (RuleApplyDReturnSym
            ,[state { state_stack = stack'
                    , state_code = Evaluate (FunApp svar' args) frm_locs'
                    , state_names = sname : confs }])

  -- State is Value Form
  | isStateValueForm state = Just (RuleIdentity, [state])

  -- Everything Broke!!!
  | otherwise = Nothing