{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Miscellaneous helper functions.  Perpetually in need of a cleanup.
module Futhark.Optimise.MemoryBlockMerging.Miscellaneous
  ( makeCommutativeMap
  , insertOrUpdate
  , insertOrUpdateMany
  , insertOrNew
  , removeEmptyMaps
  , removeKeyFromMapElems
  , newDeclarationsStm
  , lookupEmptyable
  , fromJust
  , maybeFromBoolM
  , sortByKeyM
  , mapMaybeM
  , anyM
  , whenM
  , expandPrimExp
  , expandIxFun
  , mapFromListSetUnion
  , fixpointIterateMay
  , filterSetM
  , (<&&>), (<||>)

  , expandWithAliases
  , FullWalk(..)
  , fullWalkAliasesExpM
  , FullWalkAliases
  , FullMap
  , fullMapExpM
  ) where

import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.List as L
import Control.Monad
import Data.Maybe (fromMaybe, catMaybes)
import Data.Function (on)

import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
       (ExplicitMemory, InKernel)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel
import Futhark.Representation.Kernels.KernelExp
import Futhark.Representation.Aliases
import Futhark.Analysis.PrimExp.Convert

import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Optimise.MemoryBlockMerging.Types


-- If a property is commutative in a map, build a map that reflects it.  A bit
-- crude.  We could also just use a function that calculates this whenever
-- needed.
makeCommutativeMap :: Ord v => M.Map v (S.Set v) -> M.Map v (S.Set v)
makeCommutativeMap m =
  let names = S.toList (S.union (M.keysSet m) (S.unions (M.elems m)))
      assocs = map (\n ->
                      let existing = lookupEmptyable n m
                          newly_found = S.unions $ map (\(k, v) ->
                                                          if S.member n v
                                                          then S.singleton k
                                                          else S.empty) $ M.assocs m
                          ns = S.union existing newly_found
                      in (n, ns)) names
  in M.fromList assocs

insertOrUpdate :: (Ord k, Ord v) => k -> v ->
                  M.Map k (S.Set v) -> M.Map k (S.Set v)
insertOrUpdate k v = M.alter (insertOrNew (S.singleton v)) k

insertOrUpdateMany :: (Ord k, Ord v) => k -> S.Set v ->
                      M.Map k (S.Set v) -> M.Map k (S.Set v)
insertOrUpdateMany k vs = M.alter (insertOrNew vs) k

insertOrNew :: Ord a => S.Set a -> Maybe (S.Set a) -> Maybe (S.Set a)
insertOrNew xs m = Just $ case m of
  Just s -> S.union xs s
  Nothing -> xs

removeEmptyMaps :: M.Map k (S.Set v) -> M.Map k (S.Set v)
removeEmptyMaps = M.filter (not . S.null)

removeKeyFromMapElems :: (Ord k) => M.Map k (S.Set k) -> M.Map k (S.Set k)
removeKeyFromMapElems = M.mapWithKey S.delete

newDeclarationsStm :: Stm lore -> [VName]
newDeclarationsStm (Let (Pattern patctxelems patvalelems) _ e) =
  let new_decls0 = map patElemName (patctxelems ++ patvalelems)
      new_decls1 = case e of
        DoLoop mergectxparams mergevalparams _loopform _body ->
          -- Technically not a declaration for the current expression, but very
          -- close.
          map (paramName . fst) (mergectxparams ++ mergevalparams)
        _ -> []
      new_decls = new_decls0 ++ new_decls1
  in new_decls

lookupEmptyable :: (Ord a, Monoid b) => a -> M.Map a b -> b
lookupEmptyable x m = fromMaybe mempty $ M.lookup x m

fromJust :: String -> Maybe a -> a
fromJust _ (Just x) = x
fromJust mistake Nothing = error ("error: " ++ mistake)

maybeFromBoolM :: Monad m => (a -> m Bool) -> (a -> m (Maybe a))
maybeFromBoolM f a = do
  res <- f a
  return $ if res
           then Just a
           else Nothing

expandWithAliases :: forall v. Ord v => MemAliases -> M.Map v Names -> M.Map v Names
expandWithAliases mem_aliases = fixpointIterate expand
  where expand :: M.Map v Names -> M.Map v Names
        expand mems_map =
          M.fromList (map (\(v, mems) ->
                             (v, S.unions (mems : map (`lookupEmptyable` mem_aliases)
                                           (S.toList mems))))
                      (M.assocs mems_map))

fixpointIterate :: Eq a => (a -> a) -> a -> a
fixpointIterate f x
  | f x == x = x
  | otherwise = fixpointIterate f (f x)

fixpointIterateMay :: (a -> Maybe a) -> a -> a
fixpointIterateMay f x = maybe x (fixpointIterateMay f) (f x)

mapFromListSetUnion :: (Ord k, Ord v) => [(k, S.Set v)] -> M.Map k (S.Set v)
mapFromListSetUnion = M.unionsWith S.union . map (uncurry M.singleton)

-- Replace variables with subtrees of their constituents wherever possible.  It
-- naively expands a PrimExp as much as the input map allows, and can enable
-- more expressions to have it in scope, since it will likely consist of fewer
-- variables.
expandPrimExp :: M.Map VName (ExpMem.PrimExp VName) -> ExpMem.PrimExp VName
              -> ExpMem.PrimExp VName
expandPrimExp var_to_pe = fixpointIterate (substituteInPrimExp var_to_pe)

expandIxFun :: M.Map VName (ExpMem.PrimExp VName) -> ExpMem.IxFun -> ExpMem.IxFun
expandIxFun var_to_pe = fixpointIterate (IxFun.substituteInIxFun var_to_pe)

(<&&>) :: Monad m => m Bool -> m Bool -> m Bool
m <&&> n = (&&) <$> m <*> n

(<||>) :: Monad m => m Bool -> m Bool -> m Bool
m <||> n = (||) <$> m <*> n

anyM :: Monad m => (a -> m Bool) -> [a] -> m Bool
anyM f xs = or <$> mapM f xs

whenM :: Monad m => m Bool -> m () -> m ()
whenM b m = do
  b' <- b
  when b' m

mapMaybeM :: Monad m => (a -> m (Maybe b)) -> [a] -> m [b]
mapMaybeM f xs = catMaybes <$> mapM f xs

sortByKeyM :: (Ord t, Monad m) => (a -> m t) -> [a] -> m [a]
sortByKeyM f xs =
  map fst . L.sortBy (compare `on` snd) . zip xs <$> mapM f xs

filterSetM :: (Ord a, Monad m) => (a -> m Bool) -> S.Set a -> m (S.Set a)
filterSetM f xs = S.fromList <$> filterM f (S.toList xs)

-- Map on both ExplicitMemory and InKernel.
class FullMap lore where
  fullMapExpM :: Monad m => Mapper lore lore m -> KernelMapper InKernel InKernel m
              -> Exp lore -> m (Exp lore)

instance FullMap ExplicitMemory where
  fullMapExpM mapper mapper_kernel e =
    case e of
      Op (ExpMem.Inner kernel) ->
        Op . ExpMem.Inner <$> mapKernelM mapper_kernel kernel
      _ -> mapExpM mapper e

instance FullMap InKernel where
  fullMapExpM mapper mapper_kernel e = case e of
    Op (ExpMem.Inner ke) -> Op . ExpMem.Inner <$> case ke of
      ExpMem.Combine a b c body ->
        ExpMem.Combine a b c <$> mapOnKernelBody mapper_kernel body
      ExpMem.GroupReduce a lambda b ->
        ExpMem.GroupReduce a
        <$> mapOnKernelLambda mapper_kernel lambda
        <*> pure b
      ExpMem.GroupScan a lambda b ->
        ExpMem.GroupScan a
        <$> mapOnKernelLambda mapper_kernel lambda
        <*> pure b
      ExpMem.GroupStream a b (ExpMem.GroupStreamLambda a1 b1 params0 params1 gsbody) c d ->
        ExpMem.GroupStream a b
        <$> (ExpMem.GroupStreamLambda a1 b1
             <$> mapM (mapOnKernelLParam mapper_kernel) params0
             <*> mapM (mapOnKernelLParam mapper_kernel) params1
             <*> mapOnKernelBody mapper_kernel gsbody
            )
        <*> pure c <*> pure d
      _ -> return ke
    _ -> mapExpM mapper e

-- Walk on both ExplicitMemory and InKernel.
class FullWalk lore where
  fullWalkExpM :: Monad m => Walker lore m -> KernelWalker InKernel m
               -> Exp lore -> m ()

-- FIXME: This can maybe be integrated into the above typeclass.
class FullWalkAliases lore where
  fullWalkAliasesExpM :: Monad m => Walker (Aliases lore) m
                      -> KernelWalker (Aliases InKernel) m
                      -> Exp (Aliases lore) -> m ()

instance FullWalk ExplicitMemory where
  fullWalkExpM walker walker_kernel e = do
    walkExpM walker e
    case e of
      Op (ExpMem.Inner kernel) ->
        walkKernelM walker_kernel kernel
      _ -> return ()

instance FullWalkAliases ExplicitMemory where
  fullWalkAliasesExpM walker walker_kernel e = do
    walkExpM walker e
    case e of
      Op (ExpMem.Inner kernel) ->
        walkKernelM walker_kernel kernel
      _ -> return ()

instance FullWalk InKernel where
  fullWalkExpM walker walker_kernel e = case e of
    Op (ExpMem.Inner ke) -> walkOnKernelExpM walker_kernel ke
    _ -> walkExpM walker e

instance FullWalkAliases InKernel where
  fullWalkAliasesExpM walker walker_kernel e = case e of
    Op (ExpMem.Inner ke) -> walkOnKernelExpM walker_kernel ke
    _ -> walkExpM walker e

walkOnKernelExpM :: Monad m => KernelWalker lore m ->
                    KernelExp lore -> m ()
walkOnKernelExpM walker_kernel ke = case ke of
  ExpMem.Combine _ _ _ body ->
    walkOnKernelBody walker_kernel body
  ExpMem.GroupReduce _ lambda _ ->
    walkOnKernelLambda walker_kernel lambda
  ExpMem.GroupScan _ lambda _ ->
    walkOnKernelLambda walker_kernel lambda
  ExpMem.GroupStream _ _ gslambda _ _ ->
    walkOnGroupStreamLambdaM walker_kernel gslambda
  _ -> return ()

walkOnGroupStreamLambdaM :: Monad m => KernelWalker lore m ->
                            GroupStreamLambda lore -> m ()
walkOnGroupStreamLambdaM walker_kernel (GroupStreamLambda _ _
                                        params0 params1 gsbody) = do
  mapM_ (walkOnKernelLParam walker_kernel) params0
  mapM_ (walkOnKernelLParam walker_kernel) params1
  walkOnKernelBody walker_kernel gsbody