module DDC.Core.Flow.Transform.Extract
( extractModule
, extractProcedure)
where
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import DDC.Core.Transform.Annotate
import DDC.Core.Module
extractModule :: ModuleF -> [Procedure] -> ModuleF
extractModule orig procs
= orig
{ moduleBody = annotate () $ extractTop procs }
extractTop :: [Procedure] -> ExpF
extractTop procs
= XLet (LRec (map extractProcedure procs)) xUnit
extractProcedure :: Procedure -> (Bind Name, ExpF)
extractProcedure (Procedure n bsParam xsParam nest)
= let tBody = foldr tFun tUnit $ map typeOfBind xsParam
tQuant = foldr TForall tBody $ bsParam
in ( BName n tQuant
, xLAMs bsParam
$ xLams xsParam
$ extractNest nest xUnit )
extractNest
:: Nest
-> ExpF
-> ExpF
extractNest nest xResult
= xLets (extractLoop nest) xResult
extractLoop :: Nest -> [LetsF]
extractLoop (NestLoop tRate starts bodys inner ends _result)
= let
lsStart = concatMap extractStmtStart starts
lLoop = LLet (BNone tUnit)
(xApps (XVar (UPrim (NameOpControl OpControlLoop)
(typeOpControl OpControlLoop)))
[ XType tRate
, xBody ])
xBody = XLam (BAnon tNat)
$ xLets (lsBody ++ lsInner)
xUnit
lsBody = concatMap extractStmtBody bodys
lsInner = extractLoop inner
lsEnd = concatMap extractStmtEnd ends
in lsStart ++ [lLoop] ++ lsEnd
extractLoop (NestGuard _tRateOuter tRateInner uFlags stmtsBody nested)
= let
UName nFlags = uFlags
nFlag = NameVarMod nFlags "elem"
xFlag = XVar (UName nFlag)
TVar (UName nK) = tRateInner
uCounter = UName (NameVarMod nK "count")
xBody = xGuard (XVar uCounter) xFlag
( XLam (BAnon tNat)
$ xLets (lsBody ++ lsNested) xUnit)
lsBody = concatMap extractStmtBody stmtsBody
lsNested = extractLoop nested
in [LLet (BNone tUnit) xBody]
extractLoop (NestSegment _tRateOuter tRateInner uLengths stmtsBody nested)
= let
UName nLengths = uLengths
nLength = NameVarMod nLengths "elem"
xLength = XVar (UName nLength)
TVar (UName nK) = tRateInner
uCounter = UName (NameVarMod nK "count")
xBody = xSegment (XVar uCounter) xLength
( XLam (BAnon tNat)
$ XLam (BAnon tNat)
$ xLets (lsBody ++ lsNested) xUnit)
lsBody = concatMap extractStmtBody stmtsBody
lsNested = extractLoop nested
in [LLet (BNone tUnit) xBody]
extractLoop NestEmpty
= []
extractLoop (NestList nests)
= concatMap extractLoop nests
extractStmtStart :: StmtStart -> [LetsF]
extractStmtStart ss
= case ss of
StartStmt b x
-> [LLet b x]
StartVecNew nVec tElem tRate'
-> [LLet (BName nVec (tVector tElem))
(xNewVectorR tElem tRate') ]
StartAcc n t x
-> [LLet (BName n (tRef t))
(xNew t x)]
extractStmtBody :: StmtBody -> [LetsF]
extractStmtBody sb
= case sb of
BodyStmt b x
-> [ LLet b x ]
BodyVecWrite nVec tElem xIx xVal
-> [ LLet (BNone tUnit)
(xWriteVector tElem (XVar (UName nVec)) xIx xVal)]
BodyAccRead n t bVar
-> [ LLet bVar
(xRead t (XVar (UName n))) ]
BodyAccWrite nAcc tElem xWorker
-> [ LLet (BNone tUnit)
(xWrite tElem (XVar (UName nAcc)) xWorker)]
extractStmtEnd :: StmtEnd -> [LetsF]
extractStmtEnd se
= case se of
EndStmt b x
-> [LLet b x]
EndAcc n t nAcc
-> [LLet (BName n t)
(xRead t (XVar (UName nAcc))) ]
EndVecTrunc nVec tElem tRate
-> let
TVar (UName nK) = tRate
uCounter = UName (NameVarMod nK "count")
xCounter = xRead tNat (XVar uCounter)
xVec = XVar (UName nVec)
in [ LLet (BAnon tNat)
xCounter
, LLet (BNone tUnit)
(xTruncVector tElem (XVar (UIx 0)) xVec) ]