{-# LANGUAGE FlexibleContexts #-}

-- | Alias analysis of a full Futhark program.  Takes as input a
-- program with an arbitrary lore and produces one with aliases.  This
-- module does not implement the aliasing logic itself, and derives
-- its information from definitions in
-- "Futhark.IR.Prop.Aliases" and
-- "Futhark.IR.Aliases".  The alias information computed
-- here will include transitive aliases (note that this is not what
-- the building blocks do).
module Futhark.Analysis.Alias
  ( aliasAnalysis,

    -- * Ad-hoc utilities
    analyseFun,
    analyseStms,
    analyseExp,
    analyseBody,
    analyseLambda,
  )
where

import Data.List (foldl')
import qualified Data.Map as M
import Futhark.IR.Aliases

-- | Perform alias analysis on a Futhark program.
aliasAnalysis ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  Prog lore ->
  Prog (Aliases lore)
aliasAnalysis :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
aliasAnalysis (Prog Stms lore
consts [FunDef lore]
funs) =
  Stms (Aliases lore)
-> [FunDef (Aliases lore)] -> Prog (Aliases lore)
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog ((Stms (Aliases lore), AliasesAndConsumed) -> Stms (Aliases lore)
forall a b. (a, b) -> a
fst (AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
analyseStms AliasTable
forall a. Monoid a => a
mempty Stms lore
consts)) ((FunDef lore -> FunDef (Aliases lore))
-> [FunDef lore] -> [FunDef (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map FunDef lore -> FunDef (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
FunDef lore -> FunDef (Aliases lore)
analyseFun [FunDef lore]
funs)

analyseFun ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  FunDef lore ->
  FunDef (Aliases lore)
analyseFun :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
FunDef lore -> FunDef (Aliases lore)
analyseFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType lore]
restype [FParam lore]
params BodyT lore
body) =
  Maybe EntryPoint
-> Attrs
-> Name
-> [RetType (Aliases lore)]
-> [FParam (Aliases lore)]
-> BodyT (Aliases lore)
-> FunDef (Aliases lore)
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType lore]
[RetType (Aliases lore)]
restype [FParam lore]
[FParam (Aliases lore)]
params BodyT (Aliases lore)
body'
  where
    body' :: BodyT (Aliases lore)
body' = AliasTable -> BodyT lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
analyseBody AliasTable
forall a. Monoid a => a
mempty BodyT lore
body

analyseBody ::
  ( ASTLore lore,
    CanBeAliased (Op lore)
  ) =>
  AliasTable ->
  Body lore ->
  Body (Aliases lore)
analyseBody :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
analyseBody AliasTable
atable (Body BodyDec lore
lore Stms lore
stms Result
result) =
  let (Stms (Aliases lore)
stms', AliasesAndConsumed
_atable') = AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
analyseStms AliasTable
atable Stms lore
stms
   in BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkAliasedBody BodyDec lore
lore Stms (Aliases lore)
stms' Result
result

analyseStms ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  AliasTable ->
  Stms lore ->
  (Stms (Aliases lore), AliasesAndConsumed)
analyseStms :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
analyseStms AliasTable
orig_aliases =
  ((Stms (Aliases lore), AliasesAndConsumed)
 -> Stm lore -> (Stms (Aliases lore), AliasesAndConsumed))
-> (Stms (Aliases lore), AliasesAndConsumed)
-> [Stm lore]
-> (Stms (Aliases lore), AliasesAndConsumed)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Stms (Aliases lore), AliasesAndConsumed)
-> Stm lore -> (Stms (Aliases lore), AliasesAndConsumed)
forall {lore}.
(ASTLore lore, CanBeAliased (Op lore)) =>
(Stms (Aliases lore), AliasesAndConsumed)
-> Stm lore -> (Stms (Aliases lore), AliasesAndConsumed)
f (Stms (Aliases lore)
forall a. Monoid a => a
mempty, (AliasTable
orig_aliases, Names
forall a. Monoid a => a
mempty)) ([Stm lore] -> (Stms (Aliases lore), AliasesAndConsumed))
-> (Stms lore -> [Stm lore])
-> Stms lore
-> (Stms (Aliases lore), AliasesAndConsumed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList
  where
    f :: (Stms (Aliases lore), AliasesAndConsumed)
-> Stm lore -> (Stms (Aliases lore), AliasesAndConsumed)
f (Stms (Aliases lore)
stms, AliasesAndConsumed
aliases) Stm lore
stm =
      let stm' :: Stm (Aliases lore)
stm' = AliasTable -> Stm lore -> Stm (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Stm lore -> Stm (Aliases lore)
analyseStm (AliasesAndConsumed -> AliasTable
forall a b. (a, b) -> a
fst AliasesAndConsumed
aliases) Stm lore
stm
          atable' :: AliasesAndConsumed
atable' = AliasesAndConsumed -> Stm (Aliases lore) -> AliasesAndConsumed
forall lore.
Aliased lore =>
AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
trackAliases AliasesAndConsumed
aliases Stm (Aliases lore)
stm'
       in (Stms (Aliases lore)
stms Stms (Aliases lore) -> Stms (Aliases lore) -> Stms (Aliases lore)
forall a. Semigroup a => a -> a -> a
<> Stm (Aliases lore) -> Stms (Aliases lore)
forall lore. Stm lore -> Stms lore
oneStm Stm (Aliases lore)
stm', AliasesAndConsumed
atable')

analyseStm ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  AliasTable ->
  Stm lore ->
  Stm (Aliases lore)
analyseStm :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Stm lore -> Stm (Aliases lore)
analyseStm AliasTable
aliases (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
dec) Exp lore
e) =
  let e' :: Exp (Aliases lore)
e' = AliasTable -> Exp lore -> Exp (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Exp lore -> Exp (Aliases lore)
analyseExp AliasTable
aliases Exp lore
e
      pat' :: PatternT (VarAliases, LetDec lore)
pat' = Pattern lore
-> Exp (Aliases lore) -> PatternT (VarAliases, LetDec lore)
forall lore dec.
(ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
PatternT dec -> Exp (Aliases lore) -> PatternT (VarAliases, dec)
addAliasesToPattern Pattern lore
pat Exp (Aliases lore)
e'
      lore' :: (VarAliases, ExpDec lore)
lore' = (Names -> VarAliases
AliasDec (Names -> VarAliases) -> Names -> VarAliases
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp Exp (Aliases lore)
e', ExpDec lore
dec)
   in Pattern (Aliases lore)
-> StmAux (ExpDec (Aliases lore))
-> Exp (Aliases lore)
-> Stm (Aliases lore)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (VarAliases, LetDec lore)
Pattern (Aliases lore)
pat' (Certificates
-> Attrs
-> (VarAliases, ExpDec lore)
-> StmAux (VarAliases, ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs (VarAliases, ExpDec lore)
lore') Exp (Aliases lore)
e'

analyseExp ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  AliasTable ->
  Exp lore ->
  Exp (Aliases lore)
-- Would be better to put this in a BranchType annotation, but that
-- requires a lot of other work.
analyseExp :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Exp lore -> Exp (Aliases lore)
analyseExp AliasTable
aliases (If SubExp
cond BodyT lore
tb BodyT lore
fb IfDec (BranchType lore)
dec) =
  let Body (([VarAliases]
tb_als, VarAliases
tb_cons), BodyDec lore
tb_dec) Stms (Aliases lore)
tb_stms Result
tb_res = AliasTable -> BodyT lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
analyseBody AliasTable
aliases BodyT lore
tb
      Body (([VarAliases]
fb_als, VarAliases
fb_cons), BodyDec lore
fb_dec) Stms (Aliases lore)
fb_stms Result
fb_res = AliasTable -> BodyT lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
analyseBody AliasTable
aliases BodyT lore
fb
      cons :: VarAliases
cons = VarAliases
tb_cons VarAliases -> VarAliases -> VarAliases
forall a. Semigroup a => a -> a -> a
<> VarAliases
fb_cons
      isConsumed :: VName -> Bool
isConsumed VName
v =
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` VarAliases -> Names
unAliases VarAliases
cons) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
          VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: Names -> [VName]
namesToList (Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v AliasTable
aliases)
      notConsumed :: VarAliases -> VarAliases
notConsumed =
        Names -> VarAliases
AliasDec (Names -> VarAliases)
-> (VarAliases -> Names) -> VarAliases -> VarAliases
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Names
namesFromList
          ([VName] -> Names)
-> (VarAliases -> [VName]) -> VarAliases -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isConsumed)
          ([VName] -> [VName])
-> (VarAliases -> [VName]) -> VarAliases -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
          (Names -> [VName])
-> (VarAliases -> Names) -> VarAliases -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarAliases -> Names
unAliases
      tb_als' :: [VarAliases]
tb_als' = (VarAliases -> VarAliases) -> [VarAliases] -> [VarAliases]
forall a b. (a -> b) -> [a] -> [b]
map VarAliases -> VarAliases
notConsumed [VarAliases]
tb_als
      fb_als' :: [VarAliases]
fb_als' = (VarAliases -> VarAliases) -> [VarAliases] -> [VarAliases]
forall a b. (a -> b) -> [a] -> [b]
map VarAliases -> VarAliases
notConsumed [VarAliases]
fb_als
      tb' :: BodyT (Aliases lore)
tb' = BodyDec (Aliases lore)
-> Stms (Aliases lore) -> Result -> BodyT (Aliases lore)
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body (([VarAliases]
tb_als', VarAliases
tb_cons), BodyDec lore
tb_dec) Stms (Aliases lore)
tb_stms Result
tb_res
      fb' :: BodyT (Aliases lore)
fb' = BodyDec (Aliases lore)
-> Stms (Aliases lore) -> Result -> BodyT (Aliases lore)
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body (([VarAliases]
fb_als', VarAliases
fb_cons), BodyDec lore
fb_dec) Stms (Aliases lore)
fb_stms Result
fb_res
   in SubExp
-> BodyT (Aliases lore)
-> BodyT (Aliases lore)
-> IfDec (BranchType (Aliases lore))
-> ExpT (Aliases lore)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT (Aliases lore)
tb' BodyT (Aliases lore)
fb' IfDec (BranchType lore)
IfDec (BranchType (Aliases lore))
dec
analyseExp AliasTable
aliases ExpT lore
e = Mapper lore (Aliases lore) Identity
-> ExpT lore -> ExpT (Aliases lore)
forall flore tlore.
Mapper flore tlore Identity -> Exp flore -> Exp tlore
mapExp Mapper lore (Aliases lore) Identity
analyse ExpT lore
e
  where
    analyse :: Mapper lore (Aliases lore) Identity
analyse =
      Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper
        { mapOnSubExp :: SubExp -> Identity SubExp
mapOnSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnVName :: VName -> Identity VName
mapOnVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnBody :: Scope (Aliases lore)
-> BodyT lore -> Identity (BodyT (Aliases lore))
mapOnBody = (BodyT lore -> Identity (BodyT (Aliases lore)))
-> Scope (Aliases lore)
-> BodyT lore
-> Identity (BodyT (Aliases lore))
forall a b. a -> b -> a
const ((BodyT lore -> Identity (BodyT (Aliases lore)))
 -> Scope (Aliases lore)
 -> BodyT lore
 -> Identity (BodyT (Aliases lore)))
-> (BodyT lore -> Identity (BodyT (Aliases lore)))
-> Scope (Aliases lore)
-> BodyT lore
-> Identity (BodyT (Aliases lore))
forall a b. (a -> b) -> a -> b
$ BodyT (Aliases lore) -> Identity (BodyT (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT (Aliases lore) -> Identity (BodyT (Aliases lore)))
-> (BodyT lore -> BodyT (Aliases lore))
-> BodyT lore
-> Identity (BodyT (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable -> BodyT lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
analyseBody AliasTable
aliases,
          mapOnRetType :: RetType lore -> Identity (RetType (Aliases lore))
mapOnRetType = RetType lore -> Identity (RetType (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnBranchType :: BranchType lore -> Identity (BranchType (Aliases lore))
mapOnBranchType = BranchType lore -> Identity (BranchType (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnFParam :: FParam lore -> Identity (FParam (Aliases lore))
mapOnFParam = FParam lore -> Identity (FParam (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnLParam :: LParam lore -> Identity (LParam (Aliases lore))
mapOnLParam = LParam lore -> Identity (LParam (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnOp :: Op lore -> Identity (Op (Aliases lore))
mapOnOp = OpWithAliases (Op lore) -> Identity (OpWithAliases (Op lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithAliases (Op lore) -> Identity (OpWithAliases (Op lore)))
-> (Op lore -> OpWithAliases (Op lore))
-> Op lore
-> Identity (OpWithAliases (Op lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable -> Op lore -> OpWithAliases (Op lore)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases
        }

analyseLambda ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  AliasTable ->
  Lambda lore ->
  Lambda (Aliases lore)
analyseLambda :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Lambda lore -> Lambda (Aliases lore)
analyseLambda AliasTable
aliases Lambda lore
lam =
  let body :: Body (Aliases lore)
body = AliasTable -> Body lore -> Body (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
analyseBody AliasTable
aliases (Body lore -> Body (Aliases lore))
-> Body lore -> Body (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
   in Lambda lore
lam
        { lambdaBody :: Body (Aliases lore)
lambdaBody = Body (Aliases lore)
body,
          lambdaParams :: [LParam (Aliases lore)]
lambdaParams = Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
        }