{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
-- | Perform horizontal and vertical fusion of SOACs.
module Futhark.Optimise.Fusion ( fuseSOACs )
  where

import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Except
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set      as S
import qualified Data.List         as L

import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.SOACS hiding (SOAC(..))
import qualified Futhark.Representation.Aliases as Aliases
import qualified Futhark.Representation.SOACS as Futhark
import Futhark.MonadFreshNames
import Futhark.Representation.SOACS.Simplify
import Futhark.Optimise.Fusion.LoopKernel
import Futhark.Construct
import qualified Futhark.Analysis.HORepresentation.SOAC as SOAC
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Pass

data VarEntry = IsArray VName (NameInfo SOACS) Names SOAC.Input
              | IsNotArray VName (NameInfo SOACS)

varEntryType :: VarEntry -> NameInfo SOACS
varEntryType (IsArray _ attr _ _) =
  attr
varEntryType (IsNotArray _ attr) =
  attr

varEntryAliases :: VarEntry -> Names
varEntryAliases (IsArray _ _ x _) = x
varEntryAliases _ = mempty

data FusionGEnv = FusionGEnv {
    soacs      :: M.Map VName [VName]
  -- ^ Mapping from variable name to its entire family.
  , varsInScope:: M.Map VName VarEntry
  , fusedRes   :: FusedRes
  }

lookupArr :: VName -> FusionGEnv -> Maybe SOAC.Input
lookupArr v env = asArray =<< M.lookup v (varsInScope env)
  where asArray (IsArray _ _ _ input) = Just input
        asArray IsNotArray{}          = Nothing

newtype Error = Error String

instance Show Error where
  show (Error msg) = "Fusion error:\n" ++ msg

newtype FusionGM a = FusionGM (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a)
  deriving (Monad, Applicative, Functor,
            MonadError Error,
            MonadState VNameSource,
            MonadReader FusionGEnv)

instance MonadFreshNames FusionGM where
  getNameSource = get
  putNameSource = put

instance HasScope SOACS FusionGM where
  askScope = toScope <$> asks varsInScope
    where toScope = M.map varEntryType

------------------------------------------------------------------------
--- Monadic Helpers: bind/new/runFusionGatherM, etc
------------------------------------------------------------------------

-- | Binds an array name to the set of used-array vars
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar env (Ident name t, aliases) =
  env { varsInScope = M.insert name entry $ varsInScope env }
  where entry = case t of
          Array {} -> IsArray name (LetInfo t) aliases' $ SOAC.identInput $ Ident name t
          _        -> IsNotArray name $ LetInfo t
        expand = maybe mempty varEntryAliases . flip M.lookup (varsInScope env)
        aliases' = aliases <> mconcat (map expand $ S.toList aliases)

bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars = foldl bindVar

binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding vs = local (`bindVars` vs)

gatherStmPattern :: Pattern -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern pat e = binding $ zip idents aliases
  where idents = patternIdents pat
        aliases = replicate (length (patternContextNames pat)) mempty ++
                  expAliases (Alias.analyseExp e)

bindingPat :: Pattern -> FusionGM a -> FusionGM a
bindingPat = binding . (`zip` repeat mempty) . patternIdents

bindingParams :: Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams = binding . (`zip` repeat mempty) . map paramIdent

-- | Binds an array name to the set of soac-produced vars
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar faml env (Ident nm t) =
  env { soacs       = M.insert nm faml $ soacs env
      , varsInScope = M.insert nm (IsArray nm (LetInfo t) mempty $
                                   SOAC.identInput $ Ident nm t) $
                      varsInScope env
      }

varAliases :: VName -> FusionGM Names
varAliases v = asks $ S.insert v . maybe mempty varEntryAliases .
                      M.lookup v . varsInScope

varsAliases :: Names -> FusionGM Names
varsAliases = fmap mconcat . mapM varAliases . S.toList

checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates res (BasicOp (Update src is _)) = do
  res' <- foldM addVarToInfusible res $
          src : S.toList (mconcat $ map freeIn is)
  aliases <- varAliases src
  let inspectKer k = k { inplace = aliases <> inplace k }
  return res' { kernels = M.map inspectKer $ kernels res' }
checkForUpdates res _ = return res

-- | Updates the environment: (i) the @soacs@ (map) by binding each pattern
--   element identifier to all pattern elements (identifiers) and (ii) the
--   variables in scope (map) by inserting each (pattern-array) name.
--   Finally, if the binding is an in-place update, then the @inplace@ field
--   of each (result) kernel is updated with the new in-place updates.
bindingFamily :: Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily pat = local bind
  where idents = patternIdents pat
        family = patternNames pat
        bind env = foldl (bindingFamilyVar family) env idents

bindingTransform :: PatElem -> VName -> SOAC.ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform pe srcname trns = local $ \env ->
  case M.lookup srcname $ varsInScope env of
    Just (IsArray src' _ aliases input) ->
      env { varsInScope =
              M.insert vname
              (IsArray src' (LetInfo attr) (srcname `S.insert` aliases) $
               trns `SOAC.addTransform` input) $
              varsInScope env
          }
    _ -> bindVar env (patElemIdent pe, S.singleton vname)
  where vname = patElemName pe
        attr = patElemAttr pe

-- | Binds the fusion result to the environment.
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes rrr = local (\x -> x { fusedRes = rrr })

-- | The fusion transformation runs in this monad.  The mutable
-- state refers to the fresh-names engine.
-- The reader hides the vtable that associates ... to ... (fill in, please).
-- The 'Either' monad is used for error handling.
runFusionGatherM :: MonadFreshNames m =>
                    FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM (FusionGM a) env =
  modifyNameSource $ \src -> runReader (runStateT (runExceptT a) src) env

------------------------------------------------------------------------
--- Fusion Entry Points: gather the to-be-fused kernels@pgm level    ---
---    and fuse them in a second pass!                               ---
------------------------------------------------------------------------

fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
  Pass { passName = "Fuse SOACs"
       , passDescription = "Perform higher-order optimisation, i.e., fusion."
       , passFunction = simplifySOACS <=< renameProg <=< intraproceduralTransformation fuseFun
       }

fuseFun :: FunDef -> PassM FunDef
fuseFun fun = do
  let env  = FusionGEnv { soacs = M.empty
                        , varsInScope = M.empty
                        , fusedRes = mempty
                        }
  k <- cleanFusionResult <$>
       liftEitherM (runFusionGatherM (fusionGatherFun fun) env)
  if not $ rsucc k
  then return fun
  else liftEitherM $ runFusionGatherM (fuseInFun k fun) env

fusionGatherFun :: FunDef -> FusionGM FusedRes
fusionGatherFun fundec =
  bindingParams (funDefParams fundec) $
  fusionGatherBody mempty $ funDefBody fundec

fuseInFun :: FusedRes -> FunDef -> FusionGM FunDef
fuseInFun res fundec = do
  body' <- bindingParams (funDefParams fundec) $
           bindRes res $
           fuseInBody $ funDefBody fundec
  return $ fundec { funDefBody = body' }

---------------------------------------------------
---------------------------------------------------
---- RESULT's Data Structure
---------------------------------------------------
---------------------------------------------------

-- | A type used for (hopefully) uniquely referring a producer SOAC.
-- The uniquely identifying value is the name of the first array
-- returned from the SOAC.
newtype KernName = KernName { unKernName :: VName }
  deriving (Eq, Ord, Show)

data FusedRes = FusedRes {
    rsucc :: Bool
  -- ^ Whether we have fused something anywhere.

  , outArr     :: M.Map VName KernName
  -- ^ Associates an array to the name of the
  -- SOAC kernel that has produced it.

  , inpArr     :: M.Map VName (S.Set KernName)
  -- ^ Associates an array to the names of the
  -- SOAC kernels that uses it. These sets include
  -- only the SOAC input arrays used as full variables, i.e., no `a[i]'.

  , infusible  :: Names
  -- ^ the (names of) arrays that are not fusible, i.e.,
  --
  --   1. they are either used other than input to SOAC kernels, or
  --
  --   2. are used as input to at least two different kernels that
  --      are not located on disjoint control-flow branches, or
  --
  --   3. are used in the lambda expression of SOACs

  , kernels    :: M.Map KernName FusedKer
  -- ^ The map recording the uses
  }

instance Semigroup FusedRes where
  res1 <> res2 =
    FusedRes (rsucc     res1       ||      rsucc     res2)
             (outArr    res1    `M.union`  outArr    res2)
             (M.unionWith S.union (inpArr res1) (inpArr res2) )
             (infusible res1    `S.union`  infusible res2)
             (kernels   res1    `M.union`  kernels   res2)

instance Monoid FusedRes where
  mempty = FusedRes { rsucc     = False,   outArr = M.empty, inpArr  = M.empty,
                      infusible = S.empty, kernels = M.empty }

isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers ress kers nm =
  case M.lookup nm (inpArr ress) of
    Nothing -> False
    Just s  -> not $ S.null $ s `S.difference` kers

getKersWithInpArrs :: FusedRes -> [VName] -> S.Set KernName
getKersWithInpArrs ress =
  S.unions . mapMaybe (`M.lookup` inpArr ress)

-- | extend the set of names to include all the names
--     produced via SOACs (by querring the vtable's soac)
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr =
    foldM (\y nm -> do bnd <- asks $ M.lookup nm . soacs
                       case bnd of
                         Nothing  -> return (y++[nm])
                         Just nns -> return (y++nns )
          ) []

----------------------------------------------------------------------
----------------------------------------------------------------------

soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs soac = do
  let (inp_idds, other_idds) = getIdentArr $ SOAC.inputs soac
      (inp_nms0, other_nms0) = (inp_idds, other_idds)
  inp_nms   <- expandSoacInpArr   inp_nms0
  other_nms <- expandSoacInpArr other_nms0
  return (inp_nms, other_nms)

addNewKerWithInfusible :: FusedRes -> ([Ident], Certificates, SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible res (idd, cs, soac, consumed) ufs = do
  nm_ker <- KernName <$> newVName "ker"
  scope <- askScope
  let out_nms = map identName idd
      new_ker = newKernel cs soac consumed out_nms scope
      comb    = M.unionWith S.union
      os' = M.fromList [(arr,nm_ker) | arr <- out_nms]
            `M.union` outArr res
      is' = M.fromList [(arr,S.singleton nm_ker)
                         | arr <- map SOAC.inputArray $ SOAC.inputs soac]
            `comb` inpArr res
  return $ FusedRes (rsucc res) os' is' ufs
           (M.insert nm_ker new_ker (kernels res))

lookupInput :: VName -> FusionGM (Maybe SOAC.Input)
lookupInput name = asks $ lookupArr name

inlineSOACInput :: SOAC.Input -> FusionGM SOAC.Input
inlineSOACInput (SOAC.Input ts v t) = do
  maybe_inp <- lookupInput v
  case maybe_inp of
    Nothing ->
      return $ SOAC.Input ts v t
    Just (SOAC.Input ts2 v2 t2) ->
      return $ SOAC.Input (ts2<>ts) v2 t2

inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs soac = do
  inputs' <- mapM inlineSOACInput $ SOAC.inputs soac
  return $ inputs' `SOAC.setInputs` soac


-- | Attempts to fuse between SOACs. Input:
--   @rem_bnds@ are the bindings remaining in the current body after @orig_soac@.
--   @lam_used_nms@ the infusible names
--   @res@ the fusion result (before processing the current soac)
--   @orig_soac@ and @out_idds@ the current SOAC and its binding pattern
--   @consumed@ is the set of names consumed by the SOAC.
--   Output: a new Fusion Result (after processing the current SOAC binding)
greedyFuse :: [Stm] -> Names -> FusedRes -> (Pattern, Certificates, SOAC, Names)
           -> FusionGM FusedRes
greedyFuse rem_bnds lam_used_nms res (out_idds, cs, orig_soac, consumed) = do
  soac <- inlineSOACInputs orig_soac
  (inp_nms, other_nms) <- soacInputs soac
  -- Assumption: the free vars in lambda are already in @infusible res@.
  let out_nms     = patternNames out_idds
      isInfusible = (`S.member` infusible res)
      is_screma  = case orig_soac of
                       SOAC.Screma _ form _ ->
                         (isJust (isRedomapSOAC form) || isJust (isScanomapSOAC form)) &&
                         not (isJust (isReduceSOAC form) || isJust (isScanSOAC form))
                       _ -> False
  --
  -- Conditions for fusion:
  -- If current soac is a replicate OR (current soac a redomap/scanomap AND
  --    (i) none of @out_idds@ belongs to the infusible set)
  -- THEN try applying producer-consumer fusion
  -- ELSE try applying horizontal        fusion
  -- (without duplicating computation in both cases)

  (ok_kers_compat, fused_kers, fused_nms, old_kers, oldker_nms) <-
        if   is_screma || any isInfusible out_nms
        then horizontGreedyFuse rem_bnds res (out_idds, cs, soac, consumed)
        else prodconsGreedyFuse          res (out_idds, cs, soac, consumed)
  --
  -- (ii) check whether fusing @soac@ will violate any in-place update
  --      restriction, e.g., would move an input array past its in-place update.
  let all_used_names = S.toList $ S.unions [lam_used_nms, S.fromList inp_nms, S.fromList other_nms]
      has_inplace ker = any (`S.member` inplace ker) all_used_names
      ok_inplace = not $ any has_inplace old_kers
  --
  -- (iii)  there are some kernels that use some of `out_idds' as inputs
  -- (iv)   and producer-consumer or horizontal fusion succeeds with those.
  let fusible_ker = not (null old_kers) && ok_inplace && ok_kers_compat
  --
  -- Start constructing the fusion's result:
  --  (i) inparr ids other than vars will be added to infusible list,
  -- (ii) will also become part of the infusible set the inparr vars
  --         that also appear as inparr of another kernel,
  --         BUT which said kernel is not the one we are fusing with (now)!
  let mod_kerS  = if fusible_ker then S.fromList oldker_nms else S.empty
  let used_inps = filter (isInpArrInResModKers res mod_kerS) inp_nms
  let ufs       = S.unions [infusible res, S.fromList used_inps,
                             S.fromList other_nms `S.difference`
                             S.fromList (map SOAC.inputArray $ SOAC.inputs soac)]
  let comb      = M.unionWith S.union

  if not fusible_ker then
    addNewKerWithInfusible res (patternIdents out_idds, cs, soac, consumed) ufs
  else do
     -- Need to suitably update `inpArr':
     --   (i) first remove the inpArr bindings of the old kernel
     let inpArr' =
            foldl (\inpa (kold, knm) ->
                    S.foldl'
                        (\inpp nm ->
                           case M.lookup nm inpp of
                             Nothing -> inpp
                             Just s  -> let new_set = S.delete knm s
                                        in if S.null new_set
                                           then M.delete nm         inpp
                                           else M.insert nm new_set inpp
                        )
                    inpa $ arrInputs kold
                 )
            (inpArr res) (zip old_kers oldker_nms)
     --  (ii) then add the inpArr bindings of the new kernel
     let fused_ker_nms = zip fused_nms fused_kers
         inpArr''= foldl (\inpa' (knm, knew) ->
                             M.fromList [ (k, S.singleton knm)
                                         | k <- S.toList $ arrInputs knew ]
                             `comb` inpa'
                         )
                   inpArr' fused_ker_nms
     -- Update the kernels map (why not delete the ones that have been fused?)
     let kernels' = M.fromList fused_ker_nms `M.union` kernels res
     -- nothing to do for `outArr' (since we have not added a new kernel)
     -- DO IMPROVEMENT: attempt to fuse the resulting kernel AGAIN until it fails,
     --                 but make sure NOT to add a new kernel!
     return $ FusedRes True (outArr res) inpArr'' ufs kernels'

prodconsGreedyFuse :: FusedRes -> (Pattern, Certificates, SOAC, Names)
                   -> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse res (out_idds, cs, soac, consumed) = do
  let out_nms        = patternNames out_idds    -- Extract VNames from output patterns
      to_fuse_knmSet = getKersWithInpArrs res out_nms  -- Find kernels which consume outputs
      to_fuse_knms   = S.toList to_fuse_knmSet
      lookup_kern k  = case M.lookup k (kernels res) of
                         Nothing  -> throwError $ Error
                                     ("In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
                                      ++ "kernel name not found in kernels field!")
                         Just ker -> return ker
  to_fuse_kers <- mapM lookup_kern to_fuse_knms -- Get all consumer kernels
  -- try producer-consumer fusion
  (ok_kers_compat, fused_kers) <- do
      kers <- forM to_fuse_kers $
                attemptFusion S.empty (patternNames out_idds) soac consumed
      case sequence kers of
        Nothing    -> return (False, [])
        Just kers' -> return (True, map certifyKer kers')
  return (ok_kers_compat, fused_kers, to_fuse_knms, to_fuse_kers, to_fuse_knms)
  where certifyKer k = k { certificates = certificates k <> cs }

horizontGreedyFuse :: [Stm] -> FusedRes -> (Pattern, Certificates, SOAC, Names)
                   -> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse rem_bnds res (out_idds, cs, soac, consumed) = do
  (inp_nms, _) <- soacInputs soac
  let out_nms        = patternNames out_idds
      infusible_nms  = S.fromList $ filter (`S.member` infusible res) out_nms
      out_arr_nms    = case soac of
                        -- the accumulator result cannot be fused!
                        SOAC.Screma _ (ScremaForm (_, scan_nes) (_, _, red_nes) _) _ ->
                          drop (length scan_nes + length red_nes) out_nms
                        SOAC.Stream _ frm _ _ -> drop (length $ getStreamAccums frm) out_nms
                        _ -> out_nms
      to_fuse_knms1  = S.toList $ getKersWithInpArrs res (out_arr_nms++inp_nms)
      to_fuse_knms2  = getKersWithSameInpSize (SOAC.width soac) res
      to_fuse_knms   = S.toList $ S.fromList $ to_fuse_knms1 ++ to_fuse_knms2
      lookupKernel k  = case M.lookup k (kernels res) of
                          Nothing  -> throwError $ Error
                                      ("In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
                                       ++ "kernel name not found in kernels field!")
                          Just ker -> return ker

  -- for each kernel get the index in the bindings where the kernel is located
  -- and sort based on the index so that partial fusion may succeed.
  let bnd_nms = map (patternNames . stmPattern) rem_bnds
  kernminds <- forM to_fuse_knms $ \ker_nm -> do
    ker <- lookupKernel ker_nm
    let out_nm  = case fsoac ker of
                    SOAC.Stream _ frm _ _
                      | x:_ <- drop (length $ getStreamAccums frm) $ outNames ker ->
                        x
                    SOAC.Screma _ (ScremaForm (_, scan_nes) (_, _, red_nes) _) _
                      | x:_ <- drop (length scan_nes + length red_nes) $ outNames ker ->
                        x
                    _ -> head $ outNames ker
    case L.findIndex (elem out_nm) bnd_nms of
      Nothing -> return Nothing
      Just i  -> return $ Just (ker,ker_nm,i)

  scope <- askScope
  let kernminds' = L.sortBy (\(_,_,i1) (_,_,i2)->compare i1 i2) $ catMaybes kernminds
      soac_kernel = newKernel cs soac consumed out_nms scope
  -- now try to fuse kernels one by one (in a fold); @ok_ind@ is the index of the
  -- kernel until which fusion succeded, and @fused_ker@ is the resulted kernel.
  (_,ok_ind,_,fused_ker,_) <-
      foldM (\(cur_ok,n,prev_ind,cur_ker,ufus_nms) (ker, _ker_nm, bnd_ind) -> do
                -- check that we still try fusion and that the intermediate
                -- bindings do not use the results of cur_ker
                let curker_outnms  = outNames cur_ker
                    curker_outset  = S.fromList curker_outnms
                    new_ufus_nms   = S.fromList $ outNames ker ++ S.toList ufus_nms
                    -- disable horizontal fusion in the case when an output array of
                    -- producer SOAC is a non-trivially transformed input of the consumer
                    out_transf_ok  = let ker_inp = SOAC.inputs $ fsoac ker
                                         unfuse1 = S.fromList (map SOAC.inputArray ker_inp) `S.difference`
                                                   S.fromList (mapMaybe SOAC.isVarInput ker_inp)
                                         unfuse2 = S.intersection curker_outset ufus_nms
                                     in  S.null $ S.intersection unfuse1 unfuse2
                    -- Disable horizontal fusion if consumer has any
                    -- output transforms.
                    cons_no_out_transf = SOAC.nullTransforms $ outputTransform ker

                consumer_ok   <- do let consumer_bnd   = rem_bnds !! bnd_ind
                                    maybesoac <- SOAC.fromExp $ stmExp consumer_bnd
                                    case maybesoac of
                                      -- check that consumer's lambda body does not use
                                      -- directly the produced arrays (e.g., see noFusion3.fut).
                                      Right conssoac -> return $ S.null $ S.intersection curker_outset $
                                                                 freeInBody $ lambdaBody $ SOAC.lambda conssoac
                                      Left _         -> return True

                let interm_bnds_ok = cur_ok && consumer_ok && out_transf_ok && cons_no_out_transf &&
                      foldl (\ok bnd-> ok && -- hardwired to False after first fail
                                       -- (i) check that the in-between bindings do
                                       --     not use the result of current kernel OR
                                       S.null ( S.intersection curker_outset $
                                                      freeInExp (stmExp bnd) ) ||
                                       --(ii) that the pattern-binding corresponds to
                                       --     the result of the consumer kernel; in the
                                       --     latter case it means it corresponds to a
                                       --     kernel that has been fused in the consumer,
                                       --     hence it should be ignored
                                        not ( null $ curker_outnms `L.intersect`
                                              patternNames (stmPattern bnd))
                            ) True (drop (prev_ind+1) $ take bnd_ind rem_bnds)
                if not interm_bnds_ok then return (False,n,bnd_ind,cur_ker,S.empty)
                else do new_ker <- attemptFusion ufus_nms (outNames cur_ker)
                                   (fsoac cur_ker) (fusedConsumed cur_ker) ker
                        case new_ker of
                          Nothing -> return (False, n,bnd_ind,cur_ker,S.empty)
                          Just krn-> return (True,n+1,bnd_ind,krn,new_ufus_nms)
            ) (True,0,0,soac_kernel,infusible_nms) kernminds'

  -- Find the kernels we have fused into and the name of the last such
  -- kernel (if any).
  let (to_fuse_kers', to_fuse_knms',_) = unzip3 $ take ok_ind kernminds'
      new_kernms = drop (ok_ind-1) to_fuse_knms'

  return (ok_ind>0, [fused_ker], new_kernms, to_fuse_kers', to_fuse_knms')

  where getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
        getKersWithSameInpSize sz ress =
            map fst $ filter (\ (_,ker) -> sz == SOAC.width (fsoac ker)) $ M.toList $ kernels ress

------------------------------------------------------------------------
------------------------------------------------------------------------
------------------------------------------------------------------------
--- Fusion Gather for EXPRESSIONS and BODIES,                        ---
--- i.e., where work is being done:                                  ---
---    i) bottom-up AbSyn traversal (backward analysis)              ---
---   ii) soacs are fused greedily iff does not duplicate computation---
--- E.g., (y1, y2, y3) = mapT(f, x1, x2[i])                          ---
---       (z1, z2)     = mapT(g1, y1, y2)                            ---
---       (q1, q2)     = mapT(g2, y3, z1, a, y3)                     ---
---       res          = reduce(op, ne, q1, q2, z2, y1, y3)          ---
--- can be fused if y1,y2,y3, z1,z2, q1,q2 are not used elsewhere:   ---
---       res = redomap(op, \(x1,x2i,a)->                            ---
---                             let (y1,y2,y3) = f (x1, x2i)       in---
---                             let (z1,z2)    = g1(y1, y2)        in---
---                             let (q1,q2)    = g2(y3, z1, a, y3) in---
---                             (q1, q2, z2, y1, y3)                 ---
---                     x1, x2[i], a)                                ---
------------------------------------------------------------------------
------------------------------------------------------------------------
------------------------------------------------------------------------

fusionGatherBody :: FusedRes -> Body -> FusionGM FusedRes

-- Some forms of do-loops can profitably be considered streamSeqs.  We
-- are careful to ensure that the generated nested loop cannot itself
-- be considered a stream, to avoid infinite recursion.
fusionGatherBody fres (Body blore (stmsToList ->
                                    Let (Pattern [] pes) bndtp
                                    (DoLoop [] merge (ForLoop i it w loop_vars) body)
                                    :bnds) res) | not $ null loop_vars = do
  let (merge_params,merge_init) = unzip merge
      (loop_params,loop_arrs) = unzip loop_vars
  chunk_size <- newVName "chunk_size"
  offset <- newVName "offset"
  let chunk_param = Param chunk_size $ Prim int32
      offset_param = Param offset $ Prim $ IntType it

  acc_params <- forM merge_params $ \p ->
    Param <$> newVName (baseString (paramName p) ++ "_outer") <*>
    pure (paramType p)

  chunked_params <- forM loop_vars $ \(p,arr) ->
    Param <$> newVName (baseString arr ++ "_chunk") <*>
    pure (paramType p `arrayOfRow` Futhark.Var chunk_size)

  let lam_params = chunk_param : acc_params ++ [offset_param] ++ chunked_params

  lam_body <- runBodyBinder $ localScope (scopeOfLParams lam_params) $ do
    let merge' = zip merge_params $ map (Futhark.Var . paramName) acc_params
    j <- newVName "j"
    loop_body <- runBodyBinder $ do
      forM_ (zip loop_params chunked_params) $ \(p,a_p) ->
        letBindNames_ [paramName p] $ BasicOp $ Index (paramName a_p) $
        fullSlice (paramType a_p) [DimFix $ Futhark.Var j]
      letBindNames_ [i] $ BasicOp $ BinOp (Add it) (Futhark.Var offset) (Futhark.Var j)
      return body
    eBody [pure $
           DoLoop [] merge' (ForLoop j it (Futhark.Var chunk_size) []) loop_body,
           pure $
           BasicOp $ BinOp (Add Int32) (Futhark.Var offset) (Futhark.Var chunk_size)]
  let lam = Lambda { lambdaParams = lam_params
                   , lambdaBody = lam_body
                   , lambdaReturnType = map paramType $ acc_params ++ [offset_param]
                   }
      stream = Futhark.Stream w (Sequential $ merge_init ++ [intConst it 0]) lam loop_arrs

  -- It is important that the (discarded) final-offset is not the
  -- first element in the pattern, as we use the first element to
  -- identify the SOAC in the second phase of fusion.
  discard <- newVName "discard"
  let discard_pe = PatElem discard $ Prim int32

  fusionGatherBody fres $ Body blore
    (oneStm (Let (Pattern [] (pes<>[discard_pe])) bndtp (Op stream))<>stmsFromList bnds) res

fusionGatherBody fres (Body _ (stmsToList -> (bnd@(Let pat _ e):bnds)) res) = do
  maybesoac <- SOAC.fromExp e
  case maybesoac of
    Right soac@(SOAC.Scatter _len lam _ivs _as) -> do
      -- We put the variables produced by Scatter into the infusible
      -- set to force horizontal fusion.  It is not possible to
      -- producer/consumer-fuse Scatter anyway.
      fres' <- addNamesToInfusible fres $ S.fromList $ patternNames pat
      mapLike fres' soac lam

    Right soac@(SOAC.GenReduce _ _ lam _) -> do
      -- We put the variables produced by GenReduce into the infusible
      -- set to force horizontal fusion.  It is not possible to
      -- producer/consumer-fuse GenReduce anyway.
      fres' <- addNamesToInfusible fres $ S.fromList $ patternNames pat
      mapLike fres' soac lam

    Right soac@(SOAC.Screma _ (ScremaForm (scan_lam, scan_nes)
                                              (_, reduce_lam, reduce_nes)
                                              map_lam) _) ->
      reduceLike soac [scan_lam, reduce_lam, map_lam] $ scan_nes <> reduce_nes

    Right soac@(SOAC.Stream _ form lam _) -> do
      -- a redomap does not neccessarily start a new kernel, e.g.,
      -- @let a= reduce(+,0,A) in ... bnds ... in let B = map(f,A)@
      -- can be fused into a redomap that replaces the @map@, if @a@
      -- and @B@ are defined in the same scope and @bnds@ does not uses @a@.
      -- a redomap always starts a new kernel
      let lambdas = case form of
                        Parallel _ _ lout _ -> [lout, lam]
                        _                   -> [lam]
      reduceLike soac lambdas $ getStreamAccums form

    _ | [pe] <- patternValueElements pat,
        Just (src,trns) <- SOAC.transformFromExp (stmCerts bnd) e ->
          bindingTransform pe src trns $ fusionGatherBody fres body
      | otherwise -> do
          let pat_vars = map (BasicOp . SubExp . Var) $ patternNames pat
          bres <- gatherStmPattern pat e $ fusionGatherBody fres body
          bres' <- checkForUpdates bres e
          foldM fusionGatherExp bres' (e:pat_vars)

  where body = mkBody (stmsFromList bnds) res
        cs = stmCerts bnd
        rem_bnds = bnd : bnds
        consumed = consumedInExp $ Alias.analyseExp e

        reduceLike soac lambdas nes = do
          (used_lam, lres)  <- foldM fusionGatherLam (S.empty, fres) lambdas
          bres  <- bindingFamily pat $ fusionGatherBody lres body
          bres' <- foldM fusionGatherSubExp bres nes
          consumed' <- varsAliases consumed
          greedyFuse rem_bnds used_lam bres' (pat, cs, soac, consumed')

        mapLike fres' soac lambda = do
          bres  <- bindingFamily pat $ fusionGatherBody fres' body
          (used_lam, blres) <- fusionGatherLam (S.empty, bres) lambda
          consumed' <- varsAliases consumed
          greedyFuse rem_bnds used_lam blres (pat, cs, soac, consumed')

fusionGatherBody fres (Body _ _ res) =
  foldM fusionGatherExp fres $ map (BasicOp . SubExp) res

fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes

-----------------------------------------
---- Index/If    ----
-----------------------------------------

fusionGatherExp fres (DoLoop ctx val form loop_body) = do
  fres' <- addNamesToInfusible fres $ freeIn form <> freeIn ctx <> freeIn val
  let form_idents =
        case form of
          ForLoop i _ _ loopvars ->
            Ident i (Prim int32) : map (paramIdent . fst) loopvars
          WhileLoop{} -> []

  new_res <- binding (zip (form_idents ++ map (paramIdent . fst) (ctx<>val)) $
                      repeat mempty) $
    fusionGatherBody mempty loop_body
  -- make the inpArr infusible, so that they
  -- cannot be fused from outside the loop:
  let (inp_arrs, _) = unzip $ M.toList $ inpArr new_res
  let new_res' = new_res { infusible = foldl (flip S.insert) (infusible new_res) inp_arrs }
  -- merge new_res with fres'
  return $ new_res' <> fres'

fusionGatherExp fres (If cond e_then e_else _) = do
    then_res <- fusionGatherBody mempty e_then
    else_res <- fusionGatherBody mempty e_else
    let both_res = then_res <> else_res
    fres'    <- fusionGatherSubExp fres cond
    mergeFusionRes fres' both_res

-----------------------------------------------------------------------------------
--- Errors: all SOACs, (because normalization ensures they appear
--- directly in let exp, i.e., let x = e)
-----------------------------------------------------------------------------------

fusionGatherExp _ (Op Futhark.Screma{}) = errorIllegal "screma"
fusionGatherExp _ (Op Futhark.Scatter{}) = errorIllegal "write"

-----------------------------------
---- Generic Traversal         ----
-----------------------------------

fusionGatherExp fres e =
  addNamesToInfusible fres $ freeInExp e

fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp fres (Var idd) = addVarToInfusible fres idd
fusionGatherSubExp fres _         = return fres

addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible fres = foldM addVarToInfusible fres . S.toList

addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible fres name = do
  trns <- asks $ lookupArr name
  let name' = case trns of
        Nothing         -> name
        Just (SOAC.Input _ orig _) -> orig
  return fres { infusible = S.insert name' $ infusible fres }

-- Lambdas create a new scope.  Disallow fusing from outside lambda by
-- adding inp_arrs to the infusible set.
fusionGatherLam :: (Names, FusedRes) -> Lambda -> FusionGM (S.Set VName, FusedRes)
fusionGatherLam (u_set,fres) (Lambda idds body _) = do
    new_res <- bindingParams idds $ fusionGatherBody mempty body
    -- make the inpArr infusible, so that they
    -- cannot be fused from outside the lambda:
    let inp_arrs = S.fromList $ M.keys $ inpArr new_res
    let unfus = infusible new_res `S.union` inp_arrs
    bnds <- M.keys <$> asks varsInScope
    let unfus'  = unfus `S.intersection` S.fromList bnds
    -- merge fres with new_res'
    let new_res' = new_res { infusible = unfus' }
    -- merge new_res with fres'
    return (u_set `S.union` unfus', new_res' <> fres)

-------------------------------------------------------------
-------------------------------------------------------------
--- FINALLY, Substitute the kernels in function
-------------------------------------------------------------
-------------------------------------------------------------

fuseInBody :: Body -> FusionGM Body

fuseInBody (Body _ stms res)
  | Let pat aux e:bnds <- stmsToList stms = do
      body' <- bindingPat pat $ fuseInBody $ mkBody (stmsFromList bnds) res
      soac_bnds <- replaceSOAC pat aux e
      return $ insertStms soac_bnds body'
  | otherwise = return $ Body () mempty res

fuseInExp :: Exp -> FusionGM Exp

-- Handle loop specially because we need to bind the types of the
-- merge variables.
fuseInExp (DoLoop ctx val form loopbody) =
  binding (zip form_idents $ repeat mempty) $
  bindingParams (map fst $ ctx ++ val) $
  DoLoop ctx val form <$> fuseInBody loopbody
  where form_idents = case form of
          WhileLoop{} -> []
          ForLoop i it _ loopvars ->
            Ident i (Prim $ IntType it) :
            map (paramIdent . fst) loopvars

fuseInExp e = mapExpM fuseIn e

fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn = identityMapper {
           mapOnBody = const fuseInBody
         , mapOnOp = mapSOACM identitySOACMapper { mapOnSOACLambda = fuseInLambda }
         }

fuseInLambda :: Lambda -> FusionGM Lambda
fuseInLambda (Lambda params body rtp) = do
  body' <- bindingParams params $ fuseInBody body
  return $ Lambda params body' rtp

replaceSOAC :: Pattern -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC (Pattern _ []) _ _ = return mempty
replaceSOAC pat@(Pattern _ (patElem : _)) aux e = do
  fres  <- asks fusedRes
  let pat_nm = patElemName patElem
      names  = patternIdents pat
  case M.lookup pat_nm (outArr fres) of
    Nothing  ->
      oneStm . Let pat aux <$> fuseInExp e
    Just knm ->
      case M.lookup knm (kernels fres) of
        Nothing  -> throwError $ Error
                                   ("In Fusion.hs, replaceSOAC, outArr in ker_name "
                                    ++"which is not in Res: "++pretty (unKernName knm))
        Just ker -> do
          when (null $ fusedVars ker) $
            throwError $ Error
            ("In Fusion.hs, replaceSOAC, unfused kernel "
             ++"still in result: "++pretty names)
          insertKerSOAC (outNames ker) ker

insertKerSOAC :: [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC names ker = do
  new_soac' <- finaliseSOAC $ fsoac ker
  runBinder_ $ do
    f_soac <- SOAC.toSOAC new_soac'
    -- The fused kernel may consume more than the original SOACs (see
    -- issue #224).  We insert copy expressions to fix it.
    f_soac' <- copyNewlyConsumed (fusedConsumed ker) $ addOpAliases f_soac
    validents <- zipWithM newIdent (map baseString names) $ SOAC.typeOf new_soac'
    letBind_ (basicPattern [] validents) $ Op f_soac'
    transformOutput (outputTransform ker) names validents

-- | Perform simplification and fusion inside the lambda(s) of a SOAC.
finaliseSOAC :: SOAC.SOAC SOACS -> FusionGM (SOAC.SOAC SOACS)
finaliseSOAC new_soac =
  case new_soac of
    SOAC.Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs -> do
      scan_lam' <- simplifyAndFuseInLambda scan_lam
      red_lam' <- simplifyAndFuseInLambda red_lam
      map_lam' <- simplifyAndFuseInLambda map_lam
      return $ SOAC.Screma w (ScremaForm (scan_lam', scan_nes)
                                             (comm, red_lam', red_nes)
                                             map_lam')
                               arrs
    SOAC.Scatter w lam inps dests -> do
      lam' <- simplifyAndFuseInLambda lam
      return $ SOAC.Scatter w lam' inps dests
    SOAC.GenReduce w ops lam arrs -> do
      lam' <- simplifyAndFuseInLambda lam
      return $ SOAC.GenReduce w ops lam' arrs
    SOAC.Stream w form lam inps -> do
      lam' <- simplifyAndFuseInLambda lam
      return $ SOAC.Stream w form lam' inps

simplifyAndFuseInLambda :: Lambda -> FusionGM Lambda
simplifyAndFuseInLambda lam = do
  let args = replicate (length $ lambdaParams lam) Nothing
  lam' <- simplifyLambda lam args
  (_, nfres) <- fusionGatherLam (S.empty, mkFreshFusionRes) lam'
  let nfres' =  cleanFusionResult nfres
  bindRes nfres' $ fuseInLambda lam'

copyNewlyConsumed :: Names
                  -> Futhark.SOAC (Aliases.Aliases SOACS)
                  -> Binder SOACS (Futhark.SOAC SOACS)
copyNewlyConsumed was_consumed soac =
  case soac of
    Futhark.Screma w (Futhark.ScremaForm
                         (scan_lam, scan_nes)
                         (comm, reduce_lam, reduce_nes)
                         map_lam) arrs -> do
      -- Copy any arrays that are consumed now, but were not in the
      -- constituents.
      arrs' <- mapM copyConsumedArr arrs
      -- Any consumed free variables will have to be copied inside the
      -- lambda, and we have to substitute the name of the copy for
      -- the original.
      map_lam' <- copyFreeInLambda map_lam
      return $ Futhark.Screma w
        (Futhark.ScremaForm
         (Aliases.removeLambdaAliases scan_lam, scan_nes)
         (comm, Aliases.removeLambdaAliases reduce_lam, reduce_nes)
         map_lam') arrs'

    _ -> return $ removeOpAliases soac
  where consumed = consumedInOp soac
        newly_consumed = consumed `S.difference` was_consumed

        copyConsumedArr a
          | a `S.member` newly_consumed =
            letExp (baseString a <> "_copy") $ BasicOp $ Copy a
          | otherwise = return a

        copyFreeInLambda lam = do
          let free_consumed = consumedByLambda lam `S.difference`
                S.fromList (map paramName $ lambdaParams lam)
          (bnds, subst) <-
            foldM copyFree (mempty, mempty) $ S.toList free_consumed
          let lam' = Aliases.removeLambdaAliases lam
          return $ if null bnds
                   then lam'
                   else lam' { lambdaBody =
                                 insertStms bnds $
                                 substituteNames subst $ lambdaBody lam'
                             }

        copyFree (bnds, subst) v = do
          v_copy <- newVName $ baseString v <> "_copy"
          copy <- mkLetNamesM [v_copy] $ BasicOp $ Copy v
          return (oneStm copy<>bnds, M.insert v v_copy subst)

---------------------------------------------------
---------------------------------------------------
---- HELPERS
---------------------------------------------------
---------------------------------------------------

-- | Get a new fusion result, i.e., for when entering a new scope,
--   e.g., a new lambda or a new loop.
mkFreshFusionRes :: FusedRes
mkFreshFusionRes =
    FusedRes { rsucc     = False,   outArr = M.empty, inpArr  = M.empty,
               infusible = S.empty, kernels = M.empty }

mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes res1 res2 = do
    let ufus_mres = infusible res1 `S.union` infusible res2
    inp_both     <- expandSoacInpArr $ M.keys $ inpArr res1 `M.intersection` inpArr res2
    let m_unfus   = foldl (flip S.insert) ufus_mres inp_both
    return $ FusedRes  (rsucc     res1       ||      rsucc     res2)
                       (outArr    res1    `M.union`  outArr    res2)
                       (M.unionWith S.union (inpArr res1) (inpArr res2) )
                       m_unfus
                       (kernels   res1    `M.union`  kernels   res2)


-- | The expression arguments are supposed to be array-type exps.
--   Returns a tuple, in which the arrays that are vars are in the
--   first element of the tuple, and the one which are indexed or
--   transposes (or otherwise transformed) should be in the second.
--
--   E.g., for expression `mapT(f, a, b[i])', the result should be
--   `([a],[b])'
getIdentArr :: [SOAC.Input] -> ([VName], [VName])
getIdentArr = foldl comb ([],[])
  where comb (vs,os) (SOAC.Input ts idd _)
          | SOAC.nullTransforms ts = (idd:vs, os)
        comb (vs, os) inp =
          (vs, SOAC.inputArray inp : os)

cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult fres =
    let newks = M.filter (not . null . fusedVars)      (kernels fres)
        newoa = M.filter (`M.member` newks)            (outArr  fres)
        newia = M.map    (S.filter (`M.member` newks)) (inpArr fres)
    in fres { outArr = newoa, inpArr = newia, kernels = newks }

--------------
--- Errors ---
--------------

errorIllegal :: String -> FusionGM FusedRes
errorIllegal soac_name =
    throwError $ Error
                  ("In Fusion.hs, soac "++soac_name++" appears illegally in pgm!")