{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | Do various kernel optimisations - mostly related to coalescing.
module Futhark.Pass.KernelBabysitting
       ( babysitKernels
       , nonlinearInMemory
       )
       where

import Control.Arrow (first)
import Control.Monad.State.Strict
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Foldable
import Data.List
import Data.Maybe

import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.Kernels
       hiding (Prog, Body, Stm, Pattern, PatElem,
               BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Tools
import Futhark.Pass
import Futhark.Util

babysitKernels :: Pass Kernels Kernels
babysitKernels = Pass "babysit kernels"
                 "Transpose kernel input arrays for better performance." $
                 intraproceduralTransformation transformFunDef

transformFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
transformFunDef fundec = do
  (body', _) <- modifyNameSource $ runState (runBinderT m M.empty)
  return fundec { funDefBody = body' }
  where m = inScopeOf fundec $
            transformBody mempty $ funDefBody fundec

type BabysitM = Binder Kernels

transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody expmap (Body () bnds res) = insertStmsM $ do
  foldM_ transformStm expmap bnds
  return $ resultBody res

-- | Map from variable names to defining expression.  We use this to
-- hackily determine whether something is transposed or otherwise
-- funky in memory (and we'd prefer it not to be).  If we cannot find
-- it in the map, we just assume it's all good.  HACK and FIXME, I
-- suppose.  We really should do this at the memory level.
type ExpMap = M.Map VName (Stm Kernels)

nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory name m =
  case M.lookup name m of
    Just (Let _ _ (BasicOp (Rearrange perm _))) -> Just $ Just $ rearrangeInverse perm
    Just (Let _ _ (BasicOp (Reshape _ arr))) -> nonlinearInMemory arr m
    Just (Let _ _ (BasicOp (Manifest perm _))) -> Just $ Just perm
    Just (Let pat _ (Op (Kernel _ _ ts _))) ->
      nonlinear =<< find ((==name) . patElemName . fst)
      (zip (patternElements pat) ts)
    _ -> Nothing
  where nonlinear (pe, t)
          | inner_r <- arrayRank t, inner_r > 0 = do
              let outer_r = arrayRank (patElemType pe) - inner_r
              return $ Just $ rearrangeInverse $ [inner_r..inner_r+outer_r-1] ++ [0..inner_r-1]
          | otherwise = Nothing

transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap

transformStm expmap (Let pat aux ke@(Op (Kernel desc space ts kbody))) = do
  -- Go spelunking for accesses to arrays that are defined outside the
  -- kernel body and where the indices are kernel thread indices.
  scope <- askScope
  let thread_gids = map fst $ spaceDimensions space
      thread_local = S.fromList $ spaceGlobalId space : spaceLocalId space : thread_gids
      free_ker_vars = freeInExp ke `S.difference` getKerVariantIds space
  kbody'' <- evalStateT (traverseKernelBodyArrayIndexes
                         free_ker_vars
                         thread_local
                         (castScope scope <> scopeOfKernelSpace space)
                         (ensureCoalescedAccess expmap (spaceDimensions space) num_threads)
                         kbody)
             mempty

  let bnd' = Let pat aux $ Op $ Kernel desc space ts kbody''
  addStm bnd'
  return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap
  where num_threads = spaceNumThreads space
        getKerVariantIds (KernelSpace glb_id loc_id grp_id _ _ _ (FlatThreadSpace strct)) =
            let (gids, _) = unzip strct
            in  S.fromList $ [glb_id, loc_id, grp_id] ++ gids
        getKerVariantIds (KernelSpace glb_id loc_id grp_id _ _ _ (NestedThreadSpace strct)) =
            let (gids, _, lids, _) = unzip4 strct
            in  S.fromList $ [glb_id, loc_id, grp_id] ++ gids ++ lids

transformStm expmap (Let pat aux e) = do
  e' <- mapExpM (transform expmap) e
  let bnd' = Let pat aux e'
  addStm bnd'
  return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap

transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform expmap =
  identityMapper { mapOnBody = \scope -> localScope scope . transformBody expmap }

type ArrayIndexTransform m =
  Names ->
  (VName -> Bool) ->           -- thread local?
  (VName -> SubExp -> Bool)->  -- variant to a certain gid (given as first param)?
  (SubExp -> Maybe SubExp) ->  -- split substitution?
  Scope InKernel ->            -- type environment
  VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp))

traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) =>
                                  Names
                               -> Names
                               -> Scope InKernel
                               -> ArrayIndexTransform f
                               -> KernelBody InKernel
                               -> f (KernelBody InKernel)
traverseKernelBodyArrayIndexes free_ker_vars thread_variant outer_scope f (KernelBody () kstms kres) =
  KernelBody () . stmsFromList <$>
  mapM (onStm (varianceInStms mempty kstms,
               mkSizeSubsts kstms,
               outer_scope)) (stmsToList kstms) <*>
  pure kres
  where onLambda (variance, szsubst, scope) lam =
          (\body' -> lam { lambdaBody = body' }) <$>
          onBody (variance, szsubst, scope') (lambdaBody lam)
          where scope' = scope <> scopeOfLParams (lambdaParams lam)

        onStreamLambda (variance, szsubst, scope) lam =
          (\body' -> lam { groupStreamLambdaBody = body' }) <$>
          onBody (variance, szsubst, scope') (groupStreamLambdaBody lam)
          where scope' = scope <> scopeOf lam

        onBody (variance, szsubst, scope) (Body battr stms bres) = do
          stms' <- stmsFromList <$> mapM (onStm (variance', szsubst', scope')) (stmsToList stms)
          Body battr stms' <$> pure bres
          where variance' = varianceInStms variance stms
                szsubst' = mkSizeSubsts stms <> szsubst
                scope' = scope <> scopeOf stms

        onStm (variance, szsubst, _) (Let pat attr (BasicOp (Index arr is))) =
          Let pat attr . oldOrNew <$> f free_ker_vars isThreadLocal isGidVariant sizeSubst outer_scope arr is
          where oldOrNew Nothing =
                  BasicOp $ Index arr is
                oldOrNew (Just (arr', is')) =
                  BasicOp $ Index arr' is'

                isGidVariant gid (Var v) =
                  gid == v || S.member gid (M.findWithDefault (S.singleton v) v variance)
                isGidVariant _ _ = False

                isThreadLocal v =
                  not $ S.null $
                  thread_variant `S.intersection`
                  M.findWithDefault (S.singleton v) v variance

                sizeSubst (Constant v) = Just $ Constant v
                sizeSubst (Var v)
                  | v `M.member` outer_scope      = Just $ Var v
                  | Just v' <- M.lookup v szsubst = sizeSubst v'
                  | otherwise                      = Nothing

        onStm (variance, szsubst, scope) (Let pat attr e) =
          Let pat attr <$> mapExpM (mapper (variance, szsubst, scope)) e

        mapper ctx = identityMapper { mapOnBody = const (onBody ctx)
                                    , mapOnOp = onOp ctx
                                    }

        onOp ctx (GroupReduce w lam input) =
          GroupReduce w <$> onLambda ctx lam <*> pure input
        onOp ctx (GroupStream w maxchunk lam accs arrs) =
           GroupStream w maxchunk <$> onStreamLambda ctx lam <*> pure accs <*> pure arrs
        onOp _ stm = pure stm

        mkSizeSubsts = fold . fmap mkStmSizeSubst
          where mkStmSizeSubst (Let (Pattern [] [pe]) _ (Op (SplitSpace _ _ _ elems_per_i))) =
                  M.singleton (patElemName pe) elems_per_i
                mkStmSizeSubst _ = mempty

-- Not a hashmap, as SubExp is not hashable.
type Replacements = M.Map (VName, Slice SubExp) VName

ensureCoalescedAccess :: MonadBinder m =>
                         ExpMap
                      -> [(VName,SubExp)]
                      -> SubExp
                      -> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess expmap thread_space num_threads free_ker_vars isThreadLocal
                      isGidVariant sizeSubst outer_scope arr slice = do
  seen <- gets $ M.lookup (arr, slice)

  case (seen, isThreadLocal arr, typeOf <$> M.lookup arr outer_scope) of
    -- Already took care of this case elsewhere.
    (Just arr', _, _) ->
      pure $ Just (arr', slice)

    (Nothing, False, Just t)
      -- We are fully indexing the array with thread IDs, but the
      -- indices are in a permuted order.
      | Just is <- sliceIndices slice,
        length is == arrayRank t,
        Just is' <- coalescedIndexes free_ker_vars isGidVariant (map Var thread_gids) is,
        Just perm <- is' `isPermutationOf` is ->
          replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr)

      -- Check whether the access is already coalesced because of a
      -- previous rearrange being applied to the current array:
      -- 1. get the permutation of the source-array rearrange
      -- 2. apply it to the slice
      -- 3. check that the innermost index is actually the gid
      --    of the innermost kernel dimension.
      -- If so, the access is already coalesced, nothing to do!
      -- (Cosmin's Heuristic.)
      | Just (Let _ _ (BasicOp (Rearrange perm _))) <- M.lookup arr expmap,
        ---- Just (Just perm) <- nonlinearInMemory arr expmap,
        not $ null perm,
        not $ null thread_gids,
        inner_gid <- last thread_gids,
        length slice >= length perm,
        slice' <- map (\i -> slice !! i) perm,
        DimFix inner_ind <- last slice',
        not $ null thread_gids,
        isGidVariant inner_gid inner_ind ->
--        inner_ind == (Var $ inner_gid) ->
          return Nothing

      -- We are not fully indexing an array, but the remaining slice
      -- is invariant to the innermost-kernel dimension. We assume
      -- the remaining slice will be sequentially streamed, hence
      -- tiling will be applied later and will solve coalescing.
      -- Hence nothing to do at this point. (Cosmin's Heuristic.)
      | (is, rem_slice) <- splitSlice slice,
        not $ null rem_slice,
        allDimAreSlice rem_slice,
        Nothing <- M.lookup arr expmap,
        not $ tooSmallSlice (primByteSize (elemType t)) rem_slice,
        is /= map Var (take (length is) thread_gids) || length is == length thread_gids,
        not (null thread_gids || null is),
        not ( S.member (last thread_gids) (S.union (freeIn is) (freeIn rem_slice)) ) ->
          return Nothing

      -- We are not fully indexing the array, and the indices are not
      -- a proper prefix of the thread indices, and some indices are
      -- thread local, so we assume (HEURISTIC!)  that the remaining
      -- dimensions will be traversed sequentially.
      | (is, rem_slice) <- splitSlice slice,
        not $ null rem_slice,
        not $ tooSmallSlice (primByteSize (elemType t)) rem_slice,
        is /= map Var (take (length is) thread_gids) || length is == length thread_gids,
        any isThreadLocal (S.toList $ freeIn is) -> do
          let perm = coalescingPermutation (length is) $ arrayRank t
          replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr)

      -- We are taking a slice of the array with a unit stride.  We
      -- assume that the slice will be traversed sequentially.
      --
      -- We will really want to treat the sliced dimension like two
      -- dimensions so we can transpose them.  This may require
      -- padding.
      | (is, rem_slice) <- splitSlice slice,
        and $ zipWith (==) is $ map Var thread_gids,
        DimSlice offset len (Constant stride):_ <- rem_slice,
        isThreadLocalSubExp offset,
        Just {} <- sizeSubst len,
        oneIsh stride -> do
          let num_chunks = if null is
                           then primExpFromSubExp int32 num_threads
                           else coerceIntPrimExp Int32 $
                                product $ map (primExpFromSubExp int32) $
                                drop (length is) thread_gdims
          replace =<< lift (rearrangeSlice (length is) (arraySize (length is) t) num_chunks arr)

      -- Everything is fine... assuming that the array is in row-major
      -- order!  Make sure that is the case.
      | Just{} <- nonlinearInMemory arr expmap ->
          case sliceIndices slice of
            Just is | Just _ <- coalescedIndexes free_ker_vars isGidVariant (map Var thread_gids) is ->
                        replace =<< lift (rowMajorArray arr)
                    | otherwise ->
                        return Nothing
            _ -> replace =<< lift (rowMajorArray arr)

    _ -> return Nothing

  where (thread_gids, thread_gdims) = unzip thread_space

        replace arr' = do
          modify $ M.insert (arr, slice) arr'
          return $ Just (arr', slice)

        isThreadLocalSubExp (Var v) = isThreadLocal v
        isThreadLocalSubExp Constant{} = False

-- Heuristic for avoiding rearranging too small arrays.
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice bs = fst . foldl comb (True,bs) . sliceDims
  where comb (True, x) (Constant (IntValue (Int32Value d))) = (d*x < 4, d*x)
        comb (_, x)     _                                   = (False, x)

splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix i:is) = first (i:) $ splitSlice is
splitSlice is = ([], is)

allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = True
allDimAreSlice (DimFix _:_) = False
allDimAreSlice (_:is) = allDimAreSlice is

-- Try to move thread indexes into their proper position.
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes free_ker_vars isGidVariant tgids is
  -- Do Nothing if:
  -- 1. any of the indices is a constant or a kernel free variable
  --    (because it would transpose a bigger array then needed -- big overhead).
  -- 2. the innermost index is variant to the innermost-thread gid
  --    (because access is likely to be already coalesced)
  | any isCt is =
        Nothing
  | any (`S.member` free_ker_vars) (mapMaybe mbVarId is) =
        Nothing
  | not (null tgids),
    not (null is),
    Var innergid <- last tgids,
    num_is > 0 && isGidVariant innergid (last is) =
        Just is
  -- 3. Otherwise try fix coalescing
  | otherwise =
      Just $ reverse $ foldl move (reverse is) $ zip [0..] (reverse tgids)
  where num_is = length is

        move is_rev (i, tgid)
          -- If tgid is in is_rev anywhere but at position i, and
          -- position i exists, we move it to position i instead.
          | Just j <- elemIndex tgid is_rev, i /= j, i < num_is =
              swap i j is_rev
          | otherwise =
              is_rev

        swap i j l
          | Just ix <- maybeNth i l,
            Just jx <- maybeNth j l =
              update i jx $ update j ix l
          | otherwise =
              error $ "coalescedIndexes swap: invalid indices" ++ show (i, j, l)

        update 0 x (_:ys) = x : ys
        update i x (y:ys) = y : update (i-1) x ys
        update _ _ []     = error "coalescedIndexes: update"

        isCt :: SubExp -> Bool
        isCt (Constant _) = True
        isCt (Var      _) = False

        mbVarId (Constant _) = Nothing
        mbVarId (Var v) = Just v

coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation num_is rank =
  [num_is..rank-1] ++ [0..num_is-1]

rearrangeInput :: MonadBinder m =>
                  Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just current_perm)) perm arr
  | current_perm == perm = return arr -- Already has desired representation.

rearrangeInput Nothing perm arr
  | sort perm == perm = return arr -- We don't know the current
                                   -- representation, but the indexing
                                   -- is linear, so let's hope the
                                   -- array is too.
rearrangeInput (Just Just{}) perm arr
  | sort perm == perm = rowMajorArray arr -- We just want a row-major array, no tricks.
rearrangeInput manifest perm arr = do
  -- We may first manifest the array to ensure that it is flat in
  -- memory.  This is sometimes unnecessary, in which case the copy
  -- will hopefully be removed by the simplifier.
  manifested <- if isJust manifest then rowMajorArray arr else return arr
  letExp (baseString arr ++ "_coalesced") $
    BasicOp $ Manifest perm manifested

rowMajorArray :: MonadBinder m =>
                 VName -> m VName
rowMajorArray arr = do
  rank <- arrayRank <$> lookupType arr
  letExp (baseString arr ++ "_rowmajor") $ BasicOp $ Manifest [0..rank-1] arr

rearrangeSlice :: MonadBinder m =>
                  Int -> SubExp -> PrimExp VName -> VName
               -> m VName
rearrangeSlice d w num_chunks arr = do
  num_chunks' <- letSubExp "num_chunks" =<< toExp num_chunks

  (w_padded, padding) <- paddedScanReduceInput w num_chunks'

  per_chunk <- letSubExp "per_chunk" $ BasicOp $ BinOp (SQuot Int32) w_padded num_chunks'
  arr_t <- lookupType arr
  arr_padded <- padArray w_padded padding arr_t
  rearrange num_chunks' w_padded per_chunk (baseString arr) arr_padded arr_t

  where padArray w_padded padding arr_t = do
          let arr_shape = arrayShape arr_t
              padding_shape = setDim d arr_shape padding
          arr_padding <-
            letExp (baseString arr <> "_padding") $
            BasicOp $ Scratch (elemType arr_t) (shapeDims padding_shape)
          letExp (baseString arr <> "_padded") $
            BasicOp $ Concat d arr [arr_padding] w_padded

        rearrange num_chunks' w_padded per_chunk arr_name arr_padded arr_t = do
          let arr_dims = arrayDims arr_t
              pre_dims = take d arr_dims
              post_dims = drop (d+1) arr_dims
              extradim_shape = Shape $ pre_dims ++ [num_chunks', per_chunk] ++ post_dims
              tr_perm = [0..d-1] ++ map (+d) ([1] ++ [2..shapeRank extradim_shape-1-d] ++ [0])
          arr_extradim <-
            letExp (arr_name <> "_extradim") $
            BasicOp $ Reshape (map DimNew $ shapeDims extradim_shape) arr_padded
          arr_extradim_tr <-
            letExp (arr_name <> "_extradim_tr") $
            BasicOp $ Manifest tr_perm arr_extradim
          arr_inv_tr <- letExp (arr_name <> "_inv_tr") $
            BasicOp $ Reshape (map DimCoercion pre_dims ++ map DimNew (w_padded : post_dims))
            arr_extradim_tr
          letExp (arr_name <> "_inv_tr_init") =<<
            eSliceArray d  arr_inv_tr (eSubExp $ constant (0::Int32)) (eSubExp w)

paddedScanReduceInput :: MonadBinder m =>
                         SubExp -> SubExp
                      -> m (SubExp, SubExp)
paddedScanReduceInput w stride = do
  w_padded <- letSubExp "padded_size" =<<
              eRoundToMultipleOf Int32 (eSubExp w) (eSubExp stride)
  padding <- letSubExp "padding" $ BasicOp $ BinOp (Sub Int32) w_padded w
  return (w_padded, padding)

--- Computing variance.

type VarianceTable = M.Map VName Names

varianceInStms :: VarianceTable -> Stms InKernel -> VarianceTable
varianceInStms t = foldl varianceInStm t . stmsToList

varianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable
varianceInStm variance bnd =
  foldl' add variance $ patternNames $ stmPattern bnd
  where add variance' v = M.insert v binding_variance variance'
        look variance' v = S.insert v $ M.findWithDefault mempty v variance'
        binding_variance = mconcat $ map (look variance) $ S.toList (freeInStm bnd)