{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Find all existential variables.
module Futhark.Optimise.MemoryBlockMerging.Existentials
  ( findExistentials
  ) where

import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad
import Control.Monad.Writer

import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (ExplicitMemorish)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel

import Futhark.Optimise.MemoryBlockMerging.Miscellaneous


newtype FindM lore a = FindM { unFindM :: Writer Names a }
  deriving (Monad, Functor, Applicative,
            MonadWriter Names)

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalk lore)

record :: VName -> FindM lore ()
record = tell . S.singleton

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

findExistentials :: LoreConstraints lore =>
                    FunDef lore -> Names
findExistentials fundef =
  let m = unFindM $ lookInBody $ funDefBody fundef
      existentials = execWriter m
  in existentials

lookInBody :: LoreConstraints lore =>
              Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm lore -> FindM lore ()
lookInStm (Let (Pattern patctxelems patvalelems) _ e) = do
  forM_ patvalelems $ \(PatElem var membound) ->
    case membound of
      ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
        when (mem `L.elem` map patElemName patctxelems)
        $ record var
      _ -> return ()

  case e of
    DoLoop mergectxparams mergevalparams _loopform _body ->
      forM_ mergevalparams $ \(Param var membound, _) ->
        case membound of
          ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
            when (mem `L.elem` map (paramName . fst) mergectxparams)
            $ record var
          _ -> return ()
    _ -> return ()

  fullWalkExpM walker walker_kernel e
  where walker = identityWalker
          { walkOnBody = lookInBody }
        walker_kernel = identityKernelWalker
          { walkOnKernelBody = coerce . lookInBody
          , walkOnKernelKernelBody = coerce . lookInKernelBody
          , walkOnKernelLambda = coerce . lookInBody . lambdaBody
          }