{-# LANGUAGE TypeFamilies #-}

-- | We require that entry points return arrays with zero offset in
-- row-major order.  "Futhark.Pass.ExplicitAllocations" is
-- conservative and inserts copies to ensure this is the case.  After
-- simplification, it may turn out that those copies are redundant.
-- This pass removes them.  It's a pretty simple pass, as it only has
-- to look at the top level of entry points.
module Futhark.Optimise.EntryPointMem
  ( entryPointMemGPU,
    entryPointMemMC,
    entryPointMemSeq,
  )
where

import Data.List (find)
import Data.Map.Strict qualified as M
import Futhark.IR.GPUMem (GPUMem)
import Futhark.IR.MCMem (MCMem)
import Futhark.IR.Mem
import Futhark.IR.SeqMem (SeqMem)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute

type Table rep = M.Map VName (Stm rep)

mkTable :: Stms rep -> Table rep
mkTable :: forall rep. Stms rep -> Table rep
mkTable = (Stm rep -> Table rep) -> Seq (Stm rep) -> Table rep
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Table rep
forall {rep}. Stm rep -> Map VName (Stm rep)
f
  where
    f :: Stm rep -> Map VName (Stm rep)
f Stm rep
stm = [(VName, Stm rep)] -> Map VName (Stm rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Stm rep)] -> Map VName (Stm rep))
-> [(VName, Stm rep)] -> Map VName (Stm rep)
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, Stm rep)) -> [VName] -> [(VName, Stm rep)]
forall a b. (a -> b) -> [a] -> [b]
map (,Stm rep
stm) (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm))

varInfo :: (Mem rep inner) => VName -> Table rep -> Maybe (LetDecMem, Exp rep)
varInfo :: forall rep (inner :: * -> *).
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v Table rep
table = do
  Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e <- VName -> Table rep -> Maybe (Stm rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Table rep
table
  PatElem VName
_ LetDec rep
info <- (PatElem (LetDec rep) -> Bool)
-> [PatElem (LetDec rep)] -> Maybe (PatElem (LetDec rep))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (PatElem (LetDec rep) -> VName) -> PatElem (LetDec rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
  (LParamMem, Exp rep) -> Maybe (LParamMem, Exp rep)
forall a. a -> Maybe a
Just (LetDec rep -> LParamMem
forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
info, Exp rep
e)

optimiseFun :: (Mem rep inner) => Table rep -> FunDef rep -> FunDef rep
optimiseFun :: forall rep (inner :: * -> *).
Mem rep inner =>
Table rep -> FunDef rep -> FunDef rep
optimiseFun Table rep
consts_table FunDef rep
fd =
  FunDef rep
fd {funDefBody :: Body rep
funDefBody = Body rep -> Body rep
onBody (Body rep -> Body rep) -> Body rep -> Body rep
forall a b. (a -> b) -> a -> b
$ FunDef rep -> Body rep
forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
fd}
  where
    table :: Table rep
table = Table rep
consts_table Table rep -> Table rep -> Table rep
forall a. Semigroup a => a -> a -> a
<> Stms rep -> Table rep
forall rep. Stms rep -> Table rep
mkTable (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (FunDef rep -> Body rep
forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
fd))
    mkSubst :: SubExp -> Map VName VName
mkSubst (Var VName
v0)
      | Just (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem0 LMAD
lmad0), BasicOp (Manifest [Int]
_ VName
v1)) <-
          VName -> Table rep -> Maybe (LParamMem, Exp rep)
forall rep (inner :: * -> *).
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v0 Table rep
table,
        Just (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem1 LMAD
lmad1), Exp rep
_) <-
          VName -> Table rep -> Maybe (LParamMem, Exp rep)
forall rep (inner :: * -> *).
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v1 Table rep
table,
        LMAD
lmad0 LMAD -> LMAD -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD
lmad1 =
          [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
mem0, VName
mem1), (VName
v0, VName
v1)]
    mkSubst SubExp
_ = Map VName VName
forall a. Monoid a => a
mempty
    onBody :: Body rep -> Body rep
onBody (Body BodyDec rep
dec Stms rep
stms Result
res) =
      let substs :: Map VName VName
substs = [Map VName VName] -> Map VName VName
forall a. Monoid a => [a] -> a
mconcat ([Map VName VName] -> Map VName VName)
-> [Map VName VName] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Map VName VName) -> Result -> [Map VName VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Map VName VName
mkSubst (SubExp -> Map VName VName)
-> (SubExpRes -> SubExp) -> SubExpRes -> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
       in BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Result -> Result
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Result
res

entryPointMem :: (Mem rep inner) => Pass rep rep
entryPointMem :: forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem =
  Pass
    { passName :: String
passName = String
"Entry point memory optimisation",
      passDescription :: String
passDescription = String
"Remove redundant copies of entry point results.",
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = (Stms rep -> PassM (Stms rep))
-> (Stms rep -> FunDef rep -> PassM (FunDef rep))
-> Prog rep
-> PassM (Prog rep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms rep -> PassM (Stms rep)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms rep -> FunDef rep -> PassM (FunDef rep)
forall {rep} {inner :: * -> *} {f :: * -> *}.
(BranchType rep ~ BranchTypeMem, LParamInfo rep ~ LParamMem,
 FParamInfo rep ~ FParamMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, Applicative f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns inner, RephraseOp inner, Pretty (inner rep),
 Rename (inner rep), Show (inner rep), Ord (inner rep),
 Substitute (inner rep), FreeIn (inner rep)) =>
Stms rep -> FunDef rep -> f (FunDef rep)
onFun
    }
  where
    onFun :: Stms rep -> FunDef rep -> f (FunDef rep)
onFun Stms rep
consts FunDef rep
fd = FunDef rep -> f (FunDef rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef rep -> f (FunDef rep)) -> FunDef rep -> f (FunDef rep)
forall a b. (a -> b) -> a -> b
$ Table rep -> FunDef rep -> FunDef rep
forall rep (inner :: * -> *).
Mem rep inner =>
Table rep -> FunDef rep -> FunDef rep
optimiseFun (Stms rep -> Table rep
forall rep. Stms rep -> Table rep
mkTable Stms rep
consts) FunDef rep
fd

-- | The pass for GPU representation.
entryPointMemGPU :: Pass GPUMem GPUMem
entryPointMemGPU :: Pass GPUMem GPUMem
entryPointMemGPU = Pass GPUMem GPUMem
forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem

-- | The pass for MC representation.
entryPointMemMC :: Pass MCMem MCMem
entryPointMemMC :: Pass MCMem MCMem
entryPointMemMC = Pass MCMem MCMem
forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem

-- | The pass for Seq representation.
entryPointMemSeq :: Pass SeqMem SeqMem
entryPointMemSeq :: Pass SeqMem SeqMem
entryPointMemSeq = Pass SeqMem SeqMem
forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem