{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.Representation.Aliases
(
Aliases
, Names' (..)
, VarAliases
, ConsumedInExp
, BodyAliasing
, module Futhark.Representation.AST.Attributes.Aliases
, module Futhark.Representation.AST.Attributes
, module Futhark.Representation.AST.Traversals
, module Futhark.Representation.AST.Pretty
, module Futhark.Representation.AST.Syntax
, addAliasesToPattern
, mkAliasedLetStm
, mkAliasedBody
, mkPatternAliases
, mkBodyAliases
, removeProgAliases
, removeFunDefAliases
, removeExpAliases
, removeBodyAliases
, removeStmAliases
, removeLambdaAliases
, removePatternAliases
, removeScopeAliases
, AliasesAndConsumed
, trackAliases
, consumedInStms
)
where
import Control.Monad.Identity
import Control.Monad.Reader
import Data.Foldable
import Data.Maybe
import Data.Monoid ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Pretty
import Futhark.Transform.Rename
import Futhark.Binder
import Futhark.Transform.Substitute
import Futhark.Analysis.Rephrase
import Futhark.Representation.AST.Attributes.Ranges()
import qualified Futhark.Util.Pretty as PP
data Aliases lore
newtype Names' = Names' { unNames :: Names }
deriving (Show)
instance Semigroup Names' where
x <> y = Names' $ unNames x <> unNames y
instance Monoid Names' where
mempty = Names' mempty
instance Eq Names' where
_ == _ = True
instance Ord Names' where
_ `compare` _ = EQ
instance Rename Names' where
rename (Names' names) = Names' <$> rename names
instance Substitute Names' where
substituteNames substs (Names' names) = Names' $ substituteNames substs names
instance FreeIn Names' where
freeIn = const mempty
instance PP.Pretty Names' where
ppr = PP.commasep . map PP.ppr . S.toList . unNames
type VarAliases = Names'
type ConsumedInExp = Names'
type BodyAliasing = ([VarAliases], ConsumedInExp)
instance (Annotations lore, CanBeAliased (Op lore)) =>
Annotations (Aliases lore) where
type LetAttr (Aliases lore) = (VarAliases, LetAttr lore)
type ExpAttr (Aliases lore) = (ConsumedInExp, ExpAttr lore)
type BodyAttr (Aliases lore) = (BodyAliasing, BodyAttr lore)
type FParamAttr (Aliases lore) = FParamAttr lore
type LParamAttr (Aliases lore) = LParamAttr lore
type RetType (Aliases lore) = RetType lore
type BranchType (Aliases lore) = BranchType lore
type Op (Aliases lore) = OpWithAliases (Op lore)
instance AliasesOf (VarAliases, attr) where
aliasesOf = unNames . fst
instance FreeAttr Names' where
withoutAliases :: (HasScope (Aliases lore) m, Monad m) =>
ReaderT (Scope lore) m a -> m a
withoutAliases m = do
scope <- asksScope removeScopeAliases
runReaderT m scope
instance (Attributes lore, CanBeAliased (Op lore)) => Attributes (Aliases lore) where
expTypesFromPattern =
withoutAliases . expTypesFromPattern . removePatternAliases
instance (Attributes lore, CanBeAliased (Op lore)) => Aliased (Aliases lore) where
bodyAliases = map unNames . fst . fst . bodyAttr
consumedInBody = unNames . snd . fst . bodyAttr
instance PrettyAnnot (PatElemT attr) =>
PrettyAnnot (PatElemT (VarAliases, attr)) where
ppAnnot (PatElem name (Names' als, attr)) =
let alias_comment = PP.oneLine <$> aliasComment name als
in case (alias_comment, ppAnnot (PatElem name attr)) of
(_, Nothing) ->
alias_comment
(Just alias_comment', Just inner_comment) ->
Just $ alias_comment' PP.</> inner_comment
(Nothing, Just inner_comment) ->
Just inner_comment
instance (Attributes lore, CanBeAliased (Op lore)) => PrettyLore (Aliases lore) where
ppExpLore (consumed, inner) e =
maybeComment $ catMaybes [expAttr,
mergeAttr,
ppExpLore inner $ removeExpAliases e]
where mergeAttr =
case e of
DoLoop _ merge _ body ->
let mergeParamAliases fparam als
| primType (paramType fparam) =
Nothing
| otherwise =
resultAliasComment (paramName fparam) als
in maybeComment $ catMaybes $
zipWith mergeParamAliases (map fst merge) $
bodyAliases body
_ -> Nothing
expAttr = case S.toList $ unNames consumed of
[] -> Nothing
als -> Just $ PP.oneLine $
PP.text "-- Consumes " <> PP.commasep (map PP.ppr als)
maybeComment :: [PP.Doc] -> Maybe PP.Doc
maybeComment [] = Nothing
maybeComment cs = Just $ PP.folddoc (PP.</>) cs
aliasComment :: (PP.Pretty a, PP.Pretty b) =>
a -> S.Set b -> Maybe PP.Doc
aliasComment name als =
case S.toList als of
[] -> Nothing
als' -> Just $ PP.oneLine $
PP.text "-- " <> PP.ppr name <> PP.text " aliases " <>
PP.commasep (map PP.ppr als')
resultAliasComment :: (PP.Pretty a, PP.Pretty b) =>
a -> S.Set b -> Maybe PP.Doc
resultAliasComment name als =
case S.toList als of
[] -> Nothing
als' -> Just $ PP.oneLine $
PP.text "-- Result of " <> PP.ppr name <> PP.text " aliases " <>
PP.commasep (map PP.ppr als')
removeAliases :: CanBeAliased (Op lore) => Rephraser Identity (Aliases lore) lore
removeAliases = Rephraser { rephraseExpLore = return . snd
, rephraseLetBoundLore = return . snd
, rephraseBodyLore = return . snd
, rephraseFParamLore = return
, rephraseLParamLore = return
, rephraseRetType = return
, rephraseBranchType = return
, rephraseOp = return . removeOpAliases
}
removeScopeAliases :: Scope (Aliases lore) -> Scope lore
removeScopeAliases = M.map unAlias
where unAlias (LetInfo (_, attr)) = LetInfo attr
unAlias (FParamInfo attr) = FParamInfo attr
unAlias (LParamInfo attr) = LParamInfo attr
unAlias (IndexInfo it) = IndexInfo it
removeProgAliases :: CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases = runIdentity . rephraseProg removeAliases
removeFunDefAliases :: CanBeAliased (Op lore) =>
FunDef (Aliases lore) -> FunDef lore
removeFunDefAliases = runIdentity . rephraseFunDef removeAliases
removeExpAliases :: CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases = runIdentity . rephraseExp removeAliases
removeBodyAliases :: CanBeAliased (Op lore) =>
Body (Aliases lore) -> Body lore
removeBodyAliases = runIdentity . rephraseBody removeAliases
removeStmAliases :: CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases = runIdentity . rephraseStm removeAliases
removeLambdaAliases :: CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases = runIdentity . rephraseLambda removeAliases
removePatternAliases :: PatternT (Names', a)
-> PatternT a
removePatternAliases = runIdentity . rephrasePattern (return . snd)
addAliasesToPattern :: (Attributes lore, CanBeAliased (Op lore), Typed attr) =>
PatternT attr -> Exp (Aliases lore)
-> PatternT (VarAliases, attr)
addAliasesToPattern pat e =
uncurry Pattern $ mkPatternAliases pat e
mkAliasedBody :: (Attributes lore, CanBeAliased (Op lore)) =>
BodyAttr lore -> Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkAliasedBody innerlore bnds res =
Body (mkBodyAliases bnds res, innerlore) bnds res
mkPatternAliases :: (Attributes lore, Aliased lore, Typed attr) =>
PatternT attr -> Exp lore
-> ([PatElemT (VarAliases, attr)],
[PatElemT (VarAliases, attr)])
mkPatternAliases pat e =
let als = expAliases e ++ repeat mempty
context_als = mkContextAliases pat e
in (zipWith annotateBindee (patternContextElements pat) context_als,
zipWith annotateBindee (patternValueElements pat) als)
where annotateBindee bindee names =
bindee `setPatElemLore` (Names' names', patElemAttr bindee)
where names' =
case patElemType bindee of
Array {} -> names
Mem _ _ -> names
_ -> mempty
mkContextAliases :: (Attributes lore, Aliased lore) =>
PatternT attr -> Exp lore
-> [Names]
mkContextAliases pat (DoLoop ctxmerge valmerge _ body) =
let ctx = loopResultContext (map fst ctxmerge) (map fst valmerge)
init_als = zip mergenames $ map (subExpAliases . snd) $ ctxmerge ++ valmerge
expand als = als <> S.unions (mapMaybe (`lookup` init_als) (S.toList als))
merge_als = zip mergenames $
map ((`S.difference` mergenames_set) . expand) $
bodyAliases body
in if length ctx == length (patternContextElements pat)
then map (fromMaybe mempty . flip lookup merge_als . paramName) ctx
else map (const mempty) $ patternContextElements pat
where mergenames = map (paramName . fst) $ ctxmerge ++ valmerge
mergenames_set = S.fromList mergenames
mkContextAliases pat (If _ tbranch fbranch _) =
take (length $ patternContextNames pat) $
zipWith (<>) (bodyAliases tbranch) (bodyAliases fbranch)
mkContextAliases pat _ =
replicate (length $ patternContextElements pat) mempty
mkBodyAliases :: Aliased lore =>
Stms lore
-> Result
-> BodyAliasing
mkBodyAliases bnds res =
let (aliases, consumed) = mkStmsAliases bnds res
boundNames =
fold $ fmap (S.fromList . patternNames . stmPattern) bnds
bound = (`S.member` boundNames)
aliases' = map (S.filter (not . bound)) aliases
consumed' = S.filter (not . bound) consumed
in (map Names' aliases', Names' consumed')
mkStmsAliases :: Aliased lore =>
Stms lore -> [SubExp]
-> ([Names], Names)
mkStmsAliases bnds res = delve mempty $ stmsToList bnds
where delve (aliasmap, consumed) [] =
(map (aliasClosure aliasmap . subExpAliases) res,
consumed)
delve (aliasmap, consumed) (bnd:bnds') =
delve (trackAliases (aliasmap, consumed) bnd) bnds'
aliasClosure aliasmap names =
names `S.union` mconcat (map look $ S.toList names)
where look k = M.findWithDefault mempty k aliasmap
consumedInStms :: Aliased lore => Stms lore -> [SubExp] -> Names
consumedInStms bnds res = snd $ mkStmsAliases bnds res
type AliasesAndConsumed = (M.Map VName Names,
Names)
trackAliases :: Aliased lore =>
AliasesAndConsumed -> Stm lore
-> AliasesAndConsumed
trackAliases (aliasmap, consumed) bnd =
let pat = stmPattern bnd
als = M.fromList $
zip (patternNames pat) (map addAliasesOfAliases $ patternAliases pat)
aliasmap' = als <> aliasmap
consumed' = consumed <> addAliasesOfAliases (consumedInStm bnd)
in (aliasmap', consumed')
where addAliasesOfAliases names = names <> aliasesOfAliases names
aliasesOfAliases = mconcat . map look . S.toList
look k = M.findWithDefault mempty k aliasmap
mkAliasedLetStm :: (Attributes lore, CanBeAliased (Op lore)) =>
Pattern lore
-> StmAux (ExpAttr lore) -> Exp (Aliases lore)
-> Stm (Aliases lore)
mkAliasedLetStm pat (StmAux cs attr) e =
Let (addAliasesToPattern pat e)
(StmAux cs (Names' $ consumedInExp e, attr))
e
instance (Bindable lore, CanBeAliased (Op lore)) => Bindable (Aliases lore) where
mkExpAttr pat e =
let attr = mkExpAttr (removePatternAliases pat) $ removeExpAliases e
in (Names' $ consumedInExp e, attr)
mkExpPat ctx val e =
addAliasesToPattern (mkExpPat ctx val $ removeExpAliases e) e
mkLetNames names e = do
env <- asksScope removeScopeAliases
flip runReaderT env $ do
Let pat attr _ <- mkLetNames names $ removeExpAliases e
return $ mkAliasedLetStm pat attr e
mkBody bnds res =
let Body bodylore _ _ = mkBody (fmap removeStmAliases bnds) res
in mkAliasedBody bodylore bnds res
instance (Attributes (Aliases lore), Bindable (Aliases lore)) => BinderOps (Aliases lore) where
mkBodyB = bindableMkBodyB
mkExpAttrB = bindableMkExpAttrB
mkLetNamesB = bindableMkLetNamesB