module DDC.Core.Flow.Transform.Schedule.Kernel
( scheduleKernel
, Error (..)
, Lifting (..))
where
import DDC.Core.Flow.Transform.Schedule.Nest
import DDC.Core.Flow.Transform.Schedule.Error
import DDC.Core.Flow.Transform.Schedule.Lifting
import DDC.Core.Flow.Transform.Schedule.Base
import DDC.Core.Flow.Process
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Prim
import Control.Monad
import Data.Maybe
scheduleKernel :: Lifting -> Process -> Either Error Procedure
scheduleKernel
lifting
(Process { processName = name
, processParamTypes = bsParamTypes
, processParamValues = bsParamValues
, processOperators = operators })
= do
tK <- slurpRateOfParamTypes (map typeOfBind bsParamValues)
(case bsParamTypes of
[] -> Left ErrorNoRateParameters
BName n k : _
| k == kRate
, TVar (UName n) == tK -> return ()
_ -> Left ErrorPrimaryRateMismatch)
let bsParamValues_lowered
= map (\(BName n t)
-> let t' = fromMaybe t $ lowerSeriesRate lifting t
in BName n t')
$ bsParamValues
let bsSeries = [ b | b <- bsParamValues
, isSeriesType (typeOfBind b) ]
let c = liftingFactor lifting
let ssBody = [ BodyStmt
(BName (NameVarMod nS "elem") tElem_lifted)
(xNextC c tK tElem (XVar (UName nS)) (XVar uIndex))
| BName nS tS <- bsSeries
, let Just tElem = elemTypeOfSeriesType tS
, let uIndex = UIx 0
, let Just tElem_lifted = liftType lifting tElem ]
let nest0 = NestLoop
{ nestRate = tDown c tK
, nestStart = []
, nestBody = ssBody
, nestInner = NestEmpty
, nestEnd = []
, nestResult = xUnit }
nest' <- foldM (scheduleOperator lifting bsParamValues)
nest0 operators
return $ Procedure
{ procedureName = name
, procedureParamTypes = bsParamTypes
, procedureParamValues = bsParamValues_lowered
, procedureNest = nest' }
scheduleOperator
:: Lifting
-> ScalarEnv
-> Nest
-> Operator
-> Either Error Nest
scheduleOperator lifting envScalar nest op
| OpMap{} <- op
= do let c = liftingFactor lifting
let tK = opInputRate op
let tK_down = tDown c tK
let Just bResultE = elemBindOfSeriesBind (opResultSeries op)
>>= liftTypeOfBind lifting
let Just usInput = sequence
$ map elemBoundOfSeriesBound
$ opInputSeriess op
let bsParam = opWorkerParams op
bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam
let envLift = zip bsParam bsParam_lifted
xWorker_lifted <- liftWorker lifting envScalar envLift
$ opWorkerBody op
let xBody = foldl (\x (b, p) -> XApp (XLam b x) p)
(xWorker_lifted)
[(b, XVar u)
| b <- bsParam_lifted
| u <- usInput ]
let Just nest2 = insertBody nest tK_down
$ [ BodyStmt bResultE xBody ]
return nest2
| OpFill{} <- op
= do let c = liftingFactor lifting
let tK = opInputRate op
let tK_down = tDown c tK
let Just uInput = elemBoundOfSeriesBound
$ opInputSeries op
let Just nest2 = insertBody nest tK_down
$ [ BodyStmt (BNone tUnit)
(xWriteVectorC c
(opElemType op)
(XVar $ opTargetVector op)
(XVar $ UIx 0)
(XVar $ uInput)) ]
let Just nest3 = insertEnds nest2 tK_down
$ [ EndStmt (opResultBind op)
xUnit ]
return nest3
| OpReduce{} <- op
= do let c = liftingFactor lifting
let tK = opInputRate op
let tK_down = tDown c tK
let tA = typeOfBind $ opWorkerParamElem op
let UName nRef = opTargetRef op
let nAccZero = NameVarMod nRef "zero"
let bAccZero = BName nAccZero tA
let uAccZero = UName nAccZero
let nAccVec = NameVarMod nRef "vec"
let uAccVec = UName nAccVec
let Just nest2
= insertStarts nest tK_down
$ [ StartStmt bAccZero (opZero op)
, StartAcc nAccVec
(tVec c tA)
(xvRep c tA (XVar uAccZero)) ]
let Just uInput = elemBoundOfSeriesBound
$ opInputSeries op
let nAccVal = NameVarMod nRef "val"
let uAccVal = UName nAccVal
let bAccVal = BName nAccVal (tVec c tA)
let bsParam = [ opWorkerParamAcc op, opWorkerParamElem op ]
bsParam_lifted <- mapM (liftTypeOfBindM lifting) bsParam
let envLift = zip bsParam bsParam_lifted
xWorker_lifted <- liftWorker lifting envScalar envLift
$ opWorkerBody op
let xBody_lifted x1 x2
= XApp (XApp ( XLam (opWorkerParamAcc op)
$ XLam (opWorkerParamElem op)
(xWorker_lifted))
x1)
x2
let Just nest3
= insertBody nest2 tK_down
$ [ BodyAccRead nAccVec (tVec c tA) bAccVal
, BodyAccWrite nAccVec (tVec c tA)
(xBody_lifted (XVar uAccVal) (XVar uInput)) ]
let nAccResult = NameVarMod nRef "res"
let bAccResult = BName nAccResult (tVec c tA)
let uAccResult = UName nAccResult
let bPart (i :: Int) = BName (NameVarMod nAccResult (show i)) tA
let uPart (i :: Int) = UName (NameVarMod nAccResult (show i))
let nAccInit = NameVarMod nRef "init"
let xBody x1 x2
= XApp (XApp ( XLam (opWorkerParamAcc op)
$ XLam (opWorkerParamElem op)
(opWorkerBody op))
x1)
x2
let Just nest4
= insertEnds nest3 tK_down
$ [ EndStmt bAccResult
(xRead (tVec c tA) (XVar uAccVec))
, EndStmt (BName nAccInit tA)
(xRead tA (XVar $ opTargetRef op)) ]
++ [ EndStmt (bPart 0)
(xBody (XVar $ UName nAccInit)
(xvProj c 0 tA (XVar uAccResult))) ]
++ [ EndStmt (bPart i)
(xBody (XVar (uPart (i 1)))
(xvProj c i tA (XVar uAccResult)))
| i <- [1.. c 1]]
let Just nest5 = insertEnds nest4 tK_down
$ [ EndStmt (BNone tUnit)
(xWrite tA (XVar $ opTargetRef op)
(XVar $ uPart (c 1))) ]
let Just nest6
= insertEnds nest5 tK_down
$ [ EndStmt (opResultBind op)
xUnit ]
return $ nest6
| OpGather{} <- op
= do
let c = liftingFactor lifting
let tK = opInputRate op
let tK_down = tDown c tK
let Just bResultE = elemBindOfSeriesBind (opResultBind op)
>>= liftTypeOfBind lifting
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let Just nest2 = insertBody nest tK_down
$ [ BodyStmt bResultE
(xvGather c
(opElemType op)
(XVar $ opSourceVector op)
(XVar $ uIndex)) ]
return nest2
| OpScatter{} <- op
= do
let c = liftingFactor lifting
let tK = opInputRate op
let tK_down = tDown c tK
let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)
let Just uElem = elemBoundOfSeriesBound (opSourceElems op)
let Just nest2 = insertBody nest tK_down
$ [ BodyStmt (BNone tUnit)
(xvScatter c
(opElemType op)
(XVar $ opTargetVector op)
(XVar $ uIndex) (XVar $ uElem)) ]
let Just nest3 = insertEnds nest2 tK_down
$ [ EndStmt (opResultBind op)
xUnit ]
return nest3
| otherwise
= Left $ ErrorUnsupported op
liftTypeOfBindM :: Lifting -> Bind Name -> Either Error (Bind Name)
liftTypeOfBindM lifting b
= case liftTypeOfBind lifting b of
Just b' -> return b'
_ -> Left $ ErrorCannotLiftType (typeOfBind b)