{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | This module exports facilities for transforming array accesses in
-- a list of 'Stm's (intended to be the bindings in a body).  The
-- idea is that you can state that some variable @x@ is in fact an
-- array indexing @v[i0,i1,...]@.
module Futhark.Optimise.InPlaceLowering.SubstituteIndices
       (
         substituteIndices
       , IndexSubstitution
       , IndexSubstitutions
       ) where

import Control.Monad
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.AST
import Futhark.Construct
import Futhark.Tools (fullSlice)
import Futhark.Util

type IndexSubstitution attr = (Certificates, VName, attr, Slice SubExp)
type IndexSubstitutions attr = [(VName, IndexSubstitution attr)]

typeEnvFromSubstitutions :: LetAttr lore ~ attr =>
                            IndexSubstitutions attr -> Scope lore
typeEnvFromSubstitutions = M.fromList . map (fromSubstitution . snd)
  where fromSubstitution (_, name, t, _) =
          (name, LetInfo t)

substituteIndices :: (MonadFreshNames m, BinderOps lore, Bindable lore,
                      Aliased lore, LetAttr lore ~ attr) =>
                     IndexSubstitutions attr -> Stms lore
                  -> m (IndexSubstitutions attr, Stms lore)
substituteIndices substs bnds =
  runBinderT (substituteIndicesInStms substs bnds) types
  where types = typeEnvFromSubstitutions substs

substituteIndicesInStms :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
                           IndexSubstitutions (LetAttr (Lore m))
                        -> Stms (Lore m)
                        -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStms = foldM substituteIndicesInStm

substituteIndicesInStm :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
                          IndexSubstitutions (LetAttr (Lore m))
                       -> Stm (Lore m)
                       -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStm substs (Let pat lore e) = do
  e' <- substituteIndicesInExp substs e
  (substs', pat') <- substituteIndicesInPattern substs pat
  addStm $ Let pat' lore e'
  return substs'

substituteIndicesInPattern :: (MonadBinder m, LetAttr (Lore m) ~ attr) =>
                              IndexSubstitutions (LetAttr (Lore m))
                           -> PatternT attr
                           -> m (IndexSubstitutions (LetAttr (Lore m)), PatternT attr)
substituteIndicesInPattern substs pat = do
  (substs', context) <- mapAccumLM sub substs $ patternContextElements pat
  (substs'', values) <- mapAccumLM sub substs' $ patternValueElements pat
  return (substs'', Pattern context values)
  where sub substs' patElem = return (substs', patElem)

substituteIndicesInExp :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m),
                           LetAttr (Lore m) ~ attr) =>
                          IndexSubstitutions (LetAttr (Lore m))
                       -> Exp (Lore m)
                       -> m (Exp (Lore m))
substituteIndicesInExp substs e = do
  substs' <- copyAnyConsumed e
  let substitute = identityMapper { mapOnSubExp = substituteIndicesInSubExp substs'
                                  , mapOnVName  = substituteIndicesInVar substs'
                                  , mapOnBody   = const $ substituteIndicesInBody substs'
                                  }

  mapExpM substitute e
  where copyAnyConsumed =
          let consumingSubst substs' v
                | Just (cs2, src2, src2attr, is2) <- lookup v substs = do
                    row <- certifying cs2 $
                           letExp (baseString v ++ "_row") $
                           BasicOp $ Index src2 $ fullSlice (typeOf src2attr) is2
                    row_copy <- letExp (baseString v ++ "_row_copy") $
                                BasicOp $ Copy row
                    return $ update v v (mempty,
                                         row_copy,
                                         src2attr `setType`
                                         stripArray (length is2) (typeOf src2attr),
                                         []) substs'
              consumingSubst substs' _ =
                return substs'
          in foldM consumingSubst substs . S.toList . consumedInExp

substituteIndicesInSubExp :: MonadBinder m =>
                             IndexSubstitutions (LetAttr (Lore m))
                          -> SubExp
                          -> m SubExp
substituteIndicesInSubExp substs (Var v) = Var <$> substituteIndicesInVar substs v
substituteIndicesInSubExp _      se      = return se

substituteIndicesInVar :: MonadBinder m =>
                          IndexSubstitutions (LetAttr (Lore m))
                       -> VName
                       -> m VName
substituteIndicesInVar substs v
  | Just (cs2, src2, _, []) <- lookup v substs =
    certifying cs2 $ letExp (baseString src2) $ BasicOp $ SubExp $ Var src2
  | Just (cs2, src2, src2_attr, is2) <- lookup v substs =
    certifying cs2 $
    letExp "idx" $ BasicOp $ Index src2 $ fullSlice (typeOf src2_attr) is2
  | otherwise =
    return v

substituteIndicesInBody :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
                           IndexSubstitutions (LetAttr (Lore m))
                        -> Body (Lore m)
                        -> m (Body (Lore m))
substituteIndicesInBody substs body = do
  (substs', bnds') <- inScopeOf bnds $
    collectStms $ substituteIndicesInStms substs bnds
  (ses, ses_bnds) <- inScopeOf bnds' $
    collectStms $ mapM (substituteIndicesInSubExp substs') $ bodyResult body
  mkBodyM (bnds'<>ses_bnds) ses
  where bnds = bodyStms body

update :: VName -> VName -> IndexSubstitution attr -> IndexSubstitutions attr
       -> IndexSubstitutions attr
update needle name subst ((othername, othersubst) : substs)
  | needle == othername = (name, subst)           : substs
  | otherwise           = (othername, othersubst) : update needle name subst substs
update needle _    _ [] = error $ "Cannot find substitution for " ++ pretty needle