{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Find all variable-to-memory mappings, so that other modules can lookup the
-- relation.  Maps array names to memory blocks.

module Futhark.Optimise.MemoryBlockMerging.VariableMemory
  ( findVarMemMappings
  ) where

import qualified Data.Map.Strict as M
import Control.Monad.Writer

import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
       (ExplicitMemorish, ExplicitMemory)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel

import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types


newtype FindM lore a = FindM { unFindM :: Writer (VarMemMappings MemorySrc) a }
  deriving (Monad, Functor, Applicative,
            MonadWriter (VarMemMappings MemorySrc))

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalk lore)

recordMapping :: VName -> MemorySrc -> FindM lore ()
recordMapping var memloc = tell $ M.singleton var memloc

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

-- | Find all variable-memory block mappings in a function definition.
findVarMemMappings :: FunDef ExplicitMemory -> VarMemMappings MemorySrc
findVarMemMappings fundef =
  let m = unFindM $ do
        mapM_ lookInFParam $ funDefParams fundef
        lookInBody $ funDefBody fundef
      var_to_mem = execWriter m
  in var_to_mem

lookInFParam :: LoreConstraints lore =>
                FParam lore -> FindM lore ()
lookInFParam (Param x (ExpMem.MemArray _ shape _ (ExpMem.ArrayIn xmem xixfun))) = do
  let memloc = MemorySrc xmem xixfun shape
  recordMapping x memloc
lookInFParam _ = return ()

lookInLParam :: LoreConstraints lore =>
                LParam lore -> FindM lore ()
lookInLParam (Param x (ExpMem.MemArray _ shape _ (ExpMem.ArrayIn xmem xixfun))) = do
  let memloc = MemorySrc xmem xixfun shape
  recordMapping x memloc
lookInLParam _ = return ()

lookInBody :: LoreConstraints lore =>
              Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm lore -> FindM lore ()
lookInStm (Let (Pattern _patctxelems patvalelems) _ e) = do
  mapM_ lookInPatValElem patvalelems
  fullWalkExpM walker walker_kernel e
  where walker = identityWalker
          { walkOnBody = lookInBody
          , walkOnFParam = lookInFParam
          , walkOnLParam = lookInLParam
          }
        walker_kernel = identityKernelWalker
          { walkOnKernelBody = coerce . lookInBody
          , walkOnKernelKernelBody = coerce . lookInKernelBody
          , walkOnKernelLambda = coerce . lookInLambda
          , walkOnKernelLParam = lookInLParam
          }

lookInPatValElem :: LoreConstraints lore =>
                    PatElem lore -> FindM lore ()
lookInPatValElem (PatElem x (ExpMem.MemArray _ shape _ (ExpMem.ArrayIn xmem xixfun))) = do
  let memloc = MemorySrc xmem xixfun shape
  recordMapping x memloc
lookInPatValElem _ = return ()

lookInLambda :: LoreConstraints lore =>
                Lambda lore -> FindM lore ()
lookInLambda (Lambda params body _) = do
  forM_ params lookInLParam
  lookInBody body