{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Get a mapping from statement name to PrimExp (if the statement has a
-- primitive expression) for all statements.
module Futhark.Optimise.MemoryBlockMerging.PrimExps
  ( findPrimExpsFunDef
  ) where

import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import Control.Monad
import Control.Monad.RWS

import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
       (ExplicitMemorish, ExplicitMemory)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Tools

import Futhark.Optimise.MemoryBlockMerging.Miscellaneous


type CurrentTypes = M.Map VName PrimType
type PrimExps = M.Map VName (PrimExp VName)

newtype FindM lore a = FindM { unFindM :: RWS () PrimExps CurrentTypes a }
  deriving (Monad, Functor, Applicative,
            MonadWriter PrimExps,
            MonadState CurrentTypes)

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalk lore)

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

-- Find/construct all 'PrimExp's in a function definition.
findPrimExpsFunDef :: FunDef ExplicitMemory -> PrimExps
findPrimExpsFunDef fundef =
  let m = unFindM $ do
        lookInFParams $ funDefParams fundef
        lookInBody $ funDefBody fundef
      res = snd $ evalRWS m () M.empty
  in res

lookInFParams :: LoreConstraints lore =>
                 [FParam lore] -> FindM lore ()
lookInFParams params = forM_ params $ \(Param var membound) -> do
  case typeOf membound of
    Prim pt -> modify $ M.insert var pt
    _ -> return ()

  case membound of
    ExpMem.MemArray pt shape _ (ExpMem.ArrayIn mem _) -> do
      let matchingSizeVar (Param mem1 (ExpMem.MemMem (Var mem_size) _))
            | mem1 == mem = Just mem_size
          matchingSizeVar _ = Nothing
      case mapMaybe matchingSizeVar params of
        [mem_size] -> do
          let prod_i32 = product (map (primExpFromSubExp (IntType Int32)) (shapeDims shape))
          let prod_i64 = ConvOpExp (SExt Int32 Int64) prod_i32
          let pe = prod_i64 * primByteSize pt
          tell $ M.singleton mem_size pe
        _ -> return ()
    _ -> return ()

lookInBody :: LoreConstraints lore =>
              Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm lore -> FindM lore ()
lookInStm (Let (Pattern _patctxelems patvalelems) _ e) = do
  prim_types <- get
  let varUse v = ExpMem.LeafExp v <$> M.lookup v prim_types

  case patvalelems of
    [PatElem dst _] ->
      forM_ (primExpFromExp varUse e) $ tell . M.singleton dst
    _ -> return ()

  forM_ patvalelems $ \(PatElem var membound) ->
    case typeOf membound of
      Prim pt ->
        modify $ M.insert var pt
      _ -> return ()

  -- Recursive body walk.
  fullWalkExpM walker walker_kernel e
  where walker = identityWalker
          { walkOnBody = lookInBody }
        walker_kernel = identityKernelWalker
          { walkOnKernelBody = coerce . lookInBody
          , walkOnKernelKernelBody = coerce . lookInKernelBody
          , walkOnKernelLambda = coerce . lookInBody . lambdaBody
          }