module Futhark.Representation.AST.Traversals
(
Mapper(..)
, identityMapper
, mapBody
, mapExpM
, mapExp
, mapOnType
, mapOnLoopForm
, mapOnExtType
, Walker(..)
, identityWalker
, walkExpM
, walkExp
)
where
import Control.Monad
import Control.Monad.Identity
import qualified Data.Traversable
import Data.Monoid ((<>))
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes.Scope
data Mapper flore tlore m = Mapper {
mapOnSubExp :: SubExp -> m SubExp
, mapOnBody :: Scope tlore -> Body flore -> m (Body tlore)
, mapOnVName :: VName -> m VName
, mapOnCertificates :: Certificates -> m Certificates
, mapOnRetType :: RetType flore -> m (RetType tlore)
, mapOnBranchType :: BranchType flore -> m (BranchType tlore)
, mapOnFParam :: FParam flore -> m (FParam tlore)
, mapOnLParam :: LParam flore -> m (LParam tlore)
, mapOnOp :: Op flore -> m (Op tlore)
}
identityMapper :: Monad m => Mapper lore lore m
identityMapper = Mapper {
mapOnSubExp = return
, mapOnBody = const return
, mapOnVName = return
, mapOnCertificates = return
, mapOnRetType = return
, mapOnBranchType = return
, mapOnFParam = return
, mapOnLParam = return
, mapOnOp = return
}
mapBody :: (Stm lore -> Stm lore) -> Body lore -> Body lore
mapBody f (Body attr stms res) = Body attr (fmap f stms) res
mapExpM :: (Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM tv (BasicOp (SubExp se)) =
BasicOp <$> (SubExp <$> mapOnSubExp tv se)
mapExpM tv (BasicOp (ArrayLit els rowt)) =
BasicOp <$> (pure ArrayLit <*> mapM (mapOnSubExp tv) els <*>
mapOnType (mapOnSubExp tv) rowt)
mapExpM tv (BasicOp (BinOp bop x y)) =
BasicOp <$> (BinOp bop <$> mapOnSubExp tv x <*> mapOnSubExp tv y)
mapExpM tv (BasicOp (CmpOp op x y)) =
BasicOp <$> (CmpOp op <$> mapOnSubExp tv x <*> mapOnSubExp tv y)
mapExpM tv (BasicOp (ConvOp conv x)) =
BasicOp <$> (ConvOp conv <$> mapOnSubExp tv x)
mapExpM tv (BasicOp (UnOp op x)) =
BasicOp <$> (UnOp op <$> mapOnSubExp tv x)
mapExpM tv (If c texp fexp (IfAttr ts s)) =
If <$> mapOnSubExp tv c <*> mapOnBody tv mempty texp <*> mapOnBody tv mempty fexp <*>
(IfAttr <$> mapM (mapOnBranchType tv) ts <*> pure s)
mapExpM tv (Apply fname args ret loc) = do
args' <- forM args $ \(arg, d) ->
(,) <$> mapOnSubExp tv arg <*> pure d
Apply fname <$> pure args' <*> mapM (mapOnRetType tv) ret <*> pure loc
mapExpM tv (BasicOp (Index arr slice)) =
BasicOp <$> (Index <$> mapOnVName tv arr <*>
mapM (traverse (mapOnSubExp tv)) slice)
mapExpM tv (BasicOp (Update arr slice se)) =
BasicOp <$> (Update <$> mapOnVName tv arr <*>
mapM (traverse (mapOnSubExp tv)) slice <*> mapOnSubExp tv se)
mapExpM tv (BasicOp (Iota n x s et)) =
BasicOp <$> (pure Iota <*> mapOnSubExp tv n <*> mapOnSubExp tv x <*> mapOnSubExp tv s <*> pure et)
mapExpM tv (BasicOp (Replicate shape vexp)) =
BasicOp <$> (Replicate <$> mapOnShape tv shape <*> mapOnSubExp tv vexp)
mapExpM tv (BasicOp (Repeat shapes innershape v)) =
BasicOp <$> (Repeat <$> mapM (mapOnShape tv) shapes <*>
mapOnShape tv innershape <*> mapOnVName tv v)
mapExpM tv (BasicOp (Scratch t shape)) =
BasicOp <$> (Scratch t <$> mapM (mapOnSubExp tv) shape)
mapExpM tv (BasicOp (Reshape shape arrexp)) =
BasicOp <$> (Reshape <$>
mapM (Data.Traversable.traverse (mapOnSubExp tv)) shape <*>
mapOnVName tv arrexp)
mapExpM tv (BasicOp (Rearrange perm e)) =
BasicOp <$> (Rearrange <$> pure perm <*> mapOnVName tv e)
mapExpM tv (BasicOp (Rotate es e)) =
BasicOp <$> (Rotate <$> mapM (mapOnSubExp tv) es <*> mapOnVName tv e)
mapExpM tv (BasicOp (Concat i x ys size)) =
BasicOp <$> (Concat <$> pure i <*>
mapOnVName tv x <*> mapM (mapOnVName tv) ys <*>
mapOnSubExp tv size)
mapExpM tv (BasicOp (Copy e)) =
BasicOp <$> (pure Copy <*> mapOnVName tv e)
mapExpM tv (BasicOp (Manifest perm e)) =
BasicOp <$> (Manifest perm <$> mapOnVName tv e)
mapExpM tv (BasicOp (Assert e msg loc)) =
BasicOp <$> (Assert <$> mapOnSubExp tv e <*> traverse (mapOnSubExp tv) msg <*> pure loc)
mapExpM tv (BasicOp (Opaque e)) =
BasicOp <$> (Opaque <$> mapOnSubExp tv e)
mapExpM tv (DoLoop ctxmerge valmerge form loopbody) = do
ctxparams' <- mapM (mapOnFParam tv) ctxparams
valparams' <- mapM (mapOnFParam tv) valparams
form' <- mapOnLoopForm tv form
let scope = scopeOf form' <> scopeOfFParams (ctxparams'++valparams')
DoLoop <$>
(zip ctxparams' <$> mapM (mapOnSubExp tv) ctxinits) <*>
(zip valparams' <$> mapM (mapOnSubExp tv) valinits) <*>
pure form' <*> mapOnBody tv scope loopbody
where (ctxparams,ctxinits) = unzip ctxmerge
(valparams,valinits) = unzip valmerge
mapExpM tv (Op op) =
Op <$> mapOnOp tv op
mapOnShape :: Monad m => Mapper flore tlore m -> Shape -> m Shape
mapOnShape tv (Shape ds) = Shape <$> mapM (mapOnSubExp tv) ds
mapOnExtType :: Monad m =>
Mapper flore tlore m -> TypeBase ExtShape u -> m (TypeBase ExtShape u)
mapOnExtType tv (Array bt (Shape shape) u) =
Array bt <$> (Shape <$> mapM mapOnExtSize shape) <*>
return u
where mapOnExtSize (Ext x) = return $ Ext x
mapOnExtSize (Free se) = Free <$> mapOnSubExp tv se
mapOnExtType _ (Prim bt) = return $ Prim bt
mapOnExtType tv (Mem size space) = Mem <$> mapOnSubExp tv size <*> pure space
mapOnLoopForm :: Monad m =>
Mapper flore tlore m -> LoopForm flore -> m (LoopForm tlore)
mapOnLoopForm tv (ForLoop i it bound loop_vars) =
ForLoop <$> mapOnVName tv i <*> pure it <*> mapOnSubExp tv bound <*>
(zip <$> mapM (mapOnLParam tv) loop_lparams <*> mapM (mapOnVName tv) loop_arrs)
where (loop_lparams,loop_arrs) = unzip loop_vars
mapOnLoopForm tv (WhileLoop cond) =
WhileLoop <$> mapOnVName tv cond
mapExp :: Mapper flore tlore Identity -> Exp flore -> Exp tlore
mapExp m = runIdentity . mapExpM m
mapOnType :: Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType _ (Prim bt) = return $ Prim bt
mapOnType f (Mem size space) = Mem <$> f size <*> pure space
mapOnType f (Array bt shape u) =
Array bt <$> (Shape <$> mapM f (shapeDims shape)) <*> pure u
data Walker lore m = Walker {
walkOnSubExp :: SubExp -> m ()
, walkOnBody :: Body lore -> m ()
, walkOnVName :: VName -> m ()
, walkOnCertificates :: Certificates -> m ()
, walkOnRetType :: RetType lore -> m ()
, walkOnBranchType :: BranchType lore -> m ()
, walkOnFParam :: FParam lore -> m ()
, walkOnLParam :: LParam lore -> m ()
, walkOnOp :: Op lore -> m ()
}
identityWalker :: Monad m => Walker lore m
identityWalker = Walker {
walkOnSubExp = const $ return ()
, walkOnBody = const $ return ()
, walkOnVName = const $ return ()
, walkOnCertificates = const $ return ()
, walkOnRetType = const $ return ()
, walkOnBranchType = const $ return ()
, walkOnFParam = const $ return ()
, walkOnLParam = const $ return ()
, walkOnOp = const $ return ()
}
walkMapper :: Monad m => Walker lore m -> Mapper lore lore m
walkMapper f = Mapper {
mapOnSubExp = wrap walkOnSubExp
, mapOnBody = const $ wrap walkOnBody
, mapOnVName = wrap walkOnVName
, mapOnCertificates = wrap walkOnCertificates
, mapOnRetType = wrap walkOnRetType
, mapOnBranchType = wrap walkOnBranchType
, mapOnFParam = wrap walkOnFParam
, mapOnLParam = wrap walkOnLParam
, mapOnOp = wrap walkOnOp
}
where wrap op k = op f k >> return k
walkExpM :: Monad m => Walker lore m -> Exp lore -> m ()
walkExpM f = void . mapExpM m
where m = walkMapper f
walkExp :: Walker lore Identity -> Exp lore -> ()
walkExp f = runIdentity . walkExpM f