{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.Transform.Rename
(
renameProg
, renameExp
, renameStm
, renameBody
, renameLambda
, renameFun
, renamePattern
, RenameM
, substituteRename
, bindingForRename
, renamingStms
, Rename (..)
, Renameable
)
where
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Attributes.Patterns
import Futhark.FreshNames
import Futhark.MonadFreshNames (MonadFreshNames(..), modifyNameSource)
import Futhark.Transform.Substitute
runRenamer :: RenameM a -> VNameSource -> (a, VNameSource)
runRenamer m src = runReader (runStateT m src) env
where env = RenameEnv M.empty newName
renameProg :: (Renameable lore, MonadFreshNames m) =>
Prog lore -> m (Prog lore)
renameProg prog = modifyNameSource $
runRenamer $ Prog <$> mapM rename (progFunctions prog)
renameExp :: (Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp = modifyNameSource . runRenamer . rename
renameStm :: (Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm binding = do
e <- renameExp $ stmExp binding
return binding { stmExp = e }
renameBody :: (Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody = modifyNameSource . runRenamer . rename
renameLambda :: (Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda = modifyNameSource . runRenamer . rename
renameFun :: (Renameable lore, MonadFreshNames m) =>
FunDef lore -> m (FunDef lore)
renameFun = modifyNameSource . runRenamer . rename
renamePattern :: (Rename attr, MonadFreshNames m) =>
PatternT attr -> m (PatternT attr)
renamePattern = modifyNameSource . runRenamer . rename'
where rename' pat = bind (patternNames pat) $ rename pat
data RenameEnv = RenameEnv {
envNameMap :: M.Map VName VName
, envNameFn :: VNameSource -> VName -> (VName, VNameSource)
}
type RenameM = StateT VNameSource (Reader RenameEnv)
renamerSubstitutions :: RenameM Substitutions
renamerSubstitutions = lift $ asks envNameMap
substituteRename :: Substitute a => a -> RenameM a
substituteRename x = do
substs <- renamerSubstitutions
return $ substituteNames substs x
new :: VName -> RenameM VName
new k = do (k', src') <- asks envNameFn <*> get <*> pure k
put src'
return k'
class Rename a where
rename :: a -> RenameM a
instance Rename VName where
rename name = fromMaybe name <$>
asks (M.lookup name . envNameMap)
instance Rename a => Rename [a] where
rename = mapM rename
instance (Rename a, Rename b) => Rename (a,b) where
rename (a,b) = (,) <$> rename a <*> rename b
instance (Rename a, Rename b, Rename c) => Rename (a,b,c) where
rename (a,b,c) = do
a' <- rename a
b' <- rename b
c' <- rename c
return (a',b',c')
instance Rename a => Rename (Maybe a) where
rename = maybe (return Nothing) (fmap Just . rename)
instance Rename Bool where
rename = return
instance Rename Ident where
rename (Ident name tp) = do
name' <- rename name
tp' <- rename tp
return $ Ident name' tp'
bindingForRename :: [VName] -> RenameM a -> RenameM a
bindingForRename = bind
bind :: [VName] -> RenameM a -> RenameM a
bind vars body = do
vars' <- mapM new vars
local (bind' vars') body
where bind' vars' env = env { envNameMap = M.fromList (zip vars vars')
`M.union` envNameMap env }
renamingStms :: Renameable lore => Stms lore -> (Stms lore -> RenameM a) -> RenameM a
renamingStms stms m = descend mempty stms
where descend stms' rem_stms = case stmsHead rem_stms of
Nothing -> m stms'
Just (stm, rem_stms') -> bind (patternNames $ stmPattern stm) $ do
stm' <- rename stm
descend (stms' <> oneStm stm') rem_stms'
instance Renameable lore => Rename (FunDef lore) where
rename (FunDef entry fname ret params body) =
bind (map paramName params) $ do
params' <- mapM rename params
body' <- rename body
ret' <- rename ret
return $ FunDef entry fname ret' params' body'
instance Rename SubExp where
rename (Var v) = Var <$> rename v
rename (Constant v) = return $ Constant v
instance Rename attr => Rename (ParamT attr) where
rename (Param name attr) = Param <$> rename name <*> rename attr
instance Rename attr => Rename (PatternT attr) where
rename (Pattern context values) = Pattern <$> rename context <*> rename values
instance Rename attr => Rename (PatElemT attr) where
rename (PatElem ident attr) = PatElem <$> rename ident <*> rename attr
instance Rename Certificates where
rename (Certificates cs) = Certificates <$> rename cs
instance Rename attr => Rename (StmAux attr) where
rename (StmAux cs attr) =
StmAux <$> rename cs <*> rename attr
instance Renameable lore => Rename (Body lore) where
rename (Body attr stms res) = do
attr' <- rename attr
renamingStms stms $ \stms' ->
Body attr' stms' <$> rename res
instance Renameable lore => Rename (Stm lore) where
rename (Let pat elore e) = Let <$> rename pat <*> rename elore <*> rename e
instance Renameable lore => Rename (Exp lore) where
rename (DoLoop ctx val form loopbody) = do
let (ctxparams, ctxinit) = unzip ctx
(valparams, valinit) = unzip val
ctxinit' <- mapM rename ctxinit
valinit' <- mapM rename valinit
case form of
ForLoop loopvar it boundexp loop_vars -> do
let (loop_params, loop_arrs) = unzip loop_vars
boundexp' <- rename boundexp
loop_arrs' <- rename loop_arrs
bind (map paramName (ctxparams++valparams) ++
map paramName loop_params) $ do
ctxparams' <- mapM rename ctxparams
valparams' <- mapM rename valparams
loop_params' <- mapM rename loop_params
bind [loopvar] $ do
loopvar' <- rename loopvar
loopbody' <- rename loopbody
return $ DoLoop
(zip ctxparams' ctxinit') (zip valparams' valinit')
(ForLoop loopvar' it boundexp' $
zip loop_params' loop_arrs') loopbody'
WhileLoop cond ->
bind (map paramName $ ctxparams++valparams) $ do
ctxparams' <- mapM rename ctxparams
valparams' <- mapM rename valparams
loopbody' <- rename loopbody
cond' <- rename cond
return $ DoLoop
(zip ctxparams' ctxinit') (zip valparams' valinit')
(WhileLoop cond') loopbody'
rename e = mapExpM mapper e
where mapper = Mapper {
mapOnBody = const rename
, mapOnSubExp = rename
, mapOnVName = rename
, mapOnCertificates = rename
, mapOnRetType = rename
, mapOnBranchType = rename
, mapOnFParam = rename
, mapOnLParam = rename
, mapOnOp = rename
}
instance Rename shape =>
Rename (TypeBase shape u) where
rename (Array et size u) = do
size' <- rename size
return $ Array et size' u
rename (Prim et) = return $ Prim et
rename (Mem e space) = Mem <$> rename e <*> pure space
instance Renameable lore => Rename (Lambda lore) where
rename (Lambda params body ret) =
bind (map paramName params) $ do
params' <- mapM rename params
body' <- rename body
ret' <- mapM rename ret
return $ Lambda params' body' ret'
instance Rename Names where
rename = fmap S.fromList . mapM rename . S.toList
instance Rename Rank where
rename = return
instance Rename d => Rename (ShapeBase d) where
rename (Shape l) = Shape <$> mapM rename l
instance Rename ExtSize where
rename (Free se) = Free <$> rename se
rename (Ext x) = return $ Ext x
instance Rename () where
rename = return
instance Rename d => Rename (DimIndex d) where
rename (DimFix i) = DimFix <$> rename i
rename (DimSlice i n s) = DimSlice <$> rename i <*> rename n <*> rename s
type Renameable lore = (Rename (LetAttr lore),
Rename (ExpAttr lore),
Rename (BodyAttr lore),
Rename (FParamAttr lore),
Rename (LParamAttr lore),
Rename (RetType lore),
Rename (BranchType lore),
Rename (Op lore))