{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.ArrayShortCircuiting.DataStructs
  ( Coalesced (..),
    CoalescedKind (..),
    ArrayMemBound (..),
    AllocTab,
    HasMemBlock,
    ScalarTab,
    CoalsTab,
    ScopeTab,
    CoalsEntry (..),
    FreeVarSubsts,
    LmadRef,
    MemRefs (..),
    AccessSummary (..),
    BotUpEnv (..),
    InhibitTab,
    unionCoalsEntry,
    vnameToPrimExp,
    getArrMemAssocFParam,
    getScopeMemInfo,
    createsNewArrOK,
    getArrMemAssoc,
    getUniqueMemFParam,
    markFailedCoal,
    accessSubtract,
    markSuccessCoal,
  )
where

import Control.Applicative
import Data.Functor ((<&>))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.IR.Aliases
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SeqMem
import Futhark.Util.Pretty hiding (line, sep, (</>))
import Prelude

-- | maps array-variable names to various info, including
--   types, memory block and index function, etc.
type ScopeTab rep = Scope (Aliases rep)

-- | An LMAD specialized to TPrimExps (a typed primexp)
type LmadRef = IxFun.LMAD (TPrimExp Int64 VName)

-- | Summary of all memory accesses at a given point in the code
data AccessSummary
  = -- | The access summary was statically undeterminable, for instance by
    -- having multiple lmads. In this case, we should conservatively avoid all
    -- coalescing.
    Undeterminable
  | -- | A conservative estimate of the set of accesses up until this point.
    Set (S.Set LmadRef)

instance Semigroup AccessSummary where
  AccessSummary
Undeterminable <> :: AccessSummary -> AccessSummary -> AccessSummary
<> AccessSummary
_ = AccessSummary
Undeterminable
  AccessSummary
_ <> AccessSummary
Undeterminable = AccessSummary
Undeterminable
  (Set Set LmadRef
a) <> (Set Set LmadRef
b) =
    Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> Set LmadRef -> Set LmadRef
forall a. Ord a => Set a -> Set a -> Set a
S.union Set LmadRef
a Set LmadRef
b

instance Monoid AccessSummary where
  mempty :: AccessSummary
mempty = Set LmadRef -> AccessSummary
Set Set LmadRef
forall a. Monoid a => a
mempty

instance FreeIn AccessSummary where
  freeIn' :: AccessSummary -> FV
freeIn' AccessSummary
Undeterminable = FV
forall a. Monoid a => a
mempty
  freeIn' (Set Set LmadRef
s) = Set LmadRef -> FV
forall a. FreeIn a => a -> FV
freeIn' Set LmadRef
s

accessSubtract :: AccessSummary -> AccessSummary -> AccessSummary
accessSubtract :: AccessSummary -> AccessSummary -> AccessSummary
accessSubtract AccessSummary
Undeterminable AccessSummary
_ = AccessSummary
Undeterminable
accessSubtract AccessSummary
_ AccessSummary
Undeterminable = AccessSummary
Undeterminable
accessSubtract (Set Set LmadRef
s1) (Set Set LmadRef
s2) = Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ Set LmadRef
s1 Set LmadRef -> Set LmadRef -> Set LmadRef
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set LmadRef
s2

data MemRefs = MemRefs
  { -- | The access summary of all references (reads
    -- and writes) to the destination of a coalescing entry
    MemRefs -> AccessSummary
dstrefs :: AccessSummary,
    -- | The access summary of all writes to the source of a coalescing entry
    MemRefs -> AccessSummary
srcwrts :: AccessSummary
  }

instance Semigroup MemRefs where
  MemRefs
m1 <> :: MemRefs -> MemRefs -> MemRefs
<> MemRefs
m2 =
    AccessSummary -> AccessSummary -> MemRefs
MemRefs (MemRefs -> AccessSummary
dstrefs MemRefs
m1 AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> MemRefs -> AccessSummary
dstrefs MemRefs
m2) (MemRefs -> AccessSummary
srcwrts MemRefs
m1 AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> MemRefs -> AccessSummary
srcwrts MemRefs
m2)

instance Monoid MemRefs where
  mempty :: MemRefs
mempty = AccessSummary -> AccessSummary -> MemRefs
MemRefs AccessSummary
forall a. Monoid a => a
mempty AccessSummary
forall a. Monoid a => a
mempty

data CoalescedKind
  = -- | let x    = copy b^{lu}
    CopyCoal
  | -- | let x[i] = b^{lu}
    InPlaceCoal
  | -- | let x    = concat(a, b^{lu})
    ConcatCoal
  | -- | transitive, i.e., other variables aliased with b.
    TransitiveCoal
  | MapCoal

-- | Information about a memory block: type, shape, name and ixfun.
data ArrayMemBound = MemBlock
  { ArrayMemBound -> PrimType
primType :: PrimType,
    ArrayMemBound -> Shape
shape :: Shape,
    ArrayMemBound -> VName
memName :: VName,
    ArrayMemBound -> IxFun
ixfun :: IxFun
  }

-- | Free variable substitutions
type FreeVarSubsts = M.Map VName (TPrimExp Int64 VName)

-- | Coalesced Access Entry
data Coalesced
  = Coalesced
      CoalescedKind
      -- ^ the kind of coalescing
      ArrayMemBound
      -- ^ destination mem_block info @f_m_x[i]@ (must be ArrayMem)
      -- (Maybe IxFun) -- the inverse ixfun of a coalesced array, such that
      --                     --  ixfuns can be correctly constructed for aliases;
      FreeVarSubsts
      -- ^ substitutions for free vars in index function

data CoalsEntry = CoalsEntry
  { -- | destination memory block
    CoalsEntry -> VName
dstmem :: VName,
    -- | index function of the destination (used for rebasing)
    CoalsEntry -> IxFun
dstind :: IxFun,
    -- | aliased destination memory blocks can appear
    --   due to repeated (optimistic) coalescing.
    CoalsEntry -> Names
alsmem :: Names,
    -- | per variable-name coalesced entries
    CoalsEntry -> Map VName Coalesced
vartab :: M.Map VName Coalesced,
    -- | keys are variable names, values are memblock names;
    --   it records optimistically added coalesced nodes, e.g.,
    --   in the case of if-then-else expressions. For example:
    --       @x    = map f a@
    --       @.. use of y ..@
    --       @b    = map g a@
    --       @x[i] = b      @
    --       @y[k] = x      @
    --   the coalescing of @b@ in @x[i]@ succeeds, but
    --   is dependent of the success of the coalescing
    --   of @x@ in @y[k]@, which fails in this case
    --   because @y@ is used before the new array creation
    --   of @x = map f@. Hence @optdeps@ of the @m_b@ CoalsEntry
    --   records @x -> m_x@ and at the end of analysis it is removed
    --   from the successfully coalesced table if @m_x@ is
    --   unsuccessful.
    --   Storing @m_x@ would probably be sufficient if memory would
    --     not be reused--e.g., by register allocation on arrays--the
    --     @x@ discriminates between memory being reused across semantically
    --     different arrays (searched in @vartab@ field).
    CoalsEntry -> Map VName VName
optdeps :: M.Map VName VName,
    -- | Access summaries of uses and writes of destination and source
    -- respectively.
    CoalsEntry -> MemRefs
memrefs :: MemRefs,
    -- | Certificates of the destination, which must be propagated to
    -- the source. When short-circuiting reaches the array creation
    -- point, we must check whether the certs are in scope for
    -- short-circuiting to succeed.
    CoalsEntry -> Certs
certs :: Certs
  }

-- | the allocatted memory blocks
type AllocTab = M.Map VName Space

-- | maps a variable name to its PrimExp scalar expression
type ScalarTab = M.Map VName (PrimExp VName)

-- | maps a memory-block name to a 'CoalsEntry'. Among other things, it contains
--   @vartab@, a map in which each variable associated to that memory block is
--   bound to its 'Coalesced' info.
type CoalsTab = M.Map VName CoalsEntry

-- | inhibited memory-block mergings from the key (memory block)
--   to the value (set of memory blocks).
type InhibitTab = M.Map VName Names

data BotUpEnv = BotUpEnv
  { -- | maps scalar variables to theirs PrimExp expansion
    BotUpEnv -> ScalarTab
scals :: ScalarTab,
    -- | Optimistic coalescing info. We are currently trying to coalesce these
    -- memory blocks.
    BotUpEnv -> CoalsTab
activeCoals :: CoalsTab,
    -- | Committed (successfull) coalescing info. These memory blocks have been
    -- successfully coalesced.
    BotUpEnv -> CoalsTab
successCoals :: CoalsTab,
    -- | The coalescing failures from this pass. We will no longer try to merge
    -- these memory blocks.
    BotUpEnv -> InhibitTab
inhibit :: InhibitTab
  }

instance Pretty CoalsTab where
  pretty :: forall ann. CoalsTab -> Doc ann
pretty = [(VName, CoalsEntry)] -> Doc ann
forall ann. [(VName, CoalsEntry)] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty ([(VName, CoalsEntry)] -> Doc ann)
-> (CoalsTab -> [(VName, CoalsEntry)]) -> CoalsTab -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsTab -> [(VName, CoalsEntry)]
forall k a. Map k a -> [(k, a)]
M.toList

instance Pretty AccessSummary where
  pretty :: forall ann. AccessSummary -> Doc ann
pretty AccessSummary
Undeterminable = Doc ann
"Undeterminable"
  pretty (Set Set LmadRef
a) = Doc ann
"Access-Set:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> [LmadRef] -> Doc ann
forall ann. [LmadRef] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
a) Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
" "

instance Pretty MemRefs where
  pretty :: forall ann. MemRefs -> Doc ann
pretty (MemRefs AccessSummary
a AccessSummary
b) = Doc ann
"( Use-Sum:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> AccessSummary -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. AccessSummary -> Doc ann
pretty AccessSummary
a Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"Write-Sum:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> AccessSummary -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. AccessSummary -> Doc ann
pretty AccessSummary
b Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
")"

instance Pretty CoalescedKind where
  pretty :: forall ann. CoalescedKind -> Doc ann
pretty CoalescedKind
CopyCoal = Doc ann
"Copy"
  pretty CoalescedKind
InPlaceCoal = Doc ann
"InPlace"
  pretty CoalescedKind
ConcatCoal = Doc ann
"Concat"
  pretty CoalescedKind
TransitiveCoal = Doc ann
"Transitive"
  pretty CoalescedKind
MapCoal = Doc ann
"Map"

instance Pretty ArrayMemBound where
  pretty :: forall ann. ArrayMemBound -> Doc ann
pretty (MemBlock PrimType
ptp Shape
shp VName
m_nm IxFun
ixfn) =
    Doc ann
"{" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> PrimType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. PrimType -> Doc ann
pretty PrimType
ptp Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Shape -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Shape -> Doc ann
pretty Shape
shp Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
m_nm Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> IxFun -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun -> Doc ann
pretty IxFun
ixfn Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"}"

instance Pretty Coalesced where
  pretty :: forall ann. Coalesced -> Doc ann
pretty (Coalesced CoalescedKind
knd ArrayMemBound
mbd FreeVarSubsts
_) =
    Doc ann
"(Kind:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> CoalescedKind -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. CoalescedKind -> Doc ann
pretty CoalescedKind
knd Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
", membds:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> ArrayMemBound -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ArrayMemBound -> Doc ann
pretty ArrayMemBound
mbd -- <> ", subs:" <+> pretty subs
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
")"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"\n"

instance Pretty CoalsEntry where
  pretty :: forall ann. CoalsEntry -> Doc ann
pretty CoalsEntry
etry =
    Doc ann
"{"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"Dstmem:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty (CoalsEntry -> VName
dstmem CoalsEntry
etry)
        Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
", AliasMems:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Names -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Names -> Doc ann
pretty (CoalsEntry -> Names
alsmem CoalsEntry
etry)
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
", optdeps:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> [(VName, VName)] -> Doc ann
forall ann. [(VName, VName)] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Map VName VName -> [(VName, VName)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName VName -> [(VName, VName)])
-> Map VName VName -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry)
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
", memrefs:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> MemRefs -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. MemRefs -> Doc ann
pretty (CoalsEntry -> MemRefs
memrefs CoalsEntry
etry)
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
", vartab:"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> [(VName, Coalesced)] -> Doc ann
forall ann. [(VName, Coalesced)] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Map VName Coalesced -> [(VName, Coalesced)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Coalesced -> [(VName, Coalesced)])
-> Map VName Coalesced -> [(VName, Coalesced)]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry)
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"}"
      Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"\n"

-- | Compute the union of two 'CoalsEntry'. If two 'CoalsEntry' do not refer to
-- the same destination memory and use the same index function, the first
-- 'CoalsEntry' is returned.
unionCoalsEntry :: CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry :: CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsEntry
etry1 (CoalsEntry VName
dstmem2 IxFun
dstind2 Names
alsmem2 Map VName Coalesced
vartab2 Map VName VName
optdeps2 MemRefs
memrefs2 Certs
certs2) =
  if CoalsEntry -> VName
dstmem CoalsEntry
etry1 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
dstmem2 Bool -> Bool -> Bool
|| CoalsEntry -> IxFun
dstind CoalsEntry
etry1 IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
/= IxFun
dstind2
    then CoalsEntry
etry1
    else
      CoalsEntry
etry1
        { alsmem :: Names
alsmem = CoalsEntry -> Names
alsmem CoalsEntry
etry1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alsmem2,
          optdeps :: Map VName VName
optdeps = CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry1 Map VName VName -> Map VName VName -> Map VName VName
forall a. Semigroup a => a -> a -> a
<> Map VName VName
optdeps2,
          vartab :: Map VName Coalesced
vartab = CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry1 Map VName Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall a. Semigroup a => a -> a -> a
<> Map VName Coalesced
vartab2,
          memrefs :: MemRefs
memrefs = CoalsEntry -> MemRefs
memrefs CoalsEntry
etry1 MemRefs -> MemRefs -> MemRefs
forall a. Semigroup a => a -> a -> a
<> MemRefs
memrefs2,
          certs :: Certs
certs = CoalsEntry -> Certs
certs CoalsEntry
etry1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
certs2
        }

-- | Get the names of array 'PatElem's in a 'Pat' and the corresponding
-- 'ArrayMemBound' information for each array.
getArrMemAssoc :: Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc :: forall aliases.
Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (aliases, LetDecMem)
pat =
  (PatElem (aliases, LetDecMem) -> Maybe (VName, ArrayMemBound))
-> [PatElem (aliases, LetDecMem)] -> [(VName, ArrayMemBound)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
    ( \PatElem (aliases, LetDecMem)
patel -> case (aliases, LetDecMem) -> LetDecMem
forall a b. (a, b) -> b
snd ((aliases, LetDecMem) -> LetDecMem)
-> (aliases, LetDecMem) -> LetDecMem
forall a b. (a -> b) -> a -> b
$ PatElem (aliases, LetDecMem) -> (aliases, LetDecMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (aliases, LetDecMem)
patel of
        (MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
mem_nm IxFun
indfun)) ->
          (VName, ArrayMemBound) -> Maybe (VName, ArrayMemBound)
forall a. a -> Maybe a
Just (PatElem (aliases, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (aliases, LetDecMem)
patel, PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
mem_nm IxFun
indfun)
        MemMem Space
_ -> Maybe (VName, ArrayMemBound)
forall a. Maybe a
Nothing
        MemPrim PrimType
_ -> Maybe (VName, ArrayMemBound)
forall a. Maybe a
Nothing
        MemAcc {} -> Maybe (VName, ArrayMemBound)
forall a. Maybe a
Nothing
    )
    ([PatElem (aliases, LetDecMem)] -> [(VName, ArrayMemBound)])
-> [PatElem (aliases, LetDecMem)] -> [(VName, ArrayMemBound)]
forall a b. (a -> b) -> a -> b
$ Pat (aliases, LetDecMem) -> [PatElem (aliases, LetDecMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (aliases, LetDecMem)
pat

-- | Get the names of arrays in a list of 'FParam' and the corresponding
-- 'ArrayMemBound' information for each array.
getArrMemAssocFParam :: [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)]
getArrMemAssocFParam :: [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)]
getArrMemAssocFParam =
  (Param FParamMem -> Maybe (VName, Uniqueness, ArrayMemBound))
-> [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
    ( \Param FParamMem
param -> case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
param of
        (MemArray PrimType
tp Shape
shp Uniqueness
u (ArrayIn VName
mem_nm IxFun
indfun)) ->
          (VName, Uniqueness, ArrayMemBound)
-> Maybe (VName, Uniqueness, ArrayMemBound)
forall a. a -> Maybe a
Just (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param, Uniqueness
u, PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
mem_nm IxFun
indfun)
        MemMem Space
_ -> Maybe (VName, Uniqueness, ArrayMemBound)
forall a. Maybe a
Nothing
        MemPrim PrimType
_ -> Maybe (VName, Uniqueness, ArrayMemBound)
forall a. Maybe a
Nothing
        MemAcc {} -> Maybe (VName, Uniqueness, ArrayMemBound)
forall a. Maybe a
Nothing
    )

-- | Get memory blocks in a list of 'FParam' that are used for unique arrays in
-- the same list of 'FParam'.
getUniqueMemFParam :: [Param FParamMem] -> M.Map VName Space
getUniqueMemFParam :: [Param FParamMem] -> Map VName Space
getUniqueMemFParam [Param FParamMem]
params =
  let mems :: Map VName Space
mems = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
justMem [Param FParamMem]
params
      arrayMems :: Set VName
arrayMems = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (FParamMem -> Maybe VName
forall {d}. MemInfo d Uniqueness MemBind -> Maybe VName
justArrayMem (FParamMem -> Maybe VName)
-> (Param FParamMem -> FParamMem) -> Param FParamMem -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [Param FParamMem]
params
   in Map VName Space
mems Map VName Space -> Set VName -> Map VName Space
forall k a. Ord k => Map k a -> Set k -> Map k a
`M.restrictKeys` Set VName
arrayMems
  where
    justMem :: Param (MemInfo d u ret) -> Maybe (VName, Space)
justMem (Param Attrs
_ VName
nm (MemMem Space
sp)) = (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (VName
nm, Space
sp)
    justMem Param (MemInfo d u ret)
_ = Maybe (VName, Space)
forall a. Maybe a
Nothing
    justArrayMem :: MemInfo d Uniqueness MemBind -> Maybe VName
justArrayMem (MemArray PrimType
_ ShapeBase d
_ Uniqueness
Unique (ArrayIn VName
mem_nm IxFun
_)) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
mem_nm
    justArrayMem MemInfo d Uniqueness MemBind
_ = Maybe VName
forall a. Maybe a
Nothing

class HasMemBlock rep where
  -- | Looks up 'VName' in the given scope. If it is a 'MemArray', return the
  -- 'ArrayMemBound' information for the array.
  getScopeMemInfo :: VName -> Scope rep -> Maybe ArrayMemBound

instance HasMemBlock (Aliases SeqMem) where
  getScopeMemInfo :: VName -> Scope (Aliases SeqMem) -> Maybe ArrayMemBound
getScopeMemInfo VName
r Scope (Aliases SeqMem)
scope_env0 =
    case VName
-> Scope (Aliases SeqMem) -> Maybe (NameInfo (Aliases SeqMem))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
r Scope (Aliases SeqMem)
scope_env0 of
      Just (LetName (VarAliases
_, MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Just (FParamName (MemArray PrimType
tp Shape
shp Uniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Just (LParamName (MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Maybe (NameInfo (Aliases SeqMem))
_ -> Maybe ArrayMemBound
forall a. Maybe a
Nothing

instance HasMemBlock (Aliases GPUMem) where
  getScopeMemInfo :: VName -> Scope (Aliases GPUMem) -> Maybe ArrayMemBound
getScopeMemInfo VName
r Scope (Aliases GPUMem)
scope_env0 =
    case VName
-> Scope (Aliases GPUMem) -> Maybe (NameInfo (Aliases GPUMem))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
r Scope (Aliases GPUMem)
scope_env0 of
      Just (LetName (VarAliases
_, MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Just (FParamName (MemArray PrimType
tp Shape
shp Uniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Just (LParamName (MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Maybe (NameInfo (Aliases GPUMem))
_ -> Maybe ArrayMemBound
forall a. Maybe a
Nothing

instance HasMemBlock (Aliases MCMem) where
  getScopeMemInfo :: VName -> Scope (Aliases MCMem) -> Maybe ArrayMemBound
getScopeMemInfo VName
r Scope (Aliases MCMem)
scope_env0 =
    case VName -> Scope (Aliases MCMem) -> Maybe (NameInfo (Aliases MCMem))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
r Scope (Aliases MCMem)
scope_env0 of
      Just (LetName (VarAliases
_, MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Just (FParamName (MemArray PrimType
tp Shape
shp Uniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Just (LParamName (MemArray PrimType
tp Shape
shp NoUniqueness
_ (ArrayIn VName
m IxFun
idx))) -> ArrayMemBound -> Maybe ArrayMemBound
forall a. a -> Maybe a
Just (PrimType -> Shape -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp Shape
shp VName
m IxFun
idx)
      Maybe (NameInfo (Aliases MCMem))
_ -> Maybe ArrayMemBound
forall a. Maybe a
Nothing

-- | @True@ if the expression returns a "fresh" array.
createsNewArrOK :: Exp rep -> Bool
createsNewArrOK :: forall rep. Exp rep -> Bool
createsNewArrOK (BasicOp Replicate {}) = Bool
True
createsNewArrOK (BasicOp Iota {}) = Bool
True
createsNewArrOK (BasicOp Manifest {}) = Bool
True
createsNewArrOK (BasicOp Concat {}) = Bool
True
createsNewArrOK (BasicOp ArrayLit {}) = Bool
True
createsNewArrOK (BasicOp Scratch {}) = Bool
True
createsNewArrOK Exp rep
_ = Bool
False

-- | Memory-block removal from active-coalescing table
--   should only be handled via this function, it is easy
--   to run into infinite execution problem; i.e., the
--   fix-pointed iteration of coalescing transformation
--   assumes that whenever a coalescing fails it is
--   recorded in the @inhibit@ table.
markFailedCoal ::
  (CoalsTab, InhibitTab) ->
  VName ->
  (CoalsTab, InhibitTab)
markFailedCoal :: (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
coal_tab, InhibitTab
inhb_tab) VName
src_mem =
  case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
src_mem CoalsTab
coal_tab of
    Maybe CoalsEntry
Nothing -> (CoalsTab
coal_tab, InhibitTab
inhb_tab)
    Just CoalsEntry
coale ->
      let failed_set :: Names
failed_set = VName -> Names
oneName (VName -> Names) -> VName -> Names
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> VName
dstmem CoalsEntry
coale
          failed_set' :: Names
failed_set' = Names
failed_set Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
src_mem InhibitTab
inhb_tab)
       in ( VName -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
src_mem CoalsTab
coal_tab,
            VName -> Names -> InhibitTab -> InhibitTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
src_mem Names
failed_set' InhibitTab
inhb_tab
          )

-- | promotion from active-to-successful coalescing tables
--   should be handled with this function (for clarity).
markSuccessCoal ::
  (CoalsTab, CoalsTab) ->
  VName ->
  CoalsEntry ->
  (CoalsTab, CoalsTab)
markSuccessCoal :: (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
actv, CoalsTab
succc) VName
m_b CoalsEntry
info_b =
  ( VName -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
m_b CoalsTab
actv,
    VName -> CoalsEntry -> CoalsTab -> CoalsTab
appendCoalsInfo VName
m_b CoalsEntry
info_b CoalsTab
succc
  )

-- | merges entries in the coalesced table.
appendCoalsInfo :: VName -> CoalsEntry -> CoalsTab -> CoalsTab
appendCoalsInfo :: VName -> CoalsEntry -> CoalsTab -> CoalsTab
appendCoalsInfo VName
mb CoalsEntry
info_new CoalsTab
coalstab =
  case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
coalstab of
    Maybe CoalsEntry
Nothing -> VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info_new CoalsTab
coalstab
    Just CoalsEntry
info_old -> VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb (CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsEntry
info_old CoalsEntry
info_new) CoalsTab
coalstab

-- | Attempt to convert a 'VName' to a PrimExp.
--
-- First look in 'ScalarTab' to see if we have recorded the scalar value of the
-- argument. Otherwise look up the type of the argument and return a 'LeafExp'
-- if it is a 'PrimType'.
vnameToPrimExp ::
  (AliasableRep rep) =>
  ScopeTab rep ->
  ScalarTab ->
  VName ->
  Maybe (PrimExp VName)
vnameToPrimExp :: forall rep.
AliasableRep rep =>
ScopeTab rep -> ScalarTab -> VName -> Maybe (PrimExp VName)
vnameToPrimExp ScopeTab rep
scopetab ScalarTab
scaltab VName
v =
  VName -> ScalarTab -> Maybe (PrimExp VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScalarTab
scaltab
    Maybe (PrimExp VName)
-> Maybe (PrimExp VName) -> Maybe (PrimExp VName)
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ( VName -> ScopeTab rep -> Maybe (NameInfo (Aliases rep))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScopeTab rep
scopetab
            Maybe (NameInfo (Aliases rep))
-> (NameInfo (Aliases rep) -> Maybe PrimType) -> Maybe PrimType
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TypeBase Shape NoUniqueness -> Maybe PrimType
forall shp u. TypeBase shp u -> Maybe PrimType
toPrimType (TypeBase Shape NoUniqueness -> Maybe PrimType)
-> (NameInfo (Aliases rep) -> TypeBase Shape NoUniqueness)
-> NameInfo (Aliases rep)
-> Maybe PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NameInfo (Aliases rep) -> TypeBase Shape NoUniqueness
forall t. Typed t => t -> TypeBase Shape NoUniqueness
typeOf
            Maybe PrimType
-> (PrimType -> PrimExp VName) -> Maybe (PrimExp VName)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v
        )

-- | Attempt to extract the 'PrimType' from a 'TypeBase'.
toPrimType :: TypeBase shp u -> Maybe PrimType
toPrimType :: forall shp u. TypeBase shp u -> Maybe PrimType
toPrimType (Prim PrimType
pt) = PrimType -> Maybe PrimType
forall a. a -> Maybe a
Just PrimType
pt
toPrimType TypeBase shp u
_ = Maybe PrimType
forall a. Maybe a
Nothing