module DDC.Core.Flow.Transform.Prep
(prepModule)
where
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Prim.TyConPrim
import DDC.Core.Compounds
import DDC.Core.Module
import DDC.Core.Exp
import Control.Monad.State.Strict
import Data.Map (Map)
import qualified Data.Map as Map
import DDC.Type.Env (TypeEnv)
import qualified DDC.Type.Env as Env
prepModule
:: Module a Name
-> (Module a Name, Map Name [Type Name])
prepModule mm
= do runState (prepModuleM mm) Map.empty
prepModuleM :: Module a Name -> PrepM (Module a Name)
prepModuleM mm
= do xBody' <- prepX Env.empty $ moduleBody mm
return $ mm { moduleBody = xBody' }
prepX :: TypeEnv Name -> Exp a Name -> PrepM (Exp a Name)
prepX tenv xx
= let down = prepX tenv
in case xx of
XApp{}
| Just (XVar _ u, xsArgs) <- takeXApps xx
, UPrim (NameOpFlow (OpFlowMap n)) _ <- u
, _xTR : xsArgs2 <- xsArgs
, (xsA, xsArgs3) <- splitAt (n + 1) xsArgs2
, tsA <- [t | XType t <- xsA]
, XVar _ (UName nWorker) : _ <- xsArgs3
, Env.member (UName nWorker) tenv
-> do addWorkerArgs nWorker (take n tsA)
return xx
XApp{}
| Just (xmap@(XVar _ u), args@[_, XType tA, XType _tB, f@(XVar a _), _])
<- takeXApps xx
, UPrim (NameOpFlow (OpFlowMap 1)) _ <- u
-> do let f' = xEtaExpand a f [tA]
args' = take 3 args ++ [f'] ++ [last args]
return $ xApps a xmap args'
XApp{}
| Just (XVar _ u, [_, XType tA, XType tB, XVar _ (UName n), _, _])
<- takeXApps xx
, UPrim (NameOpFlow OpFlowFold) _ <- u
-> do addWorkerArgs n [tA, tB]
return xx
XApp{}
| Just (XVar _ u, [_, XType tA, XType tB, XVar _ (UName n), _, _])
<- takeXApps xx
, UPrim (NameOpFlow OpFlowFoldIndex) _ <- u
-> do addWorkerArgs n [tInt, tA, tB]
return xx
XApp{}
| Just (XVar _ u, [XType _tK1, XType _tA, _, XVar _ (UName n)])
<- takeXApps xx
, UPrim (NameOpFlow (OpFlowMkSel _)) _ <- u
-> do addWorkerArgs n []
return xx
XVar{} -> return xx
XCon{} -> return xx
XLAM a b x -> liftM3 XLAM (return a) (return b) (down x)
XLam a b x -> liftM3 XLam (return a) (return b) (down x)
XApp a x1 x2 -> liftM3 XApp (return a) (down x1) (down x2)
XLet a lts x
-> do
let tenv' = Env.extends (valwitBindsOfLets lts) tenv
x' <- prepX tenv' x
lts' <- prepLts tenv a lts
return $ XLet a lts' x'
XCase a x alts -> liftM3 XCase (return a) (down x) (mapM (prepAlt tenv) alts)
XCast a c x -> liftM3 XCast (return a) (return c) (down x)
XType{} -> return xx
XWitness{} -> return xx
prepLts :: TypeEnv Name -> a -> Lets a Name -> PrepM (Lets a Name)
prepLts tenv a lts
= case lts of
LLet b@(BName n _) x
-> do x' <- prepX tenv x
mArgs <- lookupWorkerArgs n
case mArgs of
Just tsArgs
| length tsArgs > 0
-> return $ LLet b $ xEtaExpand a x' tsArgs
_ -> return $ LLet b x'
LLet b x
-> do x' <- prepX tenv x
return $ LLet b x'
LRec bxs
-> do let (bs, xs) = unzip bxs
let tenv' = Env.extends bs tenv
xs' <- mapM (prepX tenv') xs
return $ LRec $ zip bs xs'
LLetRegions{} -> return lts
LWithRegion{} -> return lts
prepAlt :: TypeEnv Name -> Alt a Name -> PrepM (Alt a Name)
prepAlt tenv (AAlt w x)
= liftM (AAlt w) (prepX tenv x)
xEtaExpand :: a -> Exp a Name -> [Type Name] -> Exp a Name
xEtaExpand a x tys
= xLams a (map BAnon tys)
$ xApps a x [ XVar a (UIx (length tys 1 ix))
| ix <- [0 .. length tys 1] ]
type PrepS = Map Name [Type Name]
type PrepM = State PrepS
addWorkerArgs :: Name -> [Type Name] -> PrepM ()
addWorkerArgs name tsParam
= modify $ Map.insert name tsParam
lookupWorkerArgs :: Name -> PrepM (Maybe [Type Name])
lookupWorkerArgs name
= do names <- get
return $ Map.lookup name names