module DDC.Core.Flow.Transform.Extract
(extractModule)
where
import DDC.Core.Flow.Transform.Extract.Intersperse
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import DDC.Core.Transform.Annotate
import DDC.Core.Module
extractModule :: ModuleF -> [Procedure] -> ModuleF
extractModule orig procs
= orig
{ moduleBody = annotate () $ extractTop procs }
extractTop :: [Procedure] -> ExpF
extractTop procs
= XLet (LRec (map extractProcedure procs)) xUnit
extractProcedure :: Procedure -> (Bind Name, ExpF)
extractProcedure (Procedure n bsParam xsParam nest stmts xResult tResult)
= let tBody = foldr tFun tResult $ map typeOfBind xsParam
tQuant = foldr TForall tBody $ bsParam
in ( BName n tQuant
, xLAMs bsParam
$ xLams xsParam
$ extractNest nest stmts xResult )
extractNest
:: Nest
-> [LetsF]
-> ExpF
-> ExpF
extractNest nest stmts xResult
= let stmts' = intersperseStmts (extractLoop nest) stmts
in xLets stmts' xResult
extractLoop :: Nest -> [LetsF]
extractLoop (NestLoop tRate starts bodys inner ends _result)
= let
lsStart = concatMap extractStmtStart starts
lLoop = LLet (BNone tUnit)
(xApps (XVar (UPrim (NameOpLoop OpLoopLoop)
(typeOpLoop OpLoopLoop)))
[ XType tRate
, xBody ])
xBody = XLam (BAnon tNat)
$ xLets (lsBody ++ lsInner)
xUnit
lsBody = concatMap extractStmtBody bodys
lsInner = extractLoop inner
lsEnd = concatMap extractStmtEnd ends
in lsStart ++ [lLoop] ++ lsEnd
extractLoop (NestIf _tRateOuter tRateInner uFlags stmtsBody nested)
= let
UName nFlags = uFlags
nFlag = NameVarMod nFlags "elem"
xFlag = XVar (UName nFlag)
TVar (UName nK) = tRateInner
uCounter = UName (NameVarMod nK "count")
xGuard = xLoopGuard xFlag (XVar uCounter)
( XLam (BAnon tNat)
$ xLets (lsBody ++ lsNested) xUnit)
lsBody = concatMap extractStmtBody stmtsBody
lsNested = extractLoop nested
in [LLet (BNone tUnit) xGuard]
extractLoop NestEmpty
= []
extractLoop (NestList nests)
= concatMap extractLoop nests
extractStmtStart :: StmtStart -> [LetsF]
extractStmtStart ss
= case ss of
StartVecNew nVec tElem tRate'
-> [LLet (BName nVec (tVector tElem))
(xNewVectorR tElem tRate') ]
StartAcc n t x
-> [LLet (BName n (tRef t))
(xNew t x)]
extractStmtBody :: StmtBody -> [LetsF]
extractStmtBody sb
= case sb of
BodyStmt b x
-> [ LLet b x ]
BodyVecWrite nVec tElem xIx xVal
-> [ LLet (BNone tUnit)
(xWriteVector tElem (XVar (UName nVec)) xIx xVal)]
BodyAccRead n t bVar
-> [ LLet bVar
(xRead t (XVar (UName n))) ]
BodyAccWrite nAcc tElem xWorker
-> [ LLet (BNone tUnit)
(xWrite tElem (XVar (UName nAcc)) xWorker)]
extractStmtEnd :: StmtEnd -> [LetsF]
extractStmtEnd se
= case se of
EndStmt b x
-> [LLet b x]
EndAcc n t nAcc
-> [LLet (BName n t)
(xRead t (XVar (UName nAcc))) ]
EndVecSlice nVec tElem tRate
-> let
TVar (UName nK) = tRate
uCounter = UName (NameVarMod nK "count")
xCounter = xRead tInt (XVar uCounter)
xVec = XVar (UName nVec)
in [ LLet (BAnon tInt)
xCounter
, LLet (BName nVec (tVector tElem))
(xSliceVector tElem (XVar (UIx 0)) xVec) ]