{-# LANGUAGE TypeFamilies #-}

-- | We require that entry points return arrays with zero offset in
-- row-major order.  "Futhark.Pass.ExplicitAllocations" is
-- conservative and inserts copies to ensure this is the case.  After
-- simplification, it may turn out that those copies are redundant.
-- This pass removes them.  It's a pretty simple pass, as it only has
-- to look at the top level of entry points.
module Futhark.Optimise.EntryPointMem
  ( entryPointMemGPU,
    entryPointMemMC,
    entryPointMemSeq,
  )
where

import Data.List (find)
import Data.Map.Strict qualified as M
import Futhark.IR.GPUMem (GPUMem)
import Futhark.IR.MCMem (MCMem)
import Futhark.IR.Mem
import Futhark.IR.SeqMem (SeqMem)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute

type Table rep = M.Map VName (Stm rep)

mkTable :: Stms rep -> Table rep
mkTable :: forall rep. Stms rep -> Table rep
mkTable = (Stm rep -> Table rep) -> Seq (Stm rep) -> Table rep
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Table rep
forall {rep}. Stm rep -> Map VName (Stm rep)
f
  where
    f :: Stm rep -> Map VName (Stm rep)
f Stm rep
stm = [(VName, Stm rep)] -> Map VName (Stm rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Stm rep)] -> Map VName (Stm rep))
-> [(VName, Stm rep)] -> Map VName (Stm rep)
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, Stm rep)) -> [VName] -> [(VName, Stm rep)]
forall a b. (a -> b) -> [a] -> [b]
map (,Stm rep
stm) (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm))

varInfo :: (Mem rep inner) => VName -> Table rep -> Maybe (LetDecMem, Exp rep)
varInfo :: forall rep (inner :: * -> *).
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v Table rep
table = do
  Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e <- VName -> Table rep -> Maybe (Stm rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Table rep
table
  PatElem VName
_ LetDec rep
info <- (PatElem (LetDec rep) -> Bool)
-> [PatElem (LetDec rep)] -> Maybe (PatElem (LetDec rep))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (PatElem (LetDec rep) -> VName) -> PatElem (LetDec rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
  (LParamMem, Exp rep) -> Maybe (LParamMem, Exp rep)
forall a. a -> Maybe a
Just (LetDec rep -> LParamMem
forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
info, Exp rep
e)

optimiseFun :: (Mem rep inner) => Table rep -> FunDef rep -> FunDef rep
optimiseFun :: forall rep (inner :: * -> *).
Mem rep inner =>
Table rep -> FunDef rep -> FunDef rep
optimiseFun Table rep
consts_table FunDef rep
fd =
  FunDef rep
fd {funDefBody = onBody $ funDefBody fd}
  where
    table :: Table rep
table = Table rep
consts_table Table rep -> Table rep -> Table rep
forall a. Semigroup a => a -> a -> a
<> Stms rep -> Table rep
forall rep. Stms rep -> Table rep
mkTable (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (FunDef rep -> Body rep
forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
fd))
    mkSubst :: SubExp -> Map VName VName
mkSubst (Var VName
v0)
      | Just (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem0 LMAD
lmad0), BasicOp (Manifest [Int]
_ VName
v1)) <-
          VName -> Table rep -> Maybe (LParamMem, Exp rep)
forall rep (inner :: * -> *).
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v0 Table rep
table,
        Just (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem1 LMAD
lmad1), Exp rep
_) <-
          VName -> Table rep -> Maybe (LParamMem, Exp rep)
forall rep (inner :: * -> *).
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v1 Table rep
table,
        LMAD
lmad0 LMAD -> LMAD -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD
lmad1 =
          [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
mem0, VName
mem1), (VName
v0, VName
v1)]
    mkSubst SubExp
_ = Map VName VName
forall a. Monoid a => a
mempty
    onBody :: Body rep -> Body rep
onBody (Body BodyDec rep
dec Stms rep
stms Result
res) =
      let substs :: Map VName VName
substs = [Map VName VName] -> Map VName VName
forall a. Monoid a => [a] -> a
mconcat ([Map VName VName] -> Map VName VName)
-> [Map VName VName] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Map VName VName) -> Result -> [Map VName VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Map VName VName
mkSubst (SubExp -> Map VName VName)
-> (SubExpRes -> SubExp) -> SubExpRes -> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
       in BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Result -> Result
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Result
res

entryPointMem :: (Mem rep inner) => Pass rep rep
entryPointMem :: forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem =
  Pass
    { passName :: String
passName = String
"Entry point memory optimisation",
      passDescription :: String
passDescription = String
"Remove redundant copies of entry point results.",
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = (Stms rep -> PassM (Stms rep))
-> (Stms rep -> FunDef rep -> PassM (FunDef rep))
-> Prog rep
-> PassM (Prog rep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms rep -> PassM (Stms rep)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms rep -> FunDef rep -> PassM (FunDef rep)
forall {rep} {inner :: * -> *} {f :: * -> *}.
(RetType rep ~ RetTypeMem, FParamInfo rep ~ FParamMem,
 LParamInfo rep ~ LParamMem, BranchType rep ~ BranchTypeMem,
 OpC rep ~ MemOp inner, Applicative f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns (inner rep), RephraseOp inner) =>
Stms rep -> FunDef rep -> f (FunDef rep)
onFun
    }
  where
    onFun :: Stms rep -> FunDef rep -> f (FunDef rep)
onFun Stms rep
consts FunDef rep
fd = FunDef rep -> f (FunDef rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef rep -> f (FunDef rep)) -> FunDef rep -> f (FunDef rep)
forall a b. (a -> b) -> a -> b
$ Table rep -> FunDef rep -> FunDef rep
forall rep (inner :: * -> *).
Mem rep inner =>
Table rep -> FunDef rep -> FunDef rep
optimiseFun (Stms rep -> Table rep
forall rep. Stms rep -> Table rep
mkTable Stms rep
consts) FunDef rep
fd

-- | The pass for GPU representation.
entryPointMemGPU :: Pass GPUMem GPUMem
entryPointMemGPU :: Pass GPUMem GPUMem
entryPointMemGPU = Pass GPUMem GPUMem
forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem

-- | The pass for MC representation.
entryPointMemMC :: Pass MCMem MCMem
entryPointMemMC :: Pass MCMem MCMem
entryPointMemMC = Pass MCMem MCMem
forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem

-- | The pass for Seq representation.
entryPointMemSeq :: Pass SeqMem SeqMem
entryPointMemSeq :: Pass SeqMem SeqMem
entryPointMemSeq = Pass SeqMem SeqMem
forall rep (inner :: * -> *). Mem rep inner => Pass rep rep
entryPointMem