{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies, FlexibleInstances, FlexibleContexts, MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
-- | This representation requires that every array is given
-- information about which memory block is it based in, and how array
-- elements map to memory block offsets.  The representation is based
-- on the kernels representation, so nested parallelism does not
-- occur.
--
-- There are two primary concepts you will need to understand:
--
--  1. Memory blocks, which are Futhark values of type 'Mem'
--     (parametrized with their size).  These correspond to arbitrary
--     blocks of memory, and are created using the 'Alloc' operation.
--
--  2. Index functions, which describe a mapping from the index space
--     of an array (eg. a two-dimensional space for an array of type
--     @[[int]]@) to a one-dimensional offset into a memory block.
--     Thus, index functions describe how arbitrary-dimensional arrays
--     are mapped to the single-dimensional world of memory.
--
-- At a conceptual level, imagine that we have a two-dimensional array
-- @a@ of 32-bit integers, consisting of @n@ rows of @m@ elements
-- each.  This array could be represented in classic row-major format
-- with an index function like the following:
--
-- @
--   f(i,j) = i * m + j
-- @
--
-- When we want to know the location of element @a[2,3]@, we simply
-- call the index function as @f(2,3)@ and obtain @2*m+3@.  We could
-- also have chosen another index function, one that represents the
-- array in column-major (or "transposed") format:
--
-- @
--   f(i,j) = j * n + i
-- @
--
-- Index functions are not Futhark-level functions, but a special
-- construct that the final code generator will eventually use to
-- generate concrete access code.  By modifying the index functions we
-- can change how an array is represented in memory, which can permit
-- memory access pattern optimisations.
--
-- Every time we bind an array, whether in a @let@-binding, @loop@
-- merge parameter, or @lambda@ parameter, we have an annotation
-- specifying a memory block and an index function.  In some cases,
-- such as @let@-bindings for many expressions, we are free to specify
-- an arbitrary index function and memory block - for example, we get
-- to decide where 'Copy' stores its result - but in other cases the
-- type rules of the expression chooses for us.  For example, 'Index'
-- always produces an array in the same memory block as its input, and
-- with the same index function, except with some indices fixed.
module Futhark.Representation.ExplicitMemory
       ( -- * The Lore definition
         ExplicitMemory
       , InKernel
       , MemOp (..)
       , MemInfo (..)
       , MemBound
       , MemBind (..)
       , MemReturn (..)
       , IxFun
       , ExtIxFun
       , isStaticIxFun
       , ExpReturns
       , BodyReturns
       , FunReturns
       , noUniquenessReturns
       , bodyReturnsToExpReturns
       , ExplicitMemorish
       , expReturns
       , extReturns
       , sliceInfo
       , lookupMemInfo
       , subExpMemInfo
       , lookupMemSize
       , lookupArraySummary
       , fullyLinear
       , ixFunMatchesInnerShape
       , existentialiseIxFun

         -- * Module re-exports
       , module Futhark.Representation.AST.Attributes
       , module Futhark.Representation.AST.Traversals
       , module Futhark.Representation.AST.Pretty
       , module Futhark.Representation.AST.Syntax
       , module Futhark.Representation.Kernels.Kernel
       , module Futhark.Representation.Kernels.KernelExp
       , module Futhark.Analysis.PrimExp.Convert
       )
where

import Data.Maybe
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Except
import qualified Data.Map.Strict as M
import Data.Foldable (traverse_)
import Data.List
import Data.Monoid ((<>))

import Futhark.Analysis.Metrics
import Futhark.Representation.AST.Syntax
import Futhark.Representation.Kernels.Kernel
import Futhark.Representation.Kernels.KernelExp
import Futhark.Representation.AST.Attributes
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Pretty
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.PrimExp.Simplify
import Futhark.Util
import Futhark.Util.IntegralExp
import qualified Futhark.Util.Pretty as PP
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Representation.Aliases
  (Aliases, removeScopeAliases, removeExpAliases, removePatternAliases)
import Futhark.Representation.AST.Attributes.Ranges
import Futhark.Analysis.Usage
import qualified Futhark.Analysis.SymbolTable as ST

-- | A lore containing explicit memory information.
data ExplicitMemory
data InKernel

type ExplicitMemorish lore = (SameScope lore ExplicitMemory,
                              RetType lore ~ FunReturns,
                              BranchType lore ~ BodyReturns,
                              CanBeAliased (Op lore),
                              Attributes lore, Annotations lore,
                              TC.Checkable lore,
                              OpReturns lore)

instance IsRetType FunReturns where
  primRetType = MemPrim
  applyRetType = applyFunReturns

instance IsBodyType BodyReturns where
  primBodyType = MemPrim

data MemOp inner = Alloc SubExp Space
                   -- ^ Allocate a memory block.  This really should not be an
                   -- expression, but what are you gonna do...
                 | Inner inner
            deriving (Eq, Ord, Show)

instance FreeIn inner => FreeIn (MemOp inner) where
  freeIn (Alloc size _) = freeIn size
  freeIn (Inner k) = freeIn k

instance TypedOp inner => TypedOp (MemOp inner) where
  opType (Alloc size space) = pure [Mem size space]
  opType (Inner k) = opType k

instance AliasedOp inner => AliasedOp (MemOp inner) where
  opAliases Alloc{} = [mempty]
  opAliases (Inner k) = opAliases k

  consumedInOp Alloc{} = mempty
  consumedInOp (Inner k) = consumedInOp k

instance CanBeAliased inner => CanBeAliased (MemOp inner) where
  type OpWithAliases (MemOp inner) = MemOp (OpWithAliases inner)
  removeOpAliases (Alloc se space) = Alloc se space
  removeOpAliases (Inner k) = Inner $ removeOpAliases k

  addOpAliases (Alloc se space) = Alloc se space
  addOpAliases (Inner k) = Inner $ addOpAliases k

instance RangedOp inner => RangedOp (MemOp inner) where
  opRanges (Alloc _ _) =
    [unknownRange]
  opRanges (Inner k) =
    opRanges k

instance CanBeRanged inner => CanBeRanged (MemOp inner) where
  type OpWithRanges (MemOp inner) = MemOp (OpWithRanges inner)
  removeOpRanges (Alloc size space) = Alloc size space
  removeOpRanges (Inner k) = Inner $ removeOpRanges k

  addOpRanges (Alloc size space) = Alloc size space
  addOpRanges (Inner k) = Inner $ addOpRanges k

instance Rename inner => Rename (MemOp inner) where
  rename (Alloc size space) = Alloc <$> rename size <*> pure space
  rename (Inner k) = Inner <$> rename k

instance Substitute inner => Substitute (MemOp inner) where
  substituteNames subst (Alloc size space) = Alloc (substituteNames subst size) space
  substituteNames subst (Inner k) = Inner $ substituteNames subst k

instance PP.Pretty inner => PP.Pretty (MemOp inner) where
  ppr (Alloc e DefaultSpace) = PP.text "alloc" <> PP.apply [PP.ppr e]
  ppr (Alloc e (Space sp)) = PP.text "alloc" <> PP.apply [PP.ppr e, PP.text sp]
  ppr (Inner k) = PP.ppr k

instance OpMetrics inner => OpMetrics (MemOp inner) where
  opMetrics Alloc{} = seen "Alloc"
  opMetrics (Inner k) = opMetrics k

instance IsOp inner => IsOp (MemOp inner) where
  safeOp Alloc{} = True
  safeOp (Inner k) = safeOp k
  cheapOp (Inner k) = cheapOp k
  cheapOp Alloc{} = True

instance UsageInOp inner => UsageInOp (MemOp inner) where
  usageInOp Alloc {} = mempty
  usageInOp (Inner k) = usageInOp k

instance CanBeWise inner => CanBeWise (MemOp inner) where
  type OpWithWisdom (MemOp inner) = MemOp (OpWithWisdom inner)
  removeOpWisdom (Alloc size space) = Alloc size space
  removeOpWisdom (Inner k) = Inner $ removeOpWisdom k

instance ST.IndexOp inner => ST.IndexOp (MemOp inner) where
  indexOp vtable k (Inner op) is = ST.indexOp vtable k op is
  indexOp _ _ _ _ = Nothing

instance Annotations ExplicitMemory where
  type LetAttr    ExplicitMemory = MemInfo SubExp NoUniqueness MemBind
  type FParamAttr ExplicitMemory = MemInfo SubExp Uniqueness MemBind
  type LParamAttr ExplicitMemory = MemInfo SubExp NoUniqueness MemBind
  type RetType    ExplicitMemory = FunReturns
  type BranchType ExplicitMemory = BodyReturns
  type Op         ExplicitMemory = MemOp (Kernel InKernel)

instance Annotations InKernel where
  type LetAttr    InKernel = MemInfo SubExp NoUniqueness MemBind
  type FParamAttr InKernel = MemInfo SubExp Uniqueness MemBind
  type LParamAttr InKernel = MemInfo SubExp NoUniqueness MemBind
  type RetType    InKernel = FunReturns
  type BranchType InKernel = BodyReturns
  type Op         InKernel = MemOp (KernelExp InKernel)

-- | The index function representation used for memory annotations.
type IxFun = IxFun.IxFun (PrimExp VName)

-- | An index function that may contain existential variables.
type ExtIxFun = IxFun.IxFun (PrimExp (Ext VName))

-- | A summary of the memory information for every let-bound
-- identifier, function parameter, and return value.  Parameterisered
-- over uniqueness, dimension, and auxiliary array information.
data MemInfo d u ret = MemPrim PrimType
                     -- ^ A primitive value.
                     | MemMem d Space
                     -- ^ A memory block.
                     | MemArray PrimType (ShapeBase d) u ret
                     -- ^ The array is stored in the named memory block,
                     -- and with the given index function.  The index
                     -- function maps indices in the array to /element/
                     -- offset, /not/ byte offsets!  To translate to byte
                     -- offsets, multiply the offset with the size of the
                     -- array element type.
                     deriving (Eq, Show, Ord) --- XXX Ord?

type MemBound u = MemInfo SubExp u MemBind

instance FixExt ret => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where
  declExtTypeOf (MemPrim pt) = Prim pt
  declExtTypeOf (MemMem (Free size) space) = Mem size space
  declExtTypeOf (MemMem Ext{} space) = Mem (intConst Int32 0) space -- XXX
  declExtTypeOf (MemArray pt shape u _) = Array pt shape u

instance FixExt ret => ExtTyped (MemInfo ExtSize NoUniqueness ret) where
  extTypeOf (MemPrim pt) = Prim pt
  extTypeOf (MemMem (Free size) space) = Mem size space
  extTypeOf (MemMem Ext{} space) = Mem (intConst Int32 0) space -- XXX
  extTypeOf (MemArray pt shape u _) = Array pt shape u

instance FixExt ret => FixExt (MemInfo ExtSize u ret) where
  fixExt _ _ (MemPrim pt) = MemPrim pt
  fixExt i se (MemMem size space) = MemMem (fixExt i se size) space
  fixExt i se (MemArray pt shape u ret) =
    MemArray pt (fixExt i se shape) u (fixExt i se ret)

instance Typed (MemInfo SubExp Uniqueness ret) where
  typeOf = fromDecl . declTypeOf

instance Typed (MemInfo SubExp NoUniqueness ret) where
  typeOf (MemPrim pt) = Prim pt
  typeOf (MemMem size space) = Mem size space
  typeOf (MemArray bt shape u _) = Array bt shape u

instance DeclTyped (MemInfo SubExp Uniqueness ret) where
  declTypeOf (MemPrim bt) = Prim bt
  declTypeOf (MemMem size space) = Mem size space
  declTypeOf (MemArray bt shape u _) = Array bt shape u

instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where
  freeIn (MemArray _ shape _ ret) = freeIn shape <> freeIn ret
  freeIn (MemMem size _) = freeIn size
  freeIn (MemPrim _) = mempty

instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where
  substituteNames subst (MemArray bt shape u ret) =
    MemArray bt
    (substituteNames subst shape) u
    (substituteNames subst ret)
  substituteNames substs (MemMem size space) =
    MemMem (substituteNames substs size) space
  substituteNames _ (MemPrim bt) =
    MemPrim bt

instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where
  rename = substituteRename

simplifyIxFun :: Engine.SimplifiableLore lore =>
                 IxFun -> Engine.SimpleM lore IxFun
simplifyIxFun = traverse simplifyPrimExp

simplifyExtIxFun :: Engine.SimplifiableLore lore =>
                    ExtIxFun -> Engine.SimpleM lore ExtIxFun
simplifyExtIxFun = traverse simplifyExtPrimExp

isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun = traverse $ traverse inst
  where inst Ext{} = Nothing
        inst (Free x) = Just x

instance (Engine.Simplifiable d, Engine.Simplifiable ret) =>
         Engine.Simplifiable (MemInfo d u ret) where
  simplify (MemPrim bt) =
    return $ MemPrim bt
  simplify (MemMem size space) =
    MemMem <$> Engine.simplify size <*> pure space
  simplify (MemArray bt shape u ret) =
    MemArray bt <$> Engine.simplify shape <*> pure u <*> Engine.simplify ret

instance (PP.Pretty (TypeBase (ShapeBase d) u),
          PP.Pretty d, PP.Pretty u, PP.Pretty ret) => PP.Pretty (MemInfo d u ret) where
  ppr (MemPrim bt) = PP.ppr bt
  ppr (MemMem s DefaultSpace) =
    PP.text "mem" <> PP.parens (PP.ppr s)
  ppr (MemMem s (Space sp)) =
    PP.text "mem" <> PP.parens (PP.ppr s) <> PP.text "@" <> PP.text sp
  ppr (MemArray bt shape u ret) =
    PP.ppr (Array bt shape u) <> PP.text "@" <> PP.ppr ret

instance PP.Pretty (Param (MemInfo SubExp Uniqueness ret)) where
  ppr = PP.ppr . fmap declTypeOf

instance PP.Pretty (Param (MemInfo SubExp NoUniqueness ret)) where
  ppr = PP.ppr . fmap typeOf

instance PP.Pretty (PatElemT (MemInfo SubExp NoUniqueness ret)) where
  ppr = PP.ppr . fmap typeOf

-- | Memory information for an array bound somewhere in the program.
data MemBind = ArrayIn VName IxFun
             -- ^ Located in this memory block with this index
             -- function.
             deriving (Show)

instance Eq MemBind where
  _ == _ = True

instance Ord MemBind where
  _ `compare` _ = EQ

instance Rename MemBind where
  rename = substituteRename

instance Substitute MemBind where
  substituteNames substs (ArrayIn ident ixfun) =
    ArrayIn (substituteNames substs ident) (substituteNames substs ixfun)

instance PP.Pretty MemBind where
  ppr (ArrayIn mem ixfun) =
    PP.text "@" <> PP.ppr mem <> PP.text "->" <> PP.ppr ixfun

instance FreeIn MemBind where
  freeIn (ArrayIn mem ixfun) = freeIn mem <> freeIn ixfun

-- | A description of the memory properties of an array being returned
-- by an operation.
data MemReturn = ReturnsInBlock VName ExtIxFun
                 -- ^ The array is located in a memory block that is
                 -- already in scope.
               | ReturnsNewBlock Space Int ExtSize ExtIxFun
                 -- ^ The operation returns a new (existential) block,
                 -- with an existential or known size.
               deriving (Show)

instance Eq MemReturn where
  _ == _ = True

instance Ord MemReturn where
  _ `compare` _ = EQ

instance Rename MemReturn where
  rename = substituteRename

instance Substitute MemReturn where
  substituteNames substs (ReturnsInBlock ident ixfun) =
    ReturnsInBlock (substituteNames substs ident) (substituteNames substs ixfun)
  substituteNames substs (ReturnsNewBlock space i size ixfun) =
    ReturnsNewBlock space i (substituteNames substs size) (substituteNames substs ixfun)

instance FixExt MemReturn where
  fixExt i (Var v) (ReturnsNewBlock _ j _ ixfun)
    | j == i = ReturnsInBlock v $ fixExtIxFun i
               (primExpFromSubExp int32 (Var v)) ixfun
  fixExt i se (ReturnsNewBlock space j size ixfun) =
    ReturnsNewBlock space j' (fixExt i se size)
    (fixExtIxFun i (primExpFromSubExp int32 se) ixfun)
    where j' | i < j     = j-1
             | otherwise = j
  fixExt i se (ReturnsInBlock mem ixfun) =
    ReturnsInBlock mem (fixExtIxFun i (primExpFromSubExp int32 se) ixfun)

fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun i e = fmap $ replaceInPrimExp update
  where update (Ext j) t | j > i     = LeafExp (Ext $ j - 1) t
                         | j == i    = fmap Free e
                         | otherwise = LeafExp (Ext j) t
        update (Free x) t = LeafExp (Free x) t

leafExp :: Int -> PrimExp (Ext a)
leafExp i = LeafExp (Ext i) int32

existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun ctx = IxFun.substituteInIxFun ctx' . fmap (fmap Free)
  where ctx' = M.map leafExp $ M.fromList $ zip (map Free ctx) [0..]

instance PP.Pretty MemReturn where
  ppr (ReturnsInBlock v ixfun) =
    PP.parens $ PP.text (pretty v) <> PP.text "->" <> PP.ppr ixfun
  ppr (ReturnsNewBlock space i size ixfun) =
    PP.text ("?" ++ show i) <> space' <> PP.parens (PP.ppr size)
    <> PP.text "->" <> PP.ppr ixfun
    where space' = case space of DefaultSpace -> mempty
                                 Space s -> PP.text $ "@" ++ s

instance FreeIn MemReturn where
  freeIn (ReturnsInBlock v ixfun) = freeIn v <> freeIn ixfun
  freeIn _                        = mempty

instance Engine.Simplifiable MemReturn where
  simplify (ReturnsNewBlock space i size ixfun) =
    ReturnsNewBlock space i <$> Engine.simplify size <*> simplifyExtIxFun ixfun
  simplify (ReturnsInBlock v ixfun) =
    ReturnsInBlock <$> Engine.simplify v <*> simplifyExtIxFun ixfun


instance Engine.Simplifiable MemBind where
  simplify (ArrayIn mem ixfun) =
    ArrayIn <$> Engine.simplify mem <*> simplifyIxFun ixfun

instance Engine.Simplifiable [FunReturns] where
  simplify = mapM Engine.simplify

-- | The memory return of an expression.  An array is annotated with
-- @Maybe MemReturn@, which can be interpreted as the expression
-- either dictating exactly where the array is located when it is
-- returned (if 'Just'), or able to put it whereever the binding
-- prefers (if 'Nothing').
--
-- This is necessary to capture the difference between an expression
-- that is just an array-typed variable, in which the array being
-- "returned" is located where it already is, and a @copy@ expression,
-- whose entire purpose is to store an existing array in some
-- arbitrary location.  This is a consequence of the design decision
-- never to have implicit memory copies.
type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn)

-- | The return of a body, which must always indicate where
-- returned arrays are located.
type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn

-- | The memory return of a function, which must always indicate where
-- returned arrays are located.
type FunReturns = MemInfo ExtSize Uniqueness MemReturn

maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns (MemArray bt shape u ret) =
  MemArray bt shape u $ Just ret
maybeReturns (MemPrim bt) =
  MemPrim bt
maybeReturns (MemMem size space) =
  MemMem size space

noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemArray bt shape _ r) =
  MemArray bt shape NoUniqueness r
noUniquenessReturns (MemPrim bt) =
  MemPrim bt
noUniquenessReturns (MemMem size space) =
  MemMem size space

funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns = noUniquenessReturns . maybeReturns

bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns = noUniquenessReturns . maybeReturns

instance TC.CheckableOp ExplicitMemory where
  checkOp (Alloc size _) = TC.require [Prim int64] size
  checkOp (Inner k) = TC.subCheck $ typeCheckKernel k

instance TC.CheckableOp InKernel where
  checkOp (Alloc size _) = TC.require [Prim int64] size
  checkOp (Inner k) = TC.subCheck $ typeCheckKernelExp k

instance TC.Checkable ExplicitMemory where
  checkFParamLore = checkMemInfo
  checkLParamLore = checkMemInfo
  checkLetBoundLore = checkMemInfo
  checkRetType = mapM_ TC.checkExtType . retTypeValues
  primFParam name t = return $ Param name (MemPrim t)
  matchPattern = matchPatternToExp
  matchReturnType = matchFunctionReturnType
  matchBranchType = matchBranchReturnType

instance TC.Checkable InKernel where
  checkFParamLore = checkMemInfo
  checkLParamLore = checkMemInfo
  checkLetBoundLore = checkMemInfo
  checkRetType = mapM_ TC.checkExtType . retTypeValues
  primFParam name t = return $ Param name (MemPrim t)
  matchPattern = matchPatternToExp
  matchReturnType = matchFunctionReturnType
  matchBranchType = matchBranchReturnType

matchFunctionReturnType :: ExplicitMemorish lore =>
                           [FunReturns] -> Result -> TC.TypeM lore ()
matchFunctionReturnType rettype result = do
  TC.matchExtReturnType (fromDecl <$> ts) result
  scope <- askScope
  result_ts <- runReaderT (mapM subExpMemInfo result) $ removeScopeAliases scope
  matchReturnType rettype result result_ts
  mapM_ checkResultSubExp result
  where ts = map declExtTypeOf rettype
        checkResultSubExp Constant{} =
          return ()
        checkResultSubExp (Var v) = do
          attr <- varMemInfo v
          case attr of
            MemPrim _ -> return ()
            MemMem{} -> return ()
            MemArray _ _ _ (ArrayIn _ ixfun)
              | IxFun.isLinear ixfun ->
                return ()
              | otherwise ->
                  TC.bad $ TC.TypeError $
                  "Array " ++ pretty v ++
                  " returned by function, but has nontrivial index function " ++
                  pretty ixfun ++ " " ++ show ixfun

matchBranchReturnType :: ExplicitMemorish lore =>
                         [BodyReturns]
                      -> Body (Aliases lore)
                      -> TC.TypeM lore ()
matchBranchReturnType rettype (Body _ stms res) = do
  scope <- askScope
  ts <- runReaderT (mapM subExpMemInfo res) $ removeScopeAliases (scope <> scopeOf stms)
  matchReturnType rettype res ts

-- | Helper function for index function unification.
--
-- The first return value maps a VName (wrapped in 'Free') to its Int
-- (wrapped in 'Ext').  In case of duplicates, it is mapped to the
-- *first* Int that occurs.
--
-- The second return value maps each Int (wrapped in an 'Ext') to a
-- 'LeafExp' 'Ext' with the Int at which its associated VName first
-- occurs.
getExtMaps :: [(VName,Int)] -> (M.Map (Ext VName) (PrimExp (Ext VName)),
                                M.Map (Ext VName) (PrimExp (Ext VName)))
getExtMaps ctx_lst_ids =
  (M.map leafExp $ M.mapKeys Free $ M.fromListWith (flip const) ctx_lst_ids,
   M.fromList $
   mapMaybe (traverse (fmap (\i -> LeafExp (Ext i) int32) .
                       (`lookup` ctx_lst_ids)) .
             uncurry (flip (,)) . fmap Ext) ctx_lst_ids)

matchReturnType :: PP.Pretty u =>
                   [MemInfo ExtSize u MemReturn]
                -> [SubExp]
                -> [MemInfo SubExp NoUniqueness MemBind]
                -> TC.TypeM lore ()
matchReturnType rettype res ts = do
  let (ctx_ts, val_ts) = splitFromEnd (length rettype) ts
      (ctx_res, _val_res) = splitFromEnd (length rettype) res

      getId :: (SubExp,Int) -> Maybe (VName,Int)
      getId (Var ii, i) = Just (ii,i)
      getId (Constant _, _) = Nothing

      (ctx_map_ids, ctx_map_exts) =
        getExtMaps $ mapMaybe getId $ zip ctx_res [0..length ctx_res - 1]

      existentialiseIxFun0 :: IxFun -> ExtIxFun
      existentialiseIxFun0 = IxFun.substituteInIxFun ctx_map_ids . fmap (fmap Free)

      getCt :: (Int,SubExp) -> Maybe (Ext VName, PrimExp (Ext VName))
      getCt (_, Var _) = Nothing
      getCt (i, Constant c) = Just (Ext i, ValueExp c)

      ctx_map_cts = M.fromList $ mapMaybe getCt $
                    zip [0..length ctx_res - 1] ctx_res

      substConstsInExtIndFun :: ExtIxFun -> ExtIxFun
      substConstsInExtIndFun = IxFun.substituteInIxFun (ctx_map_cts<>ctx_map_exts)

      fetchCtx i = case maybeNth i $ zip ctx_res ctx_ts of
                     Nothing -> throwError $ "Cannot find context variable " ++
                                show i ++ " in context results: " ++ pretty ctx_res
                     Just (se, t) -> return (se, t)

      checkReturn (MemPrim x) (MemPrim y)
        | x == y = return ()
      checkReturn (MemMem x _) (MemMem y _) =
        checkDim x y
      checkReturn (MemArray x_pt x_shape _ x_ret)
                  (MemArray y_pt y_shape _ y_ret)
        | x_pt == y_pt, shapeRank x_shape == shapeRank y_shape = do
            zipWithM_ checkDim (shapeDims x_shape) (shapeDims y_shape)
            checkMemReturn x_ret y_ret
      checkReturn x y =
        throwError $ unwords ["Expected ", pretty x, " but got ", pretty y]

      checkDim (Free x) y
        | x == y = return ()
        | otherwise = throwError $ unwords ["Expected dim", pretty x,
                                            "but got", pretty y]
      checkDim (Ext i) y = do
        (x, _) <- fetchCtx i
        unless (x == y) $
          throwError $ unwords ["Expected ext dim", pretty i, "=>", pretty x,
                                "but got", pretty y]

      checkMemReturn (ReturnsInBlock x_mem x_ixfun) (ArrayIn y_mem y_ixfun)
          | x_mem == y_mem = do
              let x_ixfun' = substConstsInExtIndFun x_ixfun
                  y_ixfun' = existentialiseIxFun0   y_ixfun
              unless (x_ixfun' == y_ixfun') $
                throwError $ unwords  ["Index function unification fails1!",
                    "\nixfun of body result: ", pretty y_ixfun',
                    "\nixfun of return type: ", pretty x_ixfun',
                    "\nand context elements: ", pretty ctx_res]
      checkMemReturn (ReturnsNewBlock x_space x_ext x_mem_size x_ixfun)
                     (ArrayIn y_mem y_ixfun) = do
        (x_mem, x_mem_type)  <- fetchCtx x_ext
        let x_ixfun' = substConstsInExtIndFun x_ixfun
            y_ixfun' = existentialiseIxFun0   y_ixfun
        unless (x_ixfun' == y_ixfun') $
          throwError $ unwords  ["Index function unification fails2!",
            "\nixfun of body result: ", pretty y_ixfun',
            "\nixfun of return type: ", pretty x_ixfun',
            "\nand context elements: ", pretty ctx_res]
        case x_mem_type of
          MemMem y_mem_size y_space -> do
            unless (x_mem == Var y_mem) $
              throwError $ unwords ["Expected memory", pretty x_ext, "=>", pretty x_mem,
                                    "but got", pretty y_mem]
            unless (x_space == y_space) $
              throwError $ unwords ["Expected memory", pretty y_mem, "in space", pretty x_space,
                                    "but actually in space", pretty y_space]
            checkDim x_mem_size y_mem_size
          t ->
            throwError $ unwords ["Expected memory", pretty x_ext, "=>", pretty x_mem,
                                  "but but has type", pretty t]
      checkMemReturn x y =
        throwError $ unwords ["Expected array in", pretty x,
                              "but array returned in", pretty y]

      bad :: String -> TC.TypeM lore a
      bad s = TC.bad $ TC.TypeError $
              unlines [ "Return type"
                      , "  " ++ prettyTuple rettype
                      , "cannot match returns of results"
                      , "  " ++ prettyTuple ts
                      , s
                      ]

  either bad return =<< runExceptT (zipWithM_ checkReturn rettype val_ts)

matchPatternToExp :: (ExplicitMemorish lore) =>
                     Pattern (Aliases lore)
                  -> Exp (Aliases lore)
                  -> TC.TypeM lore ()
matchPatternToExp pat e = do
  scope <- asksScope removeScopeAliases
  rt <- runReaderT (expReturns $ removeExpAliases e) scope

  let (ctxs, vals) = bodyReturnsFromPattern $ removePatternAliases pat
      (ctx_ids, _ctx_ts) = unzip ctxs
      (_val_ids, val_ts) = unzip vals
      (ctx_map_ids, ctx_map_exts) =
        getExtMaps $ zip ctx_ids [0..length ctx_ids - 1]

  unless (length val_ts == length rt &&
          and (zipWith (matches ctx_map_ids ctx_map_exts) val_ts rt)) $
    TC.bad $ TC.TypeError $ "Expression type:\n  " ++ prettyTuple rt ++
                            "\ncannot match pattern type:\n  " ++ prettyTuple val_ts ++
                            "\nwith context elements: " ++ pretty ctx_ids
  where matches _ _ (MemPrim x) (MemPrim y) = x == y
        matches _ _ (MemMem x_size x_space) (MemMem y_size y_space) =
          x_size == y_size && x_space == y_space
        matches ctxids ctxexts (MemArray x_pt x_shape _ x_ret) (MemArray y_pt y_shape _ y_ret) =
          x_pt == y_pt && x_shape == y_shape &&
          case (x_ret, y_ret) of
            (ReturnsInBlock x_mem x_ixfun, Just (ReturnsInBlock y_mem y_ixfun)) ->
              let x_ixfun' = IxFun.substituteInIxFun ctxids  x_ixfun
                  y_ixfun' = IxFun.substituteInIxFun ctxexts y_ixfun
              in  x_mem == y_mem && x_ixfun' == y_ixfun'
            (ReturnsInBlock _ x_ixfun,
             Just (ReturnsNewBlock _ _ _ y_ixfun)) ->
              let x_ixfun' = IxFun.substituteInIxFun ctxids  x_ixfun
                  y_ixfun' = IxFun.substituteInIxFun ctxexts y_ixfun
              in  x_ixfun' == y_ixfun'
            (ReturnsNewBlock x_space x_i x_size x_ixfun,
             Just (ReturnsNewBlock y_space y_i y_size y_ixfun)) ->
              let x_ixfun' = IxFun.substituteInIxFun  ctxids x_ixfun
                  y_ixfun' = IxFun.substituteInIxFun ctxexts y_ixfun
              in  x_space == y_space && x_i == y_i &&
                  x_size == y_size && x_ixfun' == y_ixfun'
            (_, Nothing) -> True
            _ -> False
        matches _ _ _ _ = False

varMemInfo :: ExplicitMemorish lore =>
              VName -> TC.TypeM lore (MemInfo SubExp NoUniqueness MemBind)
varMemInfo name = do
  attr <- TC.lookupVar name

  case attr of
    LetInfo (_, summary) -> return summary
    FParamInfo summary -> return $ noUniquenessReturns summary
    LParamInfo summary -> return summary
    IndexInfo it -> return $ MemPrim $ IntType it

nameInfoToMemInfo :: ExplicitMemorish lore => NameInfo lore -> MemBound NoUniqueness
nameInfoToMemInfo info =
  case info of
    FParamInfo summary -> noUniquenessReturns summary
    LParamInfo summary -> summary
    LetInfo summary -> summary
    IndexInfo it -> MemPrim $ IntType it

lookupMemInfo :: (HasScope lore m, ExplicitMemorish lore) =>
                  VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo = fmap nameInfoToMemInfo . lookupInfo

subExpMemInfo :: (HasScope lore m, Monad m, ExplicitMemorish lore) =>
                 SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo (Var v) = lookupMemInfo v
subExpMemInfo (Constant v) = return $ MemPrim $ primValueType v

lookupArraySummary :: (ExplicitMemorish lore, HasScope lore m, Monad m) =>
                      VName -> m (VName, IxFun.IxFun (PrimExp VName))
lookupArraySummary name = do
  summary <- lookupMemInfo name
  case summary of
    MemArray _ _ _ (ArrayIn mem ixfun) ->
      return (mem, ixfun)
    _ ->
      fail $ "Variable " ++ pretty name ++ " does not look like an array."

lookupMemSize :: (HasScope lore m, Monad m) =>
                 VName -> m SubExp
lookupMemSize v = do
  t <- lookupType v
  case t of Mem size _ -> return size
            _ -> fail $ "lookupMemSize: " ++ pretty v ++ " is not a memory block."

checkMemInfo :: TC.Checkable lore =>
                 VName -> MemInfo SubExp u MemBind
             -> TC.TypeM lore ()
checkMemInfo _ (MemPrim _) = return ()
checkMemInfo _ (MemMem size _) =
  TC.require [Prim int64] size
checkMemInfo name (MemArray _ shape _ (ArrayIn v ixfun)) = do
  t <- lookupType v
  case t of
    Mem{} ->
      return ()
    _        ->
      TC.bad $ TC.TypeError $
      "Variable " ++ pretty v ++
      " used as memory block, but is of type " ++
      pretty t ++ "."

  TC.context ("in index function " ++ pretty ixfun) $ do
    traverse_ (TC.requirePrimExp int32) ixfun
    let ixfun_rank = IxFun.rank ixfun
        ident_rank = shapeRank shape
    unless (ixfun_rank == ident_rank) $
      TC.bad $ TC.TypeError $
      "Arity of index function (" ++ pretty ixfun_rank ++
      ") does not match rank of array " ++ pretty name ++
      " (" ++ show ident_rank ++ ")"

instance Attributes ExplicitMemory where
  expTypesFromPattern = return . map snd . snd . bodyReturnsFromPattern

instance Attributes InKernel where
  expTypesFromPattern = return . map snd . snd . bodyReturnsFromPattern

bodyReturnsFromPattern :: PatternT (MemBound NoUniqueness)
                       -> ([(VName,BodyReturns)], [(VName,BodyReturns)])
bodyReturnsFromPattern pat =
  (map asReturns $ patternContextElements pat,
   map asReturns $ patternValueElements pat)
  where ctx = patternContextElements pat

        ext (Var v)
          | Just (i, _) <- find ((==v) . patElemName . snd) $ zip [0..] ctx =
              Ext i
        ext se = Free se

        asReturns pe =
         (patElemName pe,
          case patElemAttr pe of
            MemPrim pt -> MemPrim pt
            MemMem size space -> MemMem (ext size) space
            MemArray pt shape u (ArrayIn mem ixfun) ->
              MemArray pt (Shape $ map ext $ shapeDims shape) u $
              case find ((==mem) . patElemName . snd) $ zip [0..] ctx  of
                Just (i, PatElem _ (MemMem size space)) ->
                  ReturnsNewBlock space i (ext size) $
                  existentialiseIxFun (map patElemName ctx) ixfun
                _ -> ReturnsInBlock mem $ existentialiseIxFun [] ixfun
         )

instance (PP.Pretty u, PP.Pretty r) => PrettyAnnot (PatElemT (MemInfo SubExp u r)) where
  ppAnnot = bindeeAnnot patElemName patElemAttr

instance (PP.Pretty u, PP.Pretty r) => PrettyAnnot (ParamT (MemInfo SubExp u r)) where
  ppAnnot = bindeeAnnot paramName paramAttr

instance PrettyLore ExplicitMemory where
instance PrettyLore InKernel where

bindeeAnnot :: (PP.Pretty u, PP.Pretty r) =>
               (a -> VName) -> (a -> MemInfo SubExp u r)
            -> a -> Maybe PP.Doc
bindeeAnnot bindeeName bindeeLore bindee =
  case bindeeLore bindee of
    attr@MemArray{} ->
      Just $
      PP.text "-- " <>
      PP.oneLine (PP.ppr (bindeeName bindee) <>
                  PP.text " : " <>
                  PP.ppr attr)
    MemMem {} ->
      Nothing
    MemPrim _ ->
      Nothing

extReturns :: [ExtType] -> [ExpReturns]
extReturns ts =
    evalState (mapM addAttr ts) 0
    where addAttr (Prim bt) =
            return $ MemPrim bt
          addAttr (Mem size space) =
            return $ MemMem (Free size) space
          addAttr t@(Array bt shape u)
            | existential t = do
              i <- get <* modify (+2)
              return $ MemArray bt shape u $ Just $
                ReturnsNewBlock DefaultSpace (i+1) (Ext i) $
                IxFun.iota $ map convert $ shapeDims shape
            | otherwise =
              return $ MemArray bt shape u Nothing
          convert (Ext i) = LeafExp (Ext i) int32
          convert (Free v) = Free <$> primExpFromSubExp int32 v

arrayVarReturns :: (HasScope lore m, Monad m, ExplicitMemorish lore) =>
                   VName
                -> m (PrimType, Shape, VName, IxFun.IxFun (PrimExp VName))
arrayVarReturns v = do
  summary <- lookupMemInfo v
  case summary of
    MemArray et shape _ (ArrayIn mem ixfun) ->
      return (et, Shape $ shapeDims shape, mem, ixfun)
    _ ->
      fail $ "arrayVarReturns: " ++ pretty v ++ " is not an array."

varReturns :: (HasScope lore m, Monad m, ExplicitMemorish lore) =>
              VName -> m ExpReturns
varReturns v = do
  summary <- lookupMemInfo v
  case summary of
    MemPrim bt ->
      return $ MemPrim bt
    MemArray et shape _ (ArrayIn mem ixfun) ->
      return $ MemArray et (fmap Free shape) NoUniqueness $
               Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun
    MemMem size space ->
      return $ MemMem (Free size) space

-- | The return information of an expression.  This can be seen as the
-- "return type with memory annotations" of the expression.
expReturns :: (Monad m, HasScope lore m,
               ExplicitMemorish lore) =>
              Exp lore -> m [ExpReturns]

expReturns (BasicOp (SubExp (Var v))) =
  pure <$> varReturns v

expReturns (BasicOp (Opaque (Var v))) =
  pure <$> varReturns v

expReturns (BasicOp (Repeat outer_shapes inner_shape v)) = do
  t <- repeatDims outer_shapes inner_shape <$> lookupType v
  (et, _, mem, ixfun) <- arrayVarReturns v
  let outer_shapes' = map (map (primExpFromSubExp int32) . shapeDims) outer_shapes
      inner_shape' = map (primExpFromSubExp int32) $ shapeDims inner_shape
  return [MemArray et (Shape $ map Free $ arrayDims t) NoUniqueness $
          Just $ ReturnsInBlock mem $ existentialiseIxFun [] $
          IxFun.repeat ixfun outer_shapes' inner_shape']

expReturns (BasicOp (Reshape newshape v)) = do
  (et, _, mem, ixfun) <- arrayVarReturns v
  return [MemArray et (Shape $ map (Free . newDim) newshape) NoUniqueness $
          Just $ ReturnsInBlock mem $ existentialiseIxFun [] $
          IxFun.reshape ixfun $ map (fmap $ primExpFromSubExp int32) newshape]

expReturns (BasicOp (Rearrange perm v)) = do
  (et, Shape dims, mem, ixfun) <- arrayVarReturns v
  let ixfun' = IxFun.permute ixfun perm
      dims'  = rearrangeShape perm dims
  return [MemArray et (Shape $ map Free dims') NoUniqueness $
          Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun']

expReturns (BasicOp (Rotate offsets v)) = do
  (et, Shape dims, mem, ixfun) <- arrayVarReturns v
  let offsets' = map (primExpFromSubExp int32) offsets
      ixfun' = IxFun.rotate ixfun offsets'
  return [MemArray et (Shape $ map Free dims) NoUniqueness $
          Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun']

expReturns (BasicOp (Index v slice)) = do
  info <- sliceInfo v slice
  case info of
    MemArray et shape u (ArrayIn mem ixfun) ->
      return [MemArray et (fmap Free shape) u $
              Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun]
    MemPrim pt -> return [MemPrim pt]
    MemMem d space -> return [MemMem (Free d) space]

expReturns (BasicOp (Update v _ _)) =
  pure <$> varReturns v

expReturns (BasicOp op) =
  extReturns . staticShapes <$> primOpType op

expReturns (DoLoop ctx val _ _) =
  zipWithM typeWithAttr
  (loopExtType (map (paramIdent . fst) ctx) (map (paramIdent . fst) val)) $ map fst val
    where typeWithAttr t p =
            case (t, paramAttr p) of
              (Array bt shape u, MemArray _ _ _ (ArrayIn mem ixfun))
                | Just (i, mem_p) <- isMergeVar mem,
                  Mem mem_size space <- paramType mem_p ->
                    let ext_size
                          | Just (j, _) <- isMergeVar =<< subExpVar mem_size = Ext j
                          | otherwise                                        = Free mem_size
                    in return $ MemArray bt shape u $ Just $ ReturnsNewBlock space i ext_size ixfun'
                | otherwise ->
                  return (MemArray bt shape u $
                          Just $ ReturnsInBlock mem ixfun')
                  where ixfun' = existentialiseIxFun (map paramName mergevars) ixfun
              (Array{}, _) ->
                fail "expReturns: Array return type but not array merge variable."
              (Prim bt, _) ->
                return $ MemPrim bt
              (Mem{}, _) ->
                fail "expReturns: loop returns memory block explicitly."
          isMergeVar v = find ((==v) . paramName . snd) $ zip [0..] mergevars
          mergevars = map fst $ ctx ++ val

expReturns (Apply _ _ ret _) =
  return $ map funReturnsToExpReturns ret

expReturns (If _ _ _ (IfAttr ret _)) =
  return $ map bodyReturnsToExpReturns ret

expReturns (Op op) =
  opReturns op

sliceInfo :: (Monad m, HasScope lore m, ExplicitMemorish lore) =>
             VName
          -> Slice SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo v slice = do
  (et, _, mem, ixfun) <- arrayVarReturns v
  case sliceDims slice of
    [] -> return $ MemPrim et
    dims ->
      return $ MemArray et (Shape dims) NoUniqueness $
      ArrayIn mem $ IxFun.slice ixfun
      (map (fmap (primExpFromSubExp int32)) slice)

class TypedOp (Op lore) => OpReturns lore where
  opReturns :: (Monad m, HasScope lore m) =>
               Op lore -> m [ExpReturns]
  opReturns op = extReturns <$> opType op

instance OpReturns ExplicitMemory where
  opReturns (Alloc size space) =
    return [MemMem (Free size) space]
  opReturns (Inner k@(Kernel _ _ _ body)) =
    zipWithM correct (kernelBodyResult body) =<< (extReturns <$> opType k)
    where correct (WriteReturn _ arr _) _ = varReturns arr
          correct (KernelInPlaceReturn arr) _ =
            extendedScope (varReturns arr)
            (castScope $ scopeOf $ kernelBodyStms body)
          correct _ ret = return ret
  opReturns k =
    extReturns <$> opType k

instance OpReturns InKernel where
  opReturns (Alloc size space) =
    return [MemMem (Free size) space]

  opReturns (Inner (GroupStream _ _ lam _ _)) =
    forM (groupStreamAccParams lam) $ \param ->
      case paramAttr param of
        MemPrim bt ->
          return $ MemPrim bt
        MemArray et shape _ (ArrayIn mem ixfun) ->
          return $ MemArray et (Shape $ map Free $ shapeDims shape) NoUniqueness $
          Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun
        MemMem size space ->
          return $ MemMem (Free size) space

  opReturns (Inner (GroupScan _ _ input)) =
    mapM varReturns arrs
    where arrs = map snd input

  opReturns (Inner (GroupGenReduce _ dests _ _ _ _)) =
    mapM varReturns dests

  opReturns (Inner (Barrier res)) = mapM f res
    where f (Var v) = varReturns v
          f (Constant v) = return $ MemPrim $ primValueType v

  opReturns (Inner (Combine (CombineSpace scatter cspace) ts _ _)) =
    (++) <$> mapM varReturns as <*>
    pure (extReturns $ staticShapes $ map (`arrayOfShape` shape) $ drop (sum ns*2) ts)
    where (_, ns, as) = unzip3 scatter
          shape = Shape $ map snd cspace

  opReturns k =
    extReturns <$> opType k

applyFunReturns :: Typed attr =>
                   [FunReturns]
                -> [Param attr]
                -> [(SubExp,Type)]
                -> Maybe [FunReturns]
applyFunReturns rets params args
  | Just _ <- applyRetType rettype params args =
      Just $ map correctDims rets
  | otherwise =
      Nothing
  where rettype = map declExtTypeOf rets
        parammap :: M.Map VName (SubExp, Type)
        parammap = M.fromList $
                   zip (map paramName params) args

        substSubExp (Var v)
          | Just (se,_) <- M.lookup v parammap = se
        substSubExp se = se

        correctDims (MemPrim t) =
          MemPrim t
        correctDims (MemMem (Free se) space) =
          MemMem (Free $ substSubExp se) space
        correctDims (MemMem (Ext d) space) =
          MemMem (Ext d) space
        correctDims (MemArray et shape u memsummary) =
          MemArray et (correctShape shape) u $
          correctSummary memsummary

        correctShape = Shape . map correctDim . shapeDims
        correctDim (Ext i)   = Ext i
        correctDim (Free se) = Free $ substSubExp se

        correctSummary (ReturnsNewBlock space i size ixfun) =
          ReturnsNewBlock space i size ixfun
        correctSummary (ReturnsInBlock mem ixfun) =
          -- FIXME: we should also do a replacement in ixfun here.
          ReturnsInBlock mem' ixfun
          where mem' = case M.lookup mem parammap of
                  Just (Var v, _) -> v
                  _               -> mem

-- | Is an array of the given shape stored fully flat row-major with
-- the given index function?
fullyLinear :: (Eq num, IntegralExp num) =>
               ShapeBase num -> IxFun.IxFun num -> Bool
fullyLinear shape ixfun =
  IxFun.isLinear ixfun && ixFunMatchesInnerShape shape ixfun

ixFunMatchesInnerShape :: (Eq num, IntegralExp num) =>
                          ShapeBase num -> IxFun.IxFun num -> Bool
ixFunMatchesInnerShape shape ixfun =
  drop 1 (IxFun.shape ixfun) == drop 1 (shapeDims shape)