{-# LANGUAGE TypeFamilies #-}
{-# Language FlexibleInstances, FlexibleContexts #-}
module Futhark.Representation.AST.Attributes.Aliases
( vnameAliases
, subExpAliases
, primOpAliases
, expAliases
, patternAliases
, Aliased (..)
, AliasesOf (..)
, consumedInStm
, consumedInExp
, consumedByLambda
, AliasedOp (..)
, CanBeAliased (..)
)
where
import Control.Arrow (first)
import Data.Monoid ((<>))
import qualified Data.Set as S
import Futhark.Representation.AST.Attributes (IsOp)
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes.Patterns
import Futhark.Representation.AST.Attributes.Types
class (Annotations lore, AliasedOp (Op lore),
AliasesOf (LetAttr lore)) => Aliased lore where
bodyAliases :: Body lore -> [Names]
consumedInBody :: Body lore -> Names
vnameAliases :: VName -> Names
vnameAliases = S.singleton
subExpAliases :: SubExp -> Names
subExpAliases Constant{} = mempty
subExpAliases (Var v) = vnameAliases v
primOpAliases :: BasicOp lore -> [Names]
primOpAliases (SubExp se) = [subExpAliases se]
primOpAliases (Opaque se) = [subExpAliases se]
primOpAliases (ArrayLit _ _) = [mempty]
primOpAliases BinOp{} = [mempty]
primOpAliases ConvOp{} = [mempty]
primOpAliases CmpOp{} = [mempty]
primOpAliases UnOp{} = [mempty]
primOpAliases (Index ident _) =
[vnameAliases ident]
primOpAliases Update{} =
[mempty]
primOpAliases Iota{} =
[mempty]
primOpAliases Replicate{} =
[mempty]
primOpAliases (Repeat _ _ v) =
[vnameAliases v]
primOpAliases Scratch{} =
[mempty]
primOpAliases (Reshape _ e) =
[vnameAliases e]
primOpAliases (Rearrange _ e) =
[vnameAliases e]
primOpAliases (Rotate _ e) =
[vnameAliases e]
primOpAliases Concat{} =
[mempty]
primOpAliases Copy{} =
[mempty]
primOpAliases Manifest{} =
[mempty]
primOpAliases Assert{} =
[mempty]
ifAliases :: ([Names], Names) -> ([Names], Names) -> [Names]
ifAliases (als1,cons1) (als2,cons2) =
map (S.filter notConsumed) $ zipWith mappend als1 als2
where notConsumed = not . (`S.member` cons)
cons = cons1 <> cons2
funcallAliases :: [(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases args t =
returnAliases t [(subExpAliases se, d) | (se,d) <- args ]
expAliases :: (Aliased lore) => Exp lore -> [Names]
expAliases (If _ tb fb attr) =
drop (length all_aliases - length ts) all_aliases
where ts = ifReturns attr
all_aliases = ifAliases
(bodyAliases tb, consumedInBody tb)
(bodyAliases fb, consumedInBody fb)
expAliases (BasicOp op) = primOpAliases op
expAliases (DoLoop ctxmerge valmerge _ loopbody) =
map (`S.difference` merge_names) val_aliases
where (_ctx_aliases, val_aliases) =
splitAt (length ctxmerge) $ bodyAliases loopbody
merge_names = S.fromList $ map (paramName . fst) $ ctxmerge ++ valmerge
expAliases (Apply _ args t _) =
funcallAliases args $ retTypeValues t
expAliases (Op op) = opAliases op
returnAliases :: [TypeBase shaper Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases rts args = map returnType' rts
where returnType' (Array _ _ Nonunique) =
mconcat $ map (uncurry maskAliases) args
returnType' (Array _ _ Unique) =
mempty
returnType' (Prim _) =
mempty
returnType' Mem{} =
error "returnAliases Mem"
maskAliases :: Names -> Diet -> Names
maskAliases _ Consume = mempty
maskAliases als Observe = als
consumedInStm :: Aliased lore => Stm lore -> Names
consumedInStm = consumedInExp . stmExp
consumedInExp :: (Aliased lore) => Exp lore -> Names
consumedInExp (Apply _ args _ _) =
mconcat (map (consumeArg . first subExpAliases) args)
where consumeArg (als, Consume) = als
consumeArg (_, Observe) = mempty
consumedInExp (If _ tb fb _) =
consumedInBody tb <> consumedInBody fb
consumedInExp (DoLoop _ merge _ _) =
mconcat (map (subExpAliases . snd) $
filter (unique . paramDeclType . fst) merge)
consumedInExp (BasicOp (Update src _ _)) = S.singleton src
consumedInExp (Op op) = consumedInOp op
consumedInExp _ = mempty
consumedByLambda :: Aliased lore => Lambda lore -> Names
consumedByLambda = consumedInBody . lambdaBody
patternAliases :: AliasesOf attr => PatternT attr -> [Names]
patternAliases = map (aliasesOf . patElemAttr) . patternElements
class AliasesOf a where
aliasesOf :: a -> Names
instance AliasesOf Names where
aliasesOf = id
instance AliasesOf attr => AliasesOf (PatElemT attr) where
aliasesOf = aliasesOf . patElemAttr
class IsOp op => AliasedOp op where
opAliases :: op -> [Names]
consumedInOp :: op -> Names
instance AliasedOp () where
opAliases () = []
consumedInOp () = mempty
class AliasedOp (OpWithAliases op) => CanBeAliased op where
type OpWithAliases op :: *
removeOpAliases :: OpWithAliases op -> op
addOpAliases :: op -> OpWithAliases op
instance CanBeAliased () where
type OpWithAliases () = ()
removeOpAliases = id
addOpAliases = id