{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE LambdaCase #-}
-- | Find memory block aliases.  The conceptual difference from variable aliases
-- is that if a variable x has an alias y, it means that x and y use the same
-- memory block, but if a memory block xmem has an alias ymem, it means that
-- xmem and ymem refer to the same *memory*.  This is not commutative.
module Futhark.Optimise.MemoryBlockMerging.MemoryAliases
  ( findMemAliases
  ) where

import Data.Maybe (mapMaybe)
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad.RWS

import Futhark.Representation.AST
import Futhark.Representation.Aliases
import Futhark.Representation.ExplicitMemory
       (ExplicitMemorish, ExplicitMemory)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Analysis.Alias (analyseFun)

import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types


newtype FindM lore a = FindM { unFindM :: RWS (VarMemMappings MemorySrc) [MemAliases] () a }
  deriving (Monad, Functor, Applicative,
            MonadReader (VarMemMappings MemorySrc),
            MonadWriter [MemAliases])

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalkAliases lore)

recordMapping :: MName -> MNames -> FindM lore ()
recordMapping mem mems = tell [M.singleton mem (S.delete mem mems)]

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

lookupMems :: Names -> FindM lore MNames
lookupMems var_aliases = do
  var_to_mem <- ask
  return $ S.fromList $ mapMaybe ((memSrcName <$>) . flip M.lookup var_to_mem)
    $ S.toList var_aliases

-- | Find all memory aliases in a function definition.
findMemAliases :: FunDef ExplicitMemory -> VarMemMappings MemorySrc -> MemAliases
findMemAliases fundef var_to_mem =
  let fundef' = analyseFun fundef
      m = unFindM $ lookInBody $ funDefBody fundef'
      mem_aliases = M.unionsWith S.union $ snd $ evalRWS m var_to_mem ()
      mem_aliases' = removeEmptyMaps $ expandWithAliases mem_aliases mem_aliases
  in mem_aliases'

lookInBody :: LoreConstraints lore =>
              Body (Aliases lore) -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody (Aliases lore) -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm (Aliases lore) -> FindM lore ()
lookInStm (Let (Pattern patctxelems patvalelems) _ e) = do
  forM_ (patctxelems ++ patvalelems) lookInPatElem

  case e of
    DoLoop mergectxparams mergevalparams _loopform body -> do
      -- There are most likely more body results than
      -- mergectxparams, but we are only interested in the first
      -- body results anyway (those that have a matching location
      -- with the mergectxparams).
      zipWithM_ lookInMergeCtxParam mergectxparams (bodyResult body)
      zipWithM_ lookInCtx patctxelems mergectxparams
      mapM_ (lookInMergeValParam body) mergevalparams
      mapM_ (lookInBodyTuples patctxelems (map snd mergectxparams) (bodyResult body))
        patvalelems
    If _ body_then body_else _ -> do
      -- Alias everything.  FIXME: This is maybe more conservative than
      -- necessary if the If works on tuples of arrays.
      let ress = mapMaybe subExpVar
                 (bodyResult body_then ++ bodyResult body_else)
      var_to_mem <- ask
      let mems = map memSrcName $ mapMaybe (`M.lookup` var_to_mem) ress
      forM_ patctxelems $ \case
          (PatElem patmem (_, ExpMem.MemMem{})) ->
            recordMapping patmem $ S.fromList mems
          _ -> return ()
    _ -> return ()

  fullWalkAliasesExpM walker walker_kernel e
  where walker = identityWalker
          { walkOnBody = lookInBody
          }
        walker_kernel = identityKernelWalker
          { walkOnKernelBody = coerce . lookInBody
          , walkOnKernelKernelBody = coerce . lookInKernelBody
          , walkOnKernelLambda = coerce . lookInBody . lambdaBody
          }

lookInCtx :: LoreConstraints lore =>
             PatElem (Aliases lore) -> (FParam (Aliases lore), SubExp)
          -> FindM lore ()
lookInCtx (PatElem patmem (_, ExpMem.MemMem{})) (Param parammem ExpMem.MemMem{}, _) = do
  recordMapping patmem (S.singleton parammem)
  recordMapping parammem (S.singleton patmem)
lookInCtx _ _ = return ()

lookInMergeCtxParam :: LoreConstraints lore =>
                       (FParam (Aliases lore), SubExp) -> SubExp -> FindM lore ()
lookInMergeCtxParam (Param xmem ExpMem.MemMem{}, Var param_mem) (Var body_mem_res) = do
  let aliases = S.fromList [param_mem, body_mem_res]
  recordMapping xmem aliases
lookInMergeCtxParam _ _ = return ()

lookInMergeValParam :: LoreConstraints lore =>
                       Body (Aliases lore) -> (FParam (Aliases lore), SubExp)
                    -> FindM lore ()
lookInMergeValParam body (Param _ (ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _)), _t) = do
  -- FIXME: This is maybe more conservative than necessary in case you have more
  -- than one loop array.  Fixing this would require either changing the Aliases
  -- representation, or building something on top of it.
  aliases <- S.unions
             <$> mapM (lookupMems . unNames) (fst $ fst $ bodyAttr body)
  recordMapping mem aliases
lookInMergeValParam _ _ = return ()

lookInBodyTuples :: LoreConstraints lore =>
                    [PatElem (Aliases lore)]
                 -> [SubExp] -> [SubExp]
                 -> PatElem (Aliases lore)
                 -> FindM lore ()
-- When a parameter refers to a existential memory, we want to find
-- which return memory in the loop that the existential memory refers
-- to.
lookInBodyTuples patctxelems body_params body_results
  (PatElem _ (_, ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _))) = do
  let zipped = zip3 patctxelems body_params body_results
  case L.find ((== mem) . patElemName . (\(x, _, _) -> x)) zipped of
    Just (_, Var param_mem, Var res_mem) ->
      recordMapping mem (S.fromList [param_mem, res_mem])
    _ -> return ()
lookInBodyTuples _ _ _ _ = return ()

lookInPatElem :: LoreConstraints lore =>
                 PatElem (Aliases lore) -> FindM lore ()
lookInPatElem (PatElem _ (names', ExpMem.MemArray _ _ _ (ExpMem.ArrayIn xmem _))) = do
  aliases <- lookupMems $ unNames names'
  recordMapping xmem aliases
lookInPatElem (PatElem xmem (names', ExpMem.MemMem {})) = do
  aliases <- lookupMems $ unNames names'
  recordMapping xmem aliases
lookInPatElem _ = return ()