module DDC.Core.Flow.Transform.Schedule.Lifting
( Lifting (..)
, ScalarEnv
, LiftEnv
, liftType
, liftTypeOfBind
, liftWorker
, lowerSeriesRate)
where
import DDC.Core.Flow.Transform.Schedule.Error
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Prim
import Control.Monad
import Data.List
data Lifting
= Lifting
{
liftingFactor :: Int }
deriving (Eq, Show)
type ScalarEnv
= [BindF]
type LiftEnv
= [(BindF, BindF)]
liftType :: Lifting -> TypeF -> Maybe TypeF
liftType l tt
| liftingFactor l == 1
= Just tt
| elem tt
[ tFloat 32, tFloat 64
, tWord 8, tWord 16, tWord 32, tWord 64
, tInt
, tNat ]
= Just (tVec (liftingFactor l) tt)
| otherwise
= Nothing
liftTypeOfBind :: Lifting -> BindF -> Maybe BindF
liftTypeOfBind l b
= case b of
BName n t -> liftM (BName n) (liftType l t)
BAnon t -> liftM BAnon (liftType l t)
BNone t -> liftM BNone (liftType l t)
liftWorker :: Lifting -> ScalarEnv -> LiftEnv -> ExpF -> Either Error ExpF
liftWorker lifting envScalar envLift xx
= let down = liftWorker lifting envScalar envLift
in case xx of
XVar u
| Just (_, bL) <- find (\(bS', _) -> boundMatchesBind u bS') envLift
, Just uL <- takeSubstBoundOfBind bL
-> Right (XVar uL)
| any (boundMatchesBind u) envScalar
, nPrim <- PrimVecRep (liftingFactor lifting)
, tPrim <- typePrimVec nPrim
-> Right $ XApp (XApp (XVar (UPrim (NamePrimVec nPrim) tPrim))
(XType $ tFloat 32))
xx
XCon dc
| DaConPrim (NameLitFloat _ 32) _
<- dc
, nPrim <- PrimVecRep (liftingFactor lifting)
, tPrim <- typePrimVec nPrim
-> Right $ XApp (XApp (XVar (UPrim (NamePrimVec nPrim) tPrim))
(XType $ tFloat 32))
xx
XApp (XVar (UPrim (NamePrimArith prim) _)) (XType tElem)
| Just prim' <- liftPrimArithToVec (liftingFactor lifting) prim
-> Right $ XApp (XVar (UPrim (NamePrimVec prim') (typePrimVec prim')))
(XType tElem)
XApp x1 x2
-> do x1' <- down x1
x2' <- down x2
return $ XApp x1' x2'
_ -> Left (ErrorCannotLiftExp xx)
lowerSeriesRate :: Lifting -> TypeF -> Maybe TypeF
lowerSeriesRate lifting tt
| Just (NameTyConFlow TyConFlowSeries, [tK, tA])
<- takePrimTyConApps tt
, c <- liftingFactor lifting
= Just (tSeries (tDown c tK) tA)
| otherwise
= Nothing