{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
  ( lowerUpdateKernels
  , LowerUpdate
  , DesiredUpdate (..)
  ) where

import Control.Monad
import Control.Monad.Writer
import Data.List (find)
import Data.Maybe (mapMaybe)
import Data.Either
import qualified Data.Map as M

import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
import Futhark.Representation.Kernels
import Futhark.Construct
import Futhark.Optimise.InPlaceLowering.SubstituteIndices

data DesiredUpdate attr =
  DesiredUpdate { DesiredUpdate attr -> VName
updateName :: VName -- ^ Name of result.
                , DesiredUpdate attr -> attr
updateType :: attr -- ^ Type of result.
                , DesiredUpdate attr -> Certificates
updateCertificates :: Certificates
                , DesiredUpdate attr -> VName
updateSource :: VName
                , DesiredUpdate attr -> Slice SubExp
updateIndices :: Slice SubExp
                , DesiredUpdate attr -> VName
updateValue :: VName
                }
  deriving (Int -> DesiredUpdate attr -> ShowS
[DesiredUpdate attr] -> ShowS
DesiredUpdate attr -> String
(Int -> DesiredUpdate attr -> ShowS)
-> (DesiredUpdate attr -> String)
-> ([DesiredUpdate attr] -> ShowS)
-> Show (DesiredUpdate attr)
forall attr. Show attr => Int -> DesiredUpdate attr -> ShowS
forall attr. Show attr => [DesiredUpdate attr] -> ShowS
forall attr. Show attr => DesiredUpdate attr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate attr] -> ShowS
$cshowList :: forall attr. Show attr => [DesiredUpdate attr] -> ShowS
show :: DesiredUpdate attr -> String
$cshow :: forall attr. Show attr => DesiredUpdate attr -> String
showsPrec :: Int -> DesiredUpdate attr -> ShowS
$cshowsPrec :: forall attr. Show attr => Int -> DesiredUpdate attr -> ShowS
Show)

instance Functor DesiredUpdate where
  a -> b
f fmap :: (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u { updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate a
u }

updateHasValue :: VName -> DesiredUpdate attr -> Bool
updateHasValue :: VName -> DesiredUpdate attr -> Bool
updateHasValue VName
name = (VName
nameVName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> (DesiredUpdate attr -> VName) -> DesiredUpdate attr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate attr -> VName
forall attr. DesiredUpdate attr -> VName
updateValue

type LowerUpdate lore m = Scope (Aliases lore)
                          -> Stm (Aliases lore)
                          -> [DesiredUpdate (LetAttr (Aliases lore))]
                          -> Maybe (m [Stm (Aliases lore)])

lowerUpdate :: (MonadFreshNames m, Bindable lore,
                LetAttr lore ~ Type, CanBeAliased (Op lore)) => LowerUpdate lore m
lowerUpdate :: LowerUpdate lore m
lowerUpdate Scope (Aliases lore)
scope (Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
aux (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body)) [DesiredUpdate (LetAttr (Aliases lore))]
updates = do
  m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
   BodyT (Aliases lore))
canDo <- Scope (Aliases lore)
-> [DesiredUpdate (LetAttr (Aliases lore))]
-> Pattern (Aliases lore)
-> [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Maybe
     (m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
         [(FParam (Aliases lore), SubExp)],
         [(FParam (Aliases lore), SubExp)], BodyT (Aliases lore)))
forall lore als (m :: * -> *).
(Bindable lore, BinderOps lore, Aliased lore,
 LetAttr lore ~ (als, Type), MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetAttr lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope (Aliases lore)
scope [DesiredUpdate (LetAttr (Aliases lore))]
updates Pattern (Aliases lore)
pat [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body
  m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$ do
    ([Stm (Aliases lore)]
prebnds, [Stm (Aliases lore)]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
ctx', [(Param DeclType, SubExp)]
val', BodyT (Aliases lore)
body') <- m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
   BodyT (Aliases lore))
canDo
    [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> m [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
      [Stm (Aliases lore)]
prebnds [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpAttr lore) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpAttr lore)
StmAux (ExpAttr (Aliases lore))
aux) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
                  [Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctxpat [Ident]
valpat (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$ [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> ExpT (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx' [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val' LoopForm (Aliases lore)
form BodyT (Aliases lore)
body'] [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
postbnds
lowerUpdate Scope (Aliases lore)
_
  (Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
aux (BasicOp (SubExp (Var VName
v))))
  [DesiredUpdate VName
bindee_nm LetAttr (Aliases lore)
bindee_attr Certificates
cs VName
src Slice SubExp
is VName
val]
  | PatternT (ConsumedInExp, Type) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT (ConsumedInExp, Type)
Pattern (Aliases lore)
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
    let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases lore)
bindee_attr) Slice SubExp
is
    in m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$
       [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return [Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpAttr lore) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpAttr lore)
StmAux (ExpAttr (Aliases lore))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
               [Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases lore)
bindee_attr] (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
               BasicOp (Aliases lore) -> ExpT (Aliases lore)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Aliases lore) -> ExpT (Aliases lore))
-> BasicOp (Aliases lore) -> ExpT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp (Aliases lore)
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update VName
v Slice SubExp
is' (SubExp -> BasicOp (Aliases lore))
-> SubExp -> BasicOp (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val]
lowerUpdate Scope (Aliases lore)
_ Stm (Aliases lore)
_ [DesiredUpdate (LetAttr (Aliases lore))]
_ =
  Maybe (m [Stm (Aliases lore)])
forall a. Maybe a
Nothing

lowerUpdateKernels :: MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels :: LowerUpdate Kernels m
lowerUpdateKernels Scope (Aliases Kernels)
_
  (Let (Pattern [] [PatElem VName
v LetAttr (Aliases Kernels)
v_attr]) StmAux (ExpAttr (Aliases Kernels))
aux (Op (SegOp (SegMap lvl space ts kbody))))
  [update :: DesiredUpdate (LetAttr (Aliases Kernels))
update@(DesiredUpdate VName
bindee_nm LetAttr (Aliases Kernels)
bindee_attr Certificates
cs VName
_src Slice SubExp
is VName
val)]
  | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
val = do
    KernelBody (Aliases Kernels)
kbody' <- DesiredUpdate (LetAttr (Aliases Kernels))
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
lowerUpdateIntoKernel DesiredUpdate (LetAttr (Aliases Kernels))
update SegSpace
space KernelBody (Aliases Kernels)
kbody
    let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases Kernels)
bindee_attr) Slice SubExp
is
    m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a. a -> Maybe a
Just (m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)]))
-> m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall (m :: * -> *) a. Monad m => a -> m a
return [Certificates -> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ()) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ())
StmAux (ExpAttr (Aliases Kernels))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases Kernels) -> Stm (Aliases Kernels))
-> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
                    [Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases Kernels)
bindee_attr] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
                    Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Aliases Kernels) -> ExpT (Aliases Kernels))
-> Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Aliases Kernels)
 -> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> SegOp (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases Kernels)
-> SegOp (Aliases Kernels)
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases Kernels)
kbody',
                   [Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetAttr (Aliases Kernels)
v_attr] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Aliases Kernels) -> ExpT (Aliases Kernels))
-> BasicOp (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp (Aliases Kernels)
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
bindee_nm Slice SubExp
is']
lowerUpdateKernels Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetAttr (Aliases Kernels))]
updates = LowerUpdate Kernels m
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetAttr lore ~ Type,
 CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetAttr (Aliases Kernels))]
updates

lowerUpdateIntoKernel :: DesiredUpdate (LetAttr (Aliases Kernels))
                      -> SegSpace -> KernelBody (Aliases Kernels)
                      -> Maybe (KernelBody (Aliases Kernels))
lowerUpdateIntoKernel :: DesiredUpdate (LetAttr (Aliases Kernels))
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
lowerUpdateIntoKernel DesiredUpdate (LetAttr (Aliases Kernels))
update SegSpace
kspace KernelBody (Aliases Kernels)
kbody = do
  [Returns ResultManifest
_ SubExp
se] <- [KernelResult] -> Maybe [KernelResult]
forall a. a -> Maybe a
Just ([KernelResult] -> Maybe [KernelResult])
-> [KernelResult] -> Maybe [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody (Aliases Kernels) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases Kernels)
kbody
  [SubExp]
is' <- (DimIndex SubExp -> Maybe SubExp) -> Slice SubExp -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix Slice SubExp
is
  let ret :: KernelResult
ret = [SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult
WriteReturns (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall a b. (a, b) -> b
snd (ConsumedInExp, Type)
bindee_attr) VName
src [([SubExp]
is'[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
gtids, SubExp
se)]
  KernelBody (Aliases Kernels)
-> Maybe (KernelBody (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody (Aliases Kernels)
kbody { kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult
ret] }
  where DesiredUpdate VName
_bindee_nm (ConsumedInExp, Type)
bindee_attr Certificates
_cs VName
src Slice SubExp
is VName
_val = DesiredUpdate (ConsumedInExp, Type)
DesiredUpdate (LetAttr (Aliases Kernels))
update
        gtids :: [VName]
gtids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace

lowerUpdateIntoLoop :: (Bindable lore, BinderOps lore,
                        Aliased lore, LetAttr lore ~ (als, Type),
                        MonadFreshNames m) =>
                       Scope lore
                    -> [DesiredUpdate (LetAttr lore)]
                    -> Pattern lore
                    -> [(FParam lore, SubExp)]
                    -> [(FParam lore, SubExp)]
                    -> LoopForm lore
                    -> Body lore
                    -> Maybe (m ([Stm lore],
                                 [Stm lore],
                                 [Ident],
                                 [Ident],
                                 [(FParam lore, SubExp)],
                                 [(FParam lore, SubExp)],
                                 Body lore))
lowerUpdateIntoLoop :: Scope lore
-> [DesiredUpdate (LetAttr lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope lore
scope [DesiredUpdate (LetAttr lore)]
updates Pattern lore
pat [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form Body lore
body = do
  -- Algorithm:
  --
  --   0) Map each result of the loop body to a corresponding in-place
  --      update, if one exists.
  --
  --   1) Create new merge variables corresponding to the arrays being
  --      updated; extend the pattern and the @res@ list with these,
  --      and remove the parts of the result list that have a
  --      corresponding in-place update.
  --
  --      (The creation of the new merge variable identifiers is
  --      actually done at the same time as step (0)).
  --
  --   2) Create in-place updates at the end of the loop body.
  --
  --   3) Create index expressions that read back the values written
  --      in (2).  If the merge parameter corresponding to this value
  --      is unique, also @copy@ this value.
  --
  --   4) Update the result of the loop body to properly pass the new
  --      arrays and indexed elements to the next iteration of the
  --      loop.
  --
  -- We also check that the merge parameters we work with have
  -- loop-invariant shapes.

  -- Safety condition (8).
  [((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases Body lore
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall attr. Param attr -> VName
paramName Param DeclType
p VName -> Names -> Bool
`nameIn` Names
als

  m [LoopResultSummary (als, Type)]
mk_in_place_map <- [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (m :: * -> *) als.
MonadFreshNames m =>
[DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop [DesiredUpdate (als, Type)]
[DesiredUpdate (LetAttr lore)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val

  m ([Stm lore], [Stm lore], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a. a -> Maybe a
Just (m ([Stm lore], [Stm lore], [Ident], [Ident],
    [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
 -> Maybe
      (m ([Stm lore], [Stm lore], [Ident], [Ident],
          [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
          Body lore)))
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
      [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a b. (a -> b) -> a -> b
$ do
    [LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
    ([(Param DeclType, SubExp)]
val',[Stm lore]
prebnds,[Stm lore]
postbnds) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) lore als.
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
    let ([Ident]
ctxpat,[Ident]
valpat) = [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
        idxsubsts :: IndexSubstitutions (als, Type)
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions (als, Type)
forall attr. [LoopResultSummary attr] -> IndexSubstitutions attr
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
    (IndexSubstitutions (als, Type)
idxsubsts', Stms lore
newbnds) <- IndexSubstitutions (als, Type)
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall (m :: * -> *) lore attr.
(MonadFreshNames m, BinderOps lore, Bindable lore, Aliased lore,
 LetAttr lore ~ attr) =>
IndexSubstitutions attr
-> Stms lore -> m (IndexSubstitutions attr, Stms lore)
substituteIndices IndexSubstitutions (als, Type)
idxsubsts (Stms lore -> m (IndexSubstitutions (als, Type), Stms lore))
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body
    ([SubExp]
body_res, Stms lore
res_bnds) <- [LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore) -> m ([SubExp], Stms lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetAttr lore)]
in_place_map IndexSubstitutions (als, Type)
IndexSubstitutions (LetAttr lore)
idxsubsts'
    let body' :: Body lore
body' = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
newbndsStms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
res_bnds) [SubExp]
body_res
    ([Stm lore], [Stm lore], [Ident], [Ident],
 [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
      [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore]
prebnds, [Stm lore]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
ctx, [(Param DeclType, SubExp)]
val', Body lore
body')
  where usedInBody :: Names
usedInBody = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expandAliases ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
        expandAliases :: VName -> Names
expandAliases VName
v = case VName -> Scope lore -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Scope lore
scope of
                            Just (LetInfo LetAttr lore
attr) -> VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (als, Type) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf (als, Type)
LetAttr lore
attr
                            Maybe (NameInfo lore)
_ -> VName -> Names
oneName VName
v
        resmap :: [(SubExp, Ident)]
resmap = [SubExp] -> [Ident] -> [(SubExp, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body) ([Ident] -> [(SubExp, Ident)]) -> [Ident] -> [(SubExp, Ident)]
forall a b. (a -> b) -> a -> b
$ PatternT (als, Type) -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents PatternT (als, Type)
Pattern lore
pat

        mkMerges :: (MonadFreshNames m, Bindable lore) =>
                    [LoopResultSummary (als, Type)]
                 -> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
        mkMerges :: [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
          (([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm lore]
prebnds, [Stm lore]
postbnds)) <-
            WriterT
  ([Stm lore], [Stm lore])
  m
  ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm lore], [Stm lore]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Stm lore], [Stm lore])
   m
   ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
 -> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
       ([Stm lore], [Stm lore])))
-> WriterT
     ([Stm lore], [Stm lore])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm lore], [Stm lore]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
 -> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
     ([Stm lore], [Stm lore])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
     ([Stm lore], [Stm lore])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
 -> WriterT
      ([Stm lore], [Stm lore])
      m
      (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
     ([Stm lore], [Stm lore])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LoopResultSummary (als, Type)
-> WriterT
     ([Stm lore], [Stm lore])
     m
     (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) lore lore a.
(MonadFreshNames m, MonadWriter ([Stm lore], [Stm lore]) m,
 Bindable lore, Bindable lore) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
          ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm lore]
prebnds, [Stm lore]
postbnds)

        mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
          | Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergeattr) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary (a, Type)
summary = do
            VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
            let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (a, Type)
update
                elmident :: Ident
elmident = VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateValue DesiredUpdate (a, Type)
update) (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
source_t
            ([Stm lore], [Stm lore]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update
                   (DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateSource DesiredUpdate (a, Type)
update)
                   (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall attr. DesiredUpdate attr -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update) (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$
                   (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall attr. LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary],
                  [[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
elmident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index
                   (DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateName DesiredUpdate (a, Type)
update) (Type -> Slice SubExp -> Slice SubExp
fullSlice ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (a, Type)
update) (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall attr. DesiredUpdate attr -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)])
            Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right (VName -> DeclType -> Param DeclType
forall attr. VName -> attr -> Param attr
Param
                            VName
mergename
                            (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergeattr) Uniqueness
Unique),
                            VName -> SubExp
Var VName
source)
          | Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
 -> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall attr. LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary

        mkResAndPat :: [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
summaries =
          let ([Ident]
origpat,[Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (als, Type) -> Either Ident Ident)
-> [LoopResultSummary (als, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (als, Type) -> Either Ident Ident
forall a. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (als, Type)]
summaries
          in (PatternT (als, Type) -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternContextIdents PatternT (als, Type)
Pattern lore
pat,
              [Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat)

        mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
          | Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary (a, Type)
summary =
              Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (a, Type)
update))
          | Bool
otherwise =
              Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall attr. LoopResultSummary attr -> Ident
inPatternAs LoopResultSummary (a, Type)
summary)

summariseLoop :: MonadFreshNames m =>
                 [DesiredUpdate (als, Type)]
              -> Names
              -> [(SubExp, Ident)]
              -> [(Param DeclType, SubExp)]
              -> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
  [m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m (LoopResultSummary (als, Type))]
 -> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Ident)
 -> (Param DeclType, SubExp)
 -> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge
  where summariseLoopResult :: (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExp
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
          | Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall attr. VName -> DesiredUpdate attr -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
            if DesiredUpdate (als, Type) -> VName
forall attr. DesiredUpdate attr -> VName
updateSource DesiredUpdate (als, Type)
update VName -> Names -> Bool
`nameIn` Names
usedInBody
            then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
            else if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
 -> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
              VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
              LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return LoopResultSummary :: forall attr.
SubExp
-> Ident
-> (Param DeclType, SubExp)
-> Maybe (DesiredUpdate attr, VName, attr)
-> LoopResultSummary attr
LoopResultSummary { resultSubExp :: SubExp
resultSubExp = SubExp
se
                                       , inPatternAs :: Ident
inPatternAs = Ident
v
                                       , mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit)
                                       , relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate = (DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just (DesiredUpdate (als, Type)
update,
                                                               VName
lowered_array,
                                                               DesiredUpdate (als, Type) -> (als, Type)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (als, Type)
update)
                                       }
            else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
        summariseLoopResult (SubExp, Ident)
_ (Param DeclType, SubExp)
_ =
          Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing -- XXX: conservative; but this entire pass is going away.

        hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType

        merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall attr. Param attr -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge

        loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v)    = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
        loopInvariant Constant{} = Bool
True

data LoopResultSummary attr =
  LoopResultSummary { LoopResultSummary attr -> SubExp
resultSubExp :: SubExp
                    , LoopResultSummary attr -> Ident
inPatternAs :: Ident
                    , LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp)
                    , LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate :: Maybe (DesiredUpdate attr, VName, attr)
                    }
  deriving (Int -> LoopResultSummary attr -> ShowS
[LoopResultSummary attr] -> ShowS
LoopResultSummary attr -> String
(Int -> LoopResultSummary attr -> ShowS)
-> (LoopResultSummary attr -> String)
-> ([LoopResultSummary attr] -> ShowS)
-> Show (LoopResultSummary attr)
forall attr. Show attr => Int -> LoopResultSummary attr -> ShowS
forall attr. Show attr => [LoopResultSummary attr] -> ShowS
forall attr. Show attr => LoopResultSummary attr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary attr] -> ShowS
$cshowList :: forall attr. Show attr => [LoopResultSummary attr] -> ShowS
show :: LoopResultSummary attr -> String
$cshow :: forall attr. Show attr => LoopResultSummary attr -> String
showsPrec :: Int -> LoopResultSummary attr -> ShowS
$cshowsPrec :: forall attr. Show attr => Int -> LoopResultSummary attr -> ShowS
Show)

indexSubstitutions :: [LoopResultSummary attr]
                   -> IndexSubstitutions attr
indexSubstitutions :: [LoopResultSummary attr] -> IndexSubstitutions attr
indexSubstitutions = (LoopResultSummary attr
 -> Maybe (VName, (Certificates, VName, attr, Slice SubExp)))
-> [LoopResultSummary attr] -> IndexSubstitutions attr
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary attr
-> Maybe (VName, (Certificates, VName, attr, Slice SubExp))
forall c.
LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution
  where getSubstitution :: LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution LoopResultSummary c
res = do
          (DesiredUpdate VName
_ c
_ Certificates
cs VName
_ Slice SubExp
is VName
_, VName
nm, c
attr) <- LoopResultSummary c -> Maybe (DesiredUpdate c, VName, c)
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary c
res
          let name :: VName
name = Param DeclType -> VName
forall attr. Param attr -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary c -> (Param DeclType, SubExp)
forall attr. LoopResultSummary attr -> (Param DeclType, SubExp)
mergeParam LoopResultSummary c
res
          (VName, (Certificates, VName, c, Slice SubExp))
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
name, (Certificates
cs, VName
nm, c
attr, Slice SubExp
is))

manipulateResult :: (Bindable lore, MonadFreshNames m) =>
                    [LoopResultSummary (LetAttr lore)]
                 -> IndexSubstitutions (LetAttr lore)
                 -> m (Result, Stms lore)
manipulateResult :: [LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (LetAttr lore)]
summaries IndexSubstitutions (LetAttr lore)
substs = do
  let ([SubExp]
orig_ses,[SubExp]
updated_ses) = [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExp SubExp] -> ([SubExp], [SubExp]))
-> [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetAttr lore) -> Either SubExp SubExp)
-> [LoopResultSummary (LetAttr lore)] -> [Either SubExp SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetAttr lore) -> Either SubExp SubExp
forall attr. LoopResultSummary attr -> Either SubExp SubExp
unchangedRes [LoopResultSummary (LetAttr lore)]
summaries
  ([SubExp]
subst_ses, [Stm lore]
res_bnds) <- WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore]))
-> WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall a b. (a -> b) -> a -> b
$ (SubExp
 -> (VName, (Certificates, VName, LetAttr lore, Slice SubExp))
 -> WriterT [Stm lore] m SubExp)
-> [SubExp]
-> IndexSubstitutions (LetAttr lore)
-> WriterT [Stm lore] m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp
-> (VName, (Certificates, VName, LetAttr lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp
forall (m :: * -> *) lore t.
(MonadFreshNames m, MonadWriter [Stm lore] m, Bindable lore,
 Typed t) =>
SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes [SubExp]
updated_ses IndexSubstitutions (LetAttr lore)
substs
  ([SubExp], Stms lore) -> m ([SubExp], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp]
orig_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
subst_ses, [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
res_bnds)
  where
    unchangedRes :: LoopResultSummary attr -> Either SubExp SubExp
unchangedRes LoopResultSummary attr
summary =
      case LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
forall attr.
LoopResultSummary attr -> Maybe (DesiredUpdate attr, VName, attr)
relatedUpdate LoopResultSummary attr
summary of
        Maybe (DesiredUpdate attr, VName, attr)
Nothing -> SubExp -> Either SubExp SubExp
forall a b. a -> Either a b
Left (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary attr -> SubExp
forall attr. LoopResultSummary attr -> SubExp
resultSubExp LoopResultSummary attr
summary
        Just (DesiredUpdate attr, VName, attr)
_  -> SubExp -> Either SubExp SubExp
forall a b. b -> Either a b
Right (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary attr -> SubExp
forall attr. LoopResultSummary attr -> SubExp
resultSubExp LoopResultSummary attr
summary
    substRes :: SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes (Var VName
res_v) (VName
subst_v, (Certificates
_, VName
nm, t
_, Slice SubExp
_))
      | VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
        SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
    substRes SubExp
res_se (VName
_, (Certificates
cs, VName
nm, t
attr, Slice SubExp
is)) = do
      Ident
v' <- ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++String
"_updated") (Ident -> m Ident) -> Ident -> m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
attr
      [Stm lore] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Stm lore -> Stm lore) -> Stm lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
v'] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update VName
nm (Type -> Slice SubExp -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
attr) Slice SubExp
is) SubExp
res_se]
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'