module DDC.Core.Flow.Transform.Schedule.Scalar
(scheduleScalar)
where
import DDC.Core.Flow.Transform.Schedule.Nest
import DDC.Core.Flow.Transform.Schedule.Error
import DDC.Core.Flow.Transform.Schedule.Base
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Process
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import Control.Monad
scheduleScalar :: Process -> Either Error Procedure
scheduleScalar
(Process { processName = name
, processParamTypes = bsParamTypes
, processParamValues = bsParamValues
, processOperators = operators
, processContexts = contexts})
= do
tK <- slurpRateOfParamTypes
$ filter isSeriesType
$ map typeOfBind bsParamValues
(case bsParamTypes of
[] -> Left ErrorNoRateParameters
BName n k : _
| k == kRate
, TVar (UName n) == tK -> return ()
_ -> Left ErrorPrimaryRateMismatch)
let bsSeries = [ b | b <- bsParamValues
, isSeriesType (typeOfBind b) ]
let ssBody
= [ BodyStmt bElem
(xNext tK tElem (XVar (UName nS)) (XVar uIndex))
| bS@(BName nS tS) <- bsSeries
, let Just tElem = elemTypeOfSeriesType tS
, let Just bElem = elemBindOfSeriesBind bS
, let uIndex = UIx 0 ]
let nest0
= NestLoop
{ nestRate = tK
, nestStart = []
, nestBody = ssBody
, nestInner = NestEmpty
, nestEnd = []
, nestResult = xUnit }
let Just nest1 = foldM insertContext nest0 contexts
nest2 <- foldM scheduleOperator nest1 operators
return $ Procedure
{ procedureName = name
, procedureParamTypes = bsParamTypes
, procedureParamValues = bsParamValues
, procedureNest = nest2 }
scheduleOperator
:: Nest
-> Operator
-> Either Error Nest
scheduleOperator nest0 op
| OpId{} <- op
= do let tK = opInputRate op
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let Just nest1
= insertBody nest0 tK
$ [ BodyStmt bResult (XVar uInput) ]
return nest1
| OpRep{} <- op
= do let tK = opOutputRate op
let BName nResult _ = opResultSeries op
let nVal = NameVarMod nResult "val"
let uVal = UName nVal
let bVal = BName nVal (opElemType op)
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just nest1
= insertStarts nest0 tK
$ [ StartStmt bVal (opInputExp op) ]
let Just nest2
= insertBody nest1 tK
$ [ BodyStmt bResult (XVar uVal) ]
return nest2
| OpReps{} <- op
= do
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just nest1
= insertBody nest0 (opOutputRate op)
$ [ BodyStmt bResult
(XVar uInput)]
return nest1
| OpIndices{} <- op
= do
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just nest1
= insertBody nest0 (opOutputRate op)
$ [ BodyStmt bResult
(XVar (UIx 1)) ]
return nest1
| OpFill{} <- op
= do let tK = opInputRate op
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let UName nVec = opTargetVector op
let Just nest1
= insertBody nest0 tK
$ [ BodyVecWrite
nVec
(opElemType op)
(XVar (UIx 0))
(XVar uInput) ]
let Just nest2
| nestContainsGuardedRate nest1 tK
= insertEnds nest1 tK
$ [ EndVecTrunc
nVec
(opElemType op)
tK ]
| otherwise
= Just nest1
return nest2
| OpGather{} <- op
= do
let tK = opInputRate op
let Just bResult = elemBindOfSeriesBind (opResultBind op)
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let Just nest1 = insertBody nest0 tK
$ [ BodyStmt bResult
(xReadVector
(opElemType op)
(XVar $ opSourceVector op)
(XVar $ uIndex)) ]
return nest1
| OpScatter{} <- op
= do
let tK = opInputRate op
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let Just uElem = elemBoundOfSeriesBound (opSourceElems op)
let Just nest1 = insertBody nest0 tK
$ [ BodyStmt (BNone tUnit)
(xWriteVector
(opElemType op)
(XVar $ opTargetVector op)
(XVar $ uIndex) (XVar $ uElem)) ]
let Just nest2 = insertEnds nest1 tK
$ [ EndStmt (opResultBind op)
xUnit ]
return nest2
| OpMap{} <- op
= do let tK = opInputRate op
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just usInput = sequence
$ map elemBoundOfSeriesBound
$ opInputSeriess op
let xBody
= foldl (\x (b, p) -> XApp (XLam b x) p)
(opWorkerBody op)
[(b, XVar u)
| b <- opWorkerParams op
| u <- usInput ]
let Just nest1
= insertBody nest0 tK
$ [ BodyStmt bResult xBody ]
return nest1
| OpPack{} <- op
= do
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let Just bResult = elemBindOfSeriesBind (opResultSeries op)
let Just nest1
= insertBody nest0 (opOutputRate op)
$ [ BodyStmt bResult
(XVar uInput)]
return nest1
| OpReduce{} <- op
= do let tK = opInputRate op
let UName nResult = opTargetRef op
let nAcc = NameVarMod nResult "acc"
let tAcc = typeOfBind (opWorkerParamAcc op)
let nAccInit = NameVarMod nResult "init"
let Just nest1
= insertStarts nest0 tK
$ [ StartStmt (BName nAccInit tAcc)
(xRead tAcc (XVar $ opTargetRef op))
, StartAcc nAcc tAcc (XVar (UName nAccInit)) ]
let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
let nAccVal = NameVarMod nResult "val"
let uAccVal = UName nAccVal
let bAccVal = BName nAccVal tAcc
let xBody x1 x2
= XApp (XApp ( XLam (opWorkerParamAcc op)
$ XLam (opWorkerParamElem op)
(opWorkerBody op))
x1)
x2
let Just nest2
= insertBody nest1 tK
$ [ BodyAccRead nAcc tAcc bAccVal
, BodyAccWrite nAcc tAcc
(xBody (XVar uAccVal)
(XVar uInput)) ]
let nAccRes = NameVarMod nResult "res"
let Just nest3
= insertEnds nest2 tK
$ [ EndAcc nAccRes tAcc nAcc
, EndStmt (BNone tUnit)
(xWrite tAcc (XVar $ opTargetRef op)
(XVar $ UName nAccRes)) ]
return nest3
| otherwise
= Left $ ErrorUnsupported op