{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Safety condition 3 verification.
module Futhark.Optimise.MemoryBlockMerging.Coalescing.SafetyCondition3
  ( getVarUsesBetween
  ) where

import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad
import Control.Monad.RWS

import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (
  ExplicitMemory, ExplicitMemorish)
import Futhark.Representation.Kernels.Kernel

import Futhark.Optimise.MemoryBlockMerging.Miscellaneous


data Context = Context
  { ctxSource :: VName
  , ctxDestination :: VName
  }
  deriving (Show)

data Current = Current
  { curHasReachedSource :: Bool
  , curHasReachedDestination :: Bool
  , curVars :: Names
  }
  deriving (Show)

newtype FindM lore a = FindM { unFindM :: RWS Context () Current a }
  deriving (Monad, Functor, Applicative,
            MonadReader Context,
            MonadState Current)

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalk lore)

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

modifyCurVars :: (Names -> Names) -> FindM lore ()
modifyCurVars f = modify $ \c -> c { curVars = f $ curVars c }

-- Find all the variables present between the creations of two variables (not
-- inclusive).
getVarUsesBetween :: FunDef ExplicitMemory
                  -> VName -> VName
                  -> Names
getVarUsesBetween fundef src dst =
  let context = Context src dst
      m = unFindM $ lookInBody $ funDefBody fundef
      res = curVars $ fst $ execRWS m context (Current False False S.empty)
  in res

lookInBody :: LoreConstraints lore =>
              Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm lore -> FindM lore ()
lookInStm stm@(Let _ _ e) = do
  let new_decls = newDeclarationsStm stm

  dst <- asks ctxDestination
  when (dst `L.elem` new_decls)
    $ modify $ \c -> c { curHasReachedDestination = True }

  is_after_source <- gets curHasReachedSource
  is_before_destination <- gets curHasReachedDestination

  unless is_before_destination $ do
    let e_free_vars = freeInExp e
        e_used_vars = S.union e_free_vars (S.fromList new_decls)

    -- If the source has been created, add the newly used variables.
    --
    -- Note that "used after creation" refers both to used in subsequent
    -- statements AND any statements in any sub-bodies (if and loop).
    when is_after_source
      $ modifyCurVars $ S.union e_used_vars

    -- If the source is present in the declarations, state that it has been
    -- created.
    src <- asks ctxSource
    when (src `L.elem` new_decls)
      $ modify $ \c -> c { curHasReachedSource = True  }

    -- RECURSIVE BODY WALK.
    case e of
      If _ body0 body1 _ -> do
        -- This is not very If-specific, but rather specific to expressions with
        -- multiple, independent bodies, where If is just the only such
        -- expression.
        --
        -- We do not want the state (for safety condition 3) after traversing
        -- the first branch to be present when traversing the second branch,
        -- since they really will never both be run, so we compute them
        -- independently and then merge them at the end.
        before <- get
        lookInBody body0
        after0 <- get
        put Current { curHasReachedSource = curHasReachedSource before
                    , curHasReachedDestination = curHasReachedDestination after0
                    , curVars = curVars before
                    }
        lookInBody body1
        after1 <- get
        put Current { curHasReachedSource =
                      curHasReachedSource after0 || curHasReachedSource after1
                    , curHasReachedDestination =
                      curHasReachedDestination after0 || curHasReachedDestination after1
                    , curVars =
                      S.union (curVars after0) (curVars after1)
                    }
      _ -> do
        -- In the general case, just look through any 'Body' you can find.  (This
        -- is the case for loops.)
        let walker = identityWalker { walkOnBody = lookInBody }
            walker_kernel = identityKernelWalker
              { walkOnKernelBody = coerce . lookInBody
              , walkOnKernelKernelBody = coerce . lookInKernelBody
              , walkOnKernelLambda = coerce . lookInBody . lambdaBody
              }
        fullWalkExpM walker walker_kernel e