{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo.Substitution
-- Copyright   : [2012..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Trafo.Substitution (

  -- ** Renaming & Substitution
  inline, inlineVars, compose,
  subTop, subAtop,

  -- ** Weakening
  (:>), Sink(..), SinkExp(..), weakenVars,

  -- ** Strengthening
  (:?>), strengthen, strengthenE,

  -- ** Rebuilding terms
  RebuildAcc, Rebuildable(..), RebuildableAcc,
  RebuildableExp(..), rebuildWeakenVar, rebuildLHS,
  OpenAccFun(..), OpenAccExp(..),

  -- ** Checks
  isIdentity, isIdentityIndexing, extractExpVars,
  bindingIsTrivial,

) where

import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Array
import qualified Data.Array.Accelerate.Debug.Stats      as Stats

import Data.Kind
import Control.Applicative                              hiding ( Const )
import Control.Monad
import Prelude                                          hiding ( exp, seq )


-- NOTE: [Renaming and Substitution]
--
-- To do things like renaming and substitution, we need some operation on
-- variables that we push structurally through terms, applying to each variable.
-- We have a type preserving but environment changing operation:
--
--   v :: forall t. Idx env t -> f env' aenv t
--
-- The crafty bit is that 'f' might represent variables (for renaming) or terms
-- (for substitutions). The demonic forall, --- which is to say that the
-- quantifier is in a position which gives us obligation, not opportunity ---
-- forces us to respect type: when pattern matching detects the variable we care
-- about, happily we discover that it has the type we must respect. The demon is
-- not so free to mess with us as one might fear at first.
--
-- We then lift this to an operation which traverses terms and rebuild them
-- after applying 'v' to the variables:
--
--   rebuildPartial v :: OpenExp env aenv t -> OpenExp env' aenv t
--
-- The Syntactic class tells us what we need to know about 'f' if we want to be
-- able to rebuildPartial terms. In essence, the crucial functionality is to propagate
-- a class of operations on variables that is closed under shifting.
--
infixr `compose`
-- infixr `substitute`

lhsFullVars :: forall s a env1 env2. LeftHandSide s a env1 env2 -> Maybe (Vars s env2 a)
lhsFullVars = fmap snd . go weakenId
  where
    go :: forall env env' b. (env' :> env2) -> LeftHandSide s b env env' -> Maybe (env :> env2, Vars s env2 b)
    go k (LeftHandSideWildcard TupRunit) = Just (k, TupRunit)
    go k (LeftHandSideSingle s) = Just $ (weakenSucc $ k, TupRsingle $ Var s $ k >:> ZeroIdx)
    go k (LeftHandSidePair l1 l2)
      | Just (k',  v2) <- go k  l2
      , Just (k'', v1) <- go k' l1 = Just (k'', TupRpair v1 v2)
    go _ _ = Nothing

bindingIsTrivial :: LeftHandSide s a env1 env2 -> Vars s env2 b -> Maybe (a :~: b)
bindingIsTrivial lhs vars
  | Just lhsVars <- lhsFullVars lhs
  , Just Refl    <- matchVars vars lhsVars
  = Just Refl
bindingIsTrivial _ _ = Nothing

isIdentity :: OpenFun env aenv (a -> b) -> Maybe (a :~: b)
isIdentity (Lam lhs (Body (extractExpVars -> Just vars))) = bindingIsTrivial lhs vars
isIdentity _ = Nothing

-- Detects whether the function is of the form \ix -> a ! ix
isIdentityIndexing :: OpenFun env aenv (a -> b) -> Maybe (ArrayVar aenv (Array a b))
isIdentityIndexing (Lam lhs (Body body))
  | Index avar ix <- body
  , Just vars     <- extractExpVars ix
  , Just Refl     <- bindingIsTrivial lhs vars
  = Just avar
isIdentityIndexing _ = Nothing

-- | Replace the first variable with the given expression. The environment
-- shrinks.
--
inline :: OpenExp (env, s) aenv t
       -> OpenExp env      aenv s
       -> OpenExp env      aenv t
inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f

inlineVars :: forall env env' aenv t1 t2.
              ELeftHandSide t1 env env'
           ->        OpenExp env' aenv t2
           ->        OpenExp env  aenv t1
           -> Maybe (OpenExp env  aenv t2)
inlineVars lhsBound expr bound
  | Just vars <- lhsFullVars lhsBound = substitute (strengthenWithLHS lhsBound) weakenId vars expr
  where
    substitute
        :: forall env1 env2 t.
           env1 :?> env2
        -> env :> env2
        -> ExpVars env1 t1
        -> OpenExp env1 aenv t
        -> Maybe (OpenExp env2 aenv t)
    substitute _ k2 vars (extractExpVars -> Just vars')
      | Just Refl <- matchVars vars vars' = Just $ weakenE k2 bound
    substitute k1 k2 vars topExp = case topExp of
      Let lhs e1 e2
        | Exists lhs' <- rebuildLHS lhs
                          -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2
      Evar (Var t ix)     -> Evar . Var t <$> k1 ix
      Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1
      Pair e1 e2          -> Pair <$> travE e1 <*> travE e2
      Nil                 -> Just Nil
      VecPack   vec e1    -> VecPack   vec <$> travE e1
      VecUnpack vec e1    -> VecUnpack vec <$> travE e1
      IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2
      IndexFull  si e1 e2 -> IndexFull  si <$> travE e1 <*> travE e2
      ToIndex   shr e1 e2 -> ToIndex   shr <$> travE e1 <*> travE e2
      FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2
      Case e1 rhs def     -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def
      Cond e1 e2 e3       -> Cond <$> travE e1 <*> travE e2 <*> travE e3
      While f1 f2 e1      -> While <$> travF f1 <*> travF f2 <*> travE e1
      Const t c           -> Just $ Const t c
      PrimConst c         -> Just $ PrimConst c
      PrimApp p e1        -> PrimApp p <$> travE e1
      Index a e1          -> Index a <$> travE e1
      LinearIndex a e1    -> LinearIndex a <$> travE e1
      Shape a             -> Just $ Shape a
      ShapeSize shr e1    -> ShapeSize shr <$> travE e1
      Undef t             -> Just $ Undef t
      Coerce t1 t2 e1     -> Coerce t1 t2 <$> travE e1

      where
        travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s)
        travE = substitute k1 k2 vars

        travF :: OpenFun env1 aenv s -> Maybe (OpenFun env2 aenv s)
        travF = substituteF k1 k2 vars

        travMaybeE :: Maybe (OpenExp env1 aenv s) -> Maybe (Maybe (OpenExp env2 aenv s))
        travMaybeE Nothing  = pure Nothing
        travMaybeE (Just x) = Just <$> travE x

    substituteF :: forall env1 env2 t.
               env1 :?> env2
            -> env :> env2
            -> ExpVars env1 t1
            -> OpenFun env1 aenv t
            -> Maybe (OpenFun env2 aenv t)
    substituteF k1 k2 vars (Body e) = Body <$> substitute k1 k2 vars e
    substituteF k1 k2 vars (Lam lhs f)
      | Exists lhs' <- rebuildLHS lhs = Lam lhs' <$> substituteF (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) f

inlineVars _ _ _ = Nothing


-- | Replace an expression that uses the top environment variable with another.
-- The result of the first is let bound into the second.
--
{- substitute' :: OpenExp (env, b) aenv c
            -> OpenExp (env, a) aenv b
            -> OpenExp (env, a) aenv c
substitute' f g
  | Stats.substitution "substitute" False = undefined
  | isIdentity f = g -- don't rebind an identity function
  | isIdentity g = f
  | otherwise = Let g $ rebuildE split f
  where
    split :: Idx (env,b) c -> OpenExp ((env,a),b) aenv c
    split ZeroIdx       = Var ZeroIdx
    split (SuccIdx ix)  = Var (SuccIdx (SuccIdx ix))

substitute :: LeftHandSide b env envb
           -> OpenExp envb c
           -> LeftHandSide a env enva
           -> OpenExp enva b
-}

-- | Composition of unary functions.
--
compose :: HasCallStack
        => OpenFun env aenv (b -> c)
        -> OpenFun env aenv (a -> b)
        -> OpenFun env aenv (a -> c)
compose f@(Lam lhsB (Body c)) g@(Lam lhsA (Body b))
  | Stats.substitution "compose" False = undefined
  | Just Refl <- isIdentity f = g -- don't rebind an identity function
  | Just Refl <- isIdentity g = f

  | Exists lhsB' <- rebuildLHS lhsB
  = Lam lhsA
  $ Body
  $ Let lhsB' b
  $ weakenE (sinkWithLHS lhsB lhsB' $ weakenWithLHS lhsA) c
  -- = Stats.substitution "compose" . Lam lhs2 . Body $ substitute' f g
compose _
  _ = error "compose: impossible evaluation"

subTop :: OpenExp env aenv s -> ExpVar (env, s) t -> OpenExp env aenv t
subTop s (Var _  ZeroIdx     ) = s
subTop _ (Var tp (SuccIdx ix)) = Evar $ Var tp ix

subAtop :: PreOpenAcc acc aenv t -> ArrayVar (aenv, t) (Array sh2 e2) -> PreOpenAcc acc aenv (Array sh2 e2)
subAtop t (Var _    ZeroIdx      ) = t
subAtop _ (Var repr (SuccIdx idx)) = Avar $ Var repr idx

data Identity a = Identity { runIdentity :: a }

instance Functor Identity where
  {-# INLINE fmap #-}
  fmap f (Identity a) = Identity (f a)

instance Applicative Identity where
  {-# INLINE (<*>) #-}
  {-# INLINE pure  #-}
  Identity f <*> Identity a = Identity (f a)
  pure a                    = Identity a

-- A class for rebuilding terms.
--
class Rebuildable f where
  {-# MINIMAL rebuildPartial #-}
  type AccClo f :: Type -> Type -> Type

  rebuildPartial :: (Applicative f', SyntacticAcc fa)
                 => (forall sh e. ArrayVar aenv (Array sh e) -> f' (fa (AccClo f) aenv' (Array sh e)))
                 -> f aenv  a
                 -> f' (f aenv' a)

  {-# INLINEABLE rebuildA #-}
  rebuildA :: (SyntacticAcc fa)
           => (forall sh e. ArrayVar aenv (Array sh e) -> fa (AccClo f) aenv' (Array sh e))
           -> f aenv  a
           -> f aenv' a
  rebuildA av = runIdentity . rebuildPartial (Identity . av)

-- A class for rebuilding scalar terms.
--
class RebuildableExp f where
  {-# MINIMAL rebuildPartialE #-}
  rebuildPartialE :: (Applicative f', SyntacticExp fe)
                  => (forall e'. ExpVar env e' -> f' (fe env' aenv e'))
                  -> f env aenv e
                  -> f' (f env' aenv e)

  {-# INLINEABLE rebuildE #-}
  rebuildE :: SyntacticExp fe
           => (forall e'. ExpVar env e' -> fe env' aenv e')
           -> f env  aenv e
           -> f env' aenv e
  rebuildE v = runIdentity . rebuildPartialE (Identity . v)

-- Terms that are rebuildable and also recursive closures
--
type RebuildableAcc acc = (Rebuildable acc, AccClo acc ~ acc)

-- Wrappers which add the 'acc' type argument
--
data OpenAccExp (acc :: Type -> Type -> Type) env aenv a where
  OpenAccExp :: { unOpenAccExp :: OpenExp env aenv a } -> OpenAccExp acc env aenv a

data OpenAccFun (acc :: Type -> Type -> Type) env aenv a where
  OpenAccFun :: { unOpenAccFun :: OpenFun env aenv a } -> OpenAccFun acc env aenv a

-- We can use the same plumbing to rebuildPartial all the things we want to rebuild.
--
instance Rebuildable (OpenAccExp acc env) where
  type AccClo (OpenAccExp acc env) = acc
  {-# INLINEABLE rebuildPartial #-}
  rebuildPartial v (OpenAccExp e) = OpenAccExp <$> Stats.substitution "rebuild" (rebuildOpenExp (pure . IE) (reindexAvar v) e)

instance Rebuildable (OpenAccFun acc env) where
  type AccClo (OpenAccFun acc env) = acc
  {-# INLINEABLE rebuildPartial #-}
  rebuildPartial v (OpenAccFun f) = OpenAccFun <$> Stats.substitution "rebuild" (rebuildFun (pure . IE) (reindexAvar v) f)

instance RebuildableAcc acc => Rebuildable (PreOpenAcc acc) where
  type AccClo (PreOpenAcc acc) = acc
  {-# INLINEABLE rebuildPartial #-}
  rebuildPartial x = Stats.substitution "rebuild" $ rebuildPreOpenAcc rebuildPartial x

instance RebuildableAcc acc => Rebuildable (PreOpenAfun acc) where
  type AccClo (PreOpenAfun acc) = acc
  {-# INLINEABLE rebuildPartial #-}
  rebuildPartial x = Stats.substitution "rebuild" $ rebuildAfun rebuildPartial x

instance Rebuildable OpenAcc where
  type AccClo OpenAcc = OpenAcc
  {-# INLINEABLE rebuildPartial #-}
  rebuildPartial x = Stats.substitution "rebuild" $ rebuildOpenAcc x

instance RebuildableExp OpenExp where
  {-# INLINEABLE rebuildPartialE #-}
  rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildOpenExp v (ReindexAvar pure) x

instance RebuildableExp OpenFun where
  {-# INLINEABLE rebuildPartialE #-}
  rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildFun v (ReindexAvar pure) x

-- NOTE: [Weakening]
--
-- Weakening is something we usually take for granted: every time you learn a
-- new word, old sentences still make sense. If a conclusion is justified by a
-- hypothesis, it is still justified if you add more hypotheses. Similarly, a
-- term remains in scope if you bind more (fresh) variables. Weakening is the
-- operation of shifting things from one scope to a larger scope in which new
-- things have become meaningful, but no old things have vanished.
--
-- When we use a named representation (or HOAS) we get weakening for free. But
-- in the de Bruijn representation weakening takes work: you have to shift all
-- variable references to make room for the new bindings.
--

class Sink f where
  weaken :: env :> env' -> f env t -> f env' t

  -- TLM: We can't use this default instance because it doesn't lead to
  --      specialised code. Perhaps the INLINEABLE pragma is ignored: GHC bug?
  --
  -- {-# INLINEABLE weaken #-}
  -- default weaken :: Rebuildable f => env :> env' -> f env t -> f env' t
  -- weaken k = Stats.substitution "weaken" . rebuildA rebuildWeakenVar

--instance Rebuildable f => Sink f where -- undecidable, incoherent
--  weaken k = Stats.substitution "weaken" . rebuildA rebuildWeakenVar

instance Sink Idx where
  {-# INLINEABLE weaken #-}
  weaken = (>:>)

instance Sink (Var s) where
  {-# INLINEABLE weaken #-}
  weaken k (Var s ix) = Var s (k >:> ix)

weakenVars :: env :> env' -> Vars s env t -> Vars s env' t
weakenVars _  TupRunit      = TupRunit
weakenVars k (TupRsingle v) = TupRsingle $ weaken k v
weakenVars k (TupRpair v w) = TupRpair (weakenVars k v) (weakenVars k w)

rebuildWeakenVar :: env :> env' -> ArrayVar env (Array sh e) -> PreOpenAcc acc env' (Array sh e)
rebuildWeakenVar k (Var s idx) = Avar $ Var s $ k >:> idx

rebuildWeakenEvar :: env :> env' -> ExpVar env t -> OpenExp env' aenv t
rebuildWeakenEvar k (Var s idx) = Evar $ Var s $ k >:> idx

instance RebuildableAcc acc => Sink (PreOpenAcc acc) where
  {-# INLINEABLE weaken #-}
  weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)

instance RebuildableAcc acc => Sink (PreOpenAfun acc) where
  {-# INLINEABLE weaken #-}
  weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)

instance Sink (OpenExp env) where
  {-# INLINEABLE weaken #-}
  weaken k = Stats.substitution "weaken" . runIdentity . rebuildOpenExp (Identity . Evar) (ReindexAvar (Identity . weaken k))

instance Sink (OpenFun env) where
  {-# INLINEABLE weaken #-}
  weaken k = Stats.substitution "weaken" . runIdentity . rebuildFun (Identity . Evar) (ReindexAvar (Identity . weaken k))

instance Sink Boundary where
  {-# INLINEABLE weaken #-}
  weaken k bndy =
    case bndy of
      Clamp      -> Clamp
      Mirror     -> Mirror
      Wrap       -> Wrap
      Constant c -> Constant c
      Function f -> Function (weaken k f)

instance Sink OpenAcc where
  {-# INLINEABLE weaken #-}
  weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)

-- This rewrite rule is disabled because 'weaken' is now part of a type class.
-- As such, we cannot attach a NOINLINE pragma because it has many definitions.
-- {-# RULES
-- "weaken/weaken" forall a (v1 :: env' :> env'') (v2 :: env :> env').
--     weaken v1 (weaken v2 a) = weaken (v1 . v2) a
--  #-}

class SinkExp f where
  weakenE :: env :> env' -> f env aenv t -> f env' aenv t

  -- See comment in 'weaken'
  --
  -- {-# INLINEABLE weakenE #-}
  -- default weakenE :: RebuildableExp f => env :> env' -> f env aenv t -> f env' aenv t
  -- weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v)

instance SinkExp OpenExp where
  {-# INLINEABLE weakenE #-}
  weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v)

instance SinkExp OpenFun where
  {-# INLINEABLE weakenE #-}
  weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v)

-- See above for why this is disabled.
-- {-# RULES
-- "weakenE/weakenE" forall a (v1 :: env' :> env'') (v2 :: env :> env').
--    weakenE v1 (weakenE v2 a) = weakenE (v1 . v2) a
--  #-}

-- NOTE: [Strengthening]
--
-- Strengthening is the dual of weakening. Shifting terms from one scope to a
-- smaller scope. Of course this is not always possible. If the term contains
-- any variables not in the new environment, then it cannot be strengthened.
-- This partial behaviour is captured with 'Maybe'.
--

-- The type of partially shifting terms from one context into another.
type env :?> env' = forall t'. Idx env t' -> Maybe (Idx env' t')

{-# INLINEABLE strengthen #-}
strengthen :: forall f env env' t. Rebuildable f => env :?> env' -> f env t -> Maybe (f env' t)
strengthen k x = Stats.substitution "strengthen" $ rebuildPartial @f @Maybe @IdxA (\(Var s ix) -> fmap (IA . Var s) $ k ix) x

{-# INLINEABLE strengthenE #-}
strengthenE :: forall f env env' aenv t. RebuildableExp f => env :?> env' -> f env aenv t -> Maybe (f env' aenv t)
strengthenE k x = Stats.substitution "strengthenE" $ rebuildPartialE @f @Maybe @IdxE (\(Var tp ix) -> fmap (IE . Var tp) $ k ix) x

strengthenWithLHS :: LeftHandSide s t env1 env2 -> env2 :?> env1
strengthenWithLHS (LeftHandSideWildcard _) = Just
strengthenWithLHS (LeftHandSideSingle _)   = \ix -> case ix of
  ZeroIdx   -> Nothing
  SuccIdx i -> Just i
strengthenWithLHS (LeftHandSidePair l1 l2) = strengthenWithLHS l2 >=> strengthenWithLHS l1

strengthenAfter :: LeftHandSide s t env1 env2 -> LeftHandSide s t env1' env2' -> env1 :?> env1' -> env2 :?> env2'
strengthenAfter (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k
strengthenAfter (LeftHandSideSingle _)   (LeftHandSideSingle _)   k = \ix -> case ix of
  ZeroIdx   -> Just ZeroIdx
  SuccIdx i -> SuccIdx <$> k i
strengthenAfter (LeftHandSidePair l1 l2) (LeftHandSidePair l1' l2') k =
  strengthenAfter l2 l2' $ strengthenAfter l1 l1' k
strengthenAfter _ _ _ = error "Substitution.strengthenAfter: left hand sides do not match"

-- Simultaneous Substitution ===================================================
--

-- The scalar environment
-- ------------------

-- SEE: [Renaming and Substitution]
-- SEE: [Weakening]
--
class SyntacticExp f where
  varIn         :: ExpVar env t -> f env aenv t
  expOut        :: f env aenv t -> OpenExp env aenv t
  weakenExp     :: f env aenv t -> f (env, s) aenv t

newtype IdxE env aenv t = IE { unIE :: ExpVar env t }

instance SyntacticExp IdxE where
  varIn          = IE
  expOut         = Evar . unIE
  weakenExp (IE (Var tp ix)) = IE $ Var tp $ SuccIdx ix

instance SyntacticExp OpenExp where
  varIn          = Evar
  expOut         = id
  weakenExp      = runIdentity . rebuildOpenExp (Identity . weakenExp . IE) (ReindexAvar Identity)

{-# INLINEABLE shiftE #-}
shiftE
    :: (Applicative f, SyntacticExp fe)
    => RebuildEvar f fe env      env'      aenv
    -> RebuildEvar f fe (env, s) (env', s) aenv
shiftE _ (Var tp ZeroIdx)      = pure $ varIn (Var tp ZeroIdx)
shiftE v (Var tp (SuccIdx ix)) = weakenExp <$> v (Var tp ix)

{-# INLINEABLE shiftE' #-}
shiftE'
    :: (Applicative f, SyntacticExp fa)
    => ELeftHandSide t env1 env1'
    -> ELeftHandSide t env2 env2'
    -> RebuildEvar f fa env1  env2  aenv
    -> RebuildEvar f fa env1' env2' aenv
shiftE' (LeftHandSideWildcard _) (LeftHandSideWildcard _) v = v
shiftE' (LeftHandSideSingle _)   (LeftHandSideSingle _)   v = shiftE v
shiftE' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) v = shiftE' b1 b2 $ shiftE' a1 a2 v
shiftE' _ _ _ = error "Substitution: left hand sides do not match"

{-# INLINEABLE rebuildMaybeExp #-}
rebuildMaybeExp
    :: (HasCallStack, Applicative f, SyntacticExp fe)
    => RebuildEvar f fe env env' aenv'
    -> ReindexAvar f aenv aenv'
    -> Maybe (OpenExp env  aenv t)
    -> f (Maybe (OpenExp env' aenv' t))
rebuildMaybeExp _ _  Nothing  = pure Nothing
rebuildMaybeExp v av (Just x) = Just <$> rebuildOpenExp v av x

{-# INLINEABLE rebuildOpenExp #-}
rebuildOpenExp
    :: (HasCallStack, Applicative f, SyntacticExp fe)
    => RebuildEvar f fe env env' aenv'
    -> ReindexAvar f aenv aenv'
    -> OpenExp env  aenv t
    -> f (OpenExp env' aenv' t)
rebuildOpenExp v av@(ReindexAvar reindex) exp =
  case exp of
    Const t c           -> pure $ Const t c
    PrimConst c         -> pure $ PrimConst c
    Undef t             -> pure $ Undef t
    Evar var            -> expOut          <$> v var
    Let lhs a b
      | Exists lhs' <- rebuildLHS lhs
                        -> Let lhs'        <$> rebuildOpenExp v av a  <*> rebuildOpenExp (shiftE' lhs lhs' v) av b
    Pair e1 e2          -> Pair            <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2
    Nil                 -> pure Nil
    VecPack   vec e     -> VecPack   vec   <$> rebuildOpenExp v av e
    VecUnpack vec e     -> VecUnpack vec   <$> rebuildOpenExp v av e
    IndexSlice x ix sh  -> IndexSlice x    <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh
    IndexFull x ix sl   -> IndexFull x     <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl
    ToIndex shr sh ix   -> ToIndex shr     <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix
    FromIndex shr sh ix -> FromIndex shr   <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix
    Case e rhs def      -> Case            <$> rebuildOpenExp v av e  <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def
    Cond p t e          -> Cond            <$> rebuildOpenExp v av p  <*> rebuildOpenExp v av t  <*> rebuildOpenExp v av e
    While p f x         -> While           <$> rebuildFun v av p      <*> rebuildFun v av f      <*> rebuildOpenExp v av x
    PrimApp f x         -> PrimApp f       <$> rebuildOpenExp v av x
    Index a sh          -> Index           <$> reindex a              <*> rebuildOpenExp v av sh
    LinearIndex a i     -> LinearIndex     <$> reindex a              <*> rebuildOpenExp v av i
    Shape a             -> Shape           <$> reindex a
    ShapeSize shr sh    -> ShapeSize shr   <$> rebuildOpenExp v av sh
    Foreign tp ff f e   -> Foreign tp ff f <$> rebuildOpenExp v av e
    Coerce t1 t2 e      -> Coerce t1 t2    <$> rebuildOpenExp v av e

{-# INLINEABLE rebuildFun #-}
rebuildFun
    :: (HasCallStack, Applicative f, SyntacticExp fe)
    => RebuildEvar f fe env env' aenv'
    -> ReindexAvar f aenv aenv'
    -> OpenFun env  aenv  t
    -> f (OpenFun env' aenv' t)
rebuildFun v av fun =
  case fun of
    Body e -> Body <$> rebuildOpenExp v av e
    Lam lhs f
      | Exists lhs' <- rebuildLHS lhs
        -> Lam lhs' <$> rebuildFun (shiftE' lhs lhs' v) av f

-- The array environment
-- -----------------

type RebuildAcc acc =
  forall aenv aenv' f fa a. (HasCallStack, Applicative f, SyntacticAcc fa)
    => RebuildAvar f fa acc aenv aenv'
    -> acc aenv a
    -> f (acc aenv' a)

newtype IdxA (acc :: Type -> Type -> Type) aenv t = IA { unIA :: ArrayVar aenv t }

class SyntacticAcc f where
  avarIn        :: ArrayVar aenv (Array sh e) -> f acc aenv (Array sh e)
  accOut        :: f acc aenv (Array sh e) -> PreOpenAcc acc aenv (Array sh e)
  weakenAcc     :: RebuildAcc acc -> f acc aenv (Array sh e) -> f acc (aenv, s) (Array sh e)

instance SyntacticAcc IdxA where
  avarIn                       = IA
  accOut                       = Avar . unIA
  weakenAcc _ (IA (Var s idx)) = IA $ Var s $ SuccIdx idx

instance SyntacticAcc PreOpenAcc where
  avarIn        = Avar
  accOut        = id
  weakenAcc k   = runIdentity . rebuildPreOpenAcc k (Identity . weakenAcc k . IA)

type RebuildAvar f (fa :: (Type -> Type -> Type) -> Type -> Type -> Type) acc aenv aenv'
    = forall sh e. ArrayVar aenv (Array sh e) -> f (fa acc aenv' (Array sh e))

type RebuildEvar f fe env env' aenv' =
  forall t'. ExpVar env t' -> f (fe env' aenv' t')

newtype ReindexAvar f aenv aenv' =
  ReindexAvar (forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e)))

reindexAvar
    :: forall f fa acc aenv aenv'.
       (HasCallStack, Applicative f, SyntacticAcc fa)
    => RebuildAvar f fa acc aenv aenv'
    -> ReindexAvar f        aenv aenv'
reindexAvar v = ReindexAvar f where
  f :: forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e))
  f var = g <$> v var

  g :: fa acc aenv' (Array sh e) -> ArrayVar aenv' (Array sh e)
  g fa = case accOut fa of
    Avar var' -> var'
    _ -> internalError "An Avar which was used in an Exp was mapped to an array term other than Avar. This mapping is invalid as an Exp can only contain array variables."


{-# INLINEABLE shiftA #-}
shiftA
    :: (HasCallStack, Applicative f, SyntacticAcc fa)
    => RebuildAcc acc
    -> RebuildAvar f fa acc aenv aenv'
    -> ArrayVar  (aenv,  s) (Array sh e)
    -> f (fa acc (aenv', s) (Array sh e))
shiftA _ _ (Var s ZeroIdx)      = pure $ avarIn $ Var s ZeroIdx
shiftA k v (Var s (SuccIdx ix)) = weakenAcc k <$> v (Var s ix)

shiftA'
    :: (HasCallStack, Applicative f, SyntacticAcc fa)
    => ALeftHandSide t aenv1 aenv1'
    -> ALeftHandSide t aenv2 aenv2'
    -> RebuildAcc acc
    -> RebuildAvar f fa acc aenv1  aenv2
    -> RebuildAvar f fa acc aenv1' aenv2'
shiftA' (LeftHandSideWildcard _) (LeftHandSideWildcard _) _ v = v
shiftA' (LeftHandSideSingle _)   (LeftHandSideSingle _)   k v = shiftA k v
shiftA' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k v = shiftA' b1 b2 k $ shiftA' a1 a2 k v
shiftA' _ _ _ _ = internalError "left hand sides do not match"

{-# INLINEABLE rebuildOpenAcc #-}
rebuildOpenAcc
    :: (HasCallStack, Applicative f, SyntacticAcc fa)
    => (forall sh e. ArrayVar aenv (Array sh e) -> f (fa OpenAcc aenv' (Array sh e)))
    -> OpenAcc aenv  t
    -> f (OpenAcc aenv' t)
rebuildOpenAcc av (OpenAcc acc) = OpenAcc <$> rebuildPreOpenAcc rebuildOpenAcc av acc

{-# INLINEABLE rebuildPreOpenAcc #-}
rebuildPreOpenAcc
    :: (HasCallStack, Applicative f, SyntacticAcc fa)
    => RebuildAcc acc
    -> RebuildAvar f fa acc aenv aenv'
    -> PreOpenAcc acc aenv  t
    -> f (PreOpenAcc acc aenv' t)
rebuildPreOpenAcc k av acc =
  case acc of
    Use repr a                -> pure $ Use repr a
    Alet lhs a b              -> rebuildAlet k av lhs a b
    Avar ix                   -> accOut          <$> av ix
    Apair as bs               -> Apair           <$> k av as <*> k av bs
    Anil                      -> pure Anil
    Apply repr f a            -> Apply repr      <$> rebuildAfun k av f <*> k av a
    Acond p t e               -> Acond           <$> rebuildOpenExp (pure . IE) av' p <*> k av t <*> k av e
    Awhile p f a              -> Awhile          <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a
    Unit tp e                 -> Unit tp         <$> rebuildOpenExp (pure . IE) av' e
    Reshape shr e a           -> Reshape shr     <$> rebuildOpenExp (pure . IE) av' e <*> k av a
    Generate repr e f         -> Generate repr   <$> rebuildOpenExp (pure . IE) av' e <*> rebuildFun (pure . IE) av' f
    Transform repr sh ix f a  -> Transform repr  <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' ix <*> rebuildFun (pure . IE) av' f <*> k av a
    Replicate sl slix a       -> Replicate sl    <$> rebuildOpenExp (pure . IE) av' slix <*> k av a
    Slice sl a slix           -> Slice sl        <$> k av a <*> rebuildOpenExp (pure . IE) av' slix
    Map tp f a                -> Map tp          <$> rebuildFun (pure . IE) av' f <*> k av a
    ZipWith tp f a1 a2        -> ZipWith tp      <$> rebuildFun (pure . IE) av' f <*> k av a1 <*> k av a2
    Fold f z a                -> Fold            <$> rebuildFun (pure . IE) av' f <*> rebuildMaybeExp (pure . IE) av' z <*> k av a
    FoldSeg itp f z a s       -> FoldSeg itp     <$> rebuildFun (pure . IE) av' f <*> rebuildMaybeExp (pure . IE) av' z <*> k av a <*> k av s
    Scan  d f z a             -> Scan  d         <$> rebuildFun (pure . IE) av' f <*> rebuildMaybeExp (pure . IE) av' z <*> k av a
    Scan' d f z a             -> Scan' d         <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a
    Permute f1 a1 f2 a2       -> Permute         <$> rebuildFun (pure . IE) av' f1 <*> k av a1 <*> rebuildFun (pure . IE) av' f2 <*> k av a2
    Backpermute shr sh f a    -> Backpermute shr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' f <*> k av a
    Stencil sr tp f b a       -> Stencil sr tp   <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b  <*> k av a
    Stencil2 s1 s2 tp f b1 a1 b2 a2
                              -> Stencil2 s1 s2 tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b1 <*> k av a1 <*> rebuildBoundary av' b2 <*> k av a2
    Aforeign repr ff afun as  -> Aforeign repr ff afun <$> k av as
    -- Collect seq             -> Collect      <$> rebuildSeq k av seq
  where
    av' = reindexAvar av

{-# INLINEABLE rebuildAfun #-}
rebuildAfun
    :: (HasCallStack, Applicative f, SyntacticAcc fa)
    => RebuildAcc acc
    -> RebuildAvar f fa acc aenv aenv'
    -> PreOpenAfun acc aenv  t
    -> f (PreOpenAfun acc aenv' t)
rebuildAfun k av (Abody b) = Abody <$> k av b
rebuildAfun k av (Alam lhs1 f)
  | Exists lhs2 <- rebuildLHS lhs1
  = Alam lhs2 <$> rebuildAfun k (shiftA' lhs1 lhs2 k av) f

rebuildAlet
    :: forall f fa acc aenv1 aenv1' aenv2 bndArrs arrs. (HasCallStack, Applicative f, SyntacticAcc fa)
    => RebuildAcc acc
    -> RebuildAvar f fa acc aenv1 aenv2
    -> ALeftHandSide bndArrs aenv1 aenv1'
    -> acc aenv1  bndArrs
    -> acc aenv1' arrs
    -> f (PreOpenAcc acc aenv2 arrs)
rebuildAlet k av lhs1 bind1 body1
  | Exists lhs2 <- rebuildLHS lhs1
  = Alet lhs2 <$> k av bind1 <*> k (shiftA' lhs1 lhs2 k av) body1

{-# INLINEABLE rebuildLHS #-}
rebuildLHS :: LeftHandSide s t aenv1 aenv1' -> Exists (LeftHandSide s t aenv2)
rebuildLHS (LeftHandSideWildcard r) = Exists $ LeftHandSideWildcard r
rebuildLHS (LeftHandSideSingle s)   = Exists $ LeftHandSideSingle s
rebuildLHS (LeftHandSidePair as bs)
  | Exists as' <- rebuildLHS as
  , Exists bs' <- rebuildLHS bs
  = Exists $ LeftHandSidePair as' bs'

{-# INLINEABLE rebuildBoundary #-}
rebuildBoundary
    :: Applicative f
    => ReindexAvar f aenv aenv'
    -> Boundary aenv t
    -> f (Boundary aenv' t)
rebuildBoundary av bndy =
  case bndy of
    Clamp       -> pure Clamp
    Mirror      -> pure Mirror
    Wrap        -> pure Wrap
    Constant v  -> pure (Constant v)
    Function f  -> Function <$> rebuildFun (pure . IE) av f

{--
{-# INLINEABLE rebuildSeq #-}
rebuildSeq
    :: (SyntacticAcc fa, Applicative f)
    => RebuildAcc acc
    -> RebuildAvar f fa acc aenv aenv'
    -> PreOpenSeq acc aenv senv t
    -> f (PreOpenSeq acc aenv' senv t)
rebuildSeq k v seq =
  case seq of
    Producer p s -> Producer <$> (rebuildP k v p) <*> (rebuildSeq k v s)
    Consumer c   -> Consumer <$> (rebuildC k v c)
    Reify ix     -> pure $ Reify ix

{-# INLINEABLE rebuildP #-}
rebuildP :: (SyntacticAcc fa, Applicative f)
         => RebuildAcc acc
         -> RebuildAvar f fa acc aenv aenv'
         -> Producer acc aenv senv a
         -> f (Producer acc aenv' senv a)
rebuildP k v p =
  case p of
    StreamIn arrs        -> pure (StreamIn arrs)
    ToSeq sl slix acc    -> ToSeq sl slix <$> k v acc
    MapSeq f x           -> MapSeq <$> rebuildAfun k v f <*> pure x
    ChunkedMapSeq f x    -> ChunkedMapSeq <$> rebuildAfun k v f <*> pure x
    ZipWithSeq f x y     -> ZipWithSeq <$> rebuildAfun k v f <*> pure x <*> pure y
    ScanSeq f e x        -> ScanSeq <$> rebuildFun (pure . IE) v f <*> rebuildOpenExp (pure . IE) v e <*> pure x

{-# INLINEABLE rebuildC #-}
rebuildC :: forall acc fa f aenv aenv' senv a. (SyntacticAcc fa, Applicative f)
         => RebuildAcc acc
         -> RebuildAvar f fa acc aenv aenv'
         -> Consumer acc aenv senv a
         -> f (Consumer acc aenv' senv a)
rebuildC k v c =
  case c of
    FoldSeq f e x          -> FoldSeq <$> rebuildFun (pure . IE) v f <*> rebuildOpenExp (pure . IE) v e <*> pure x
    FoldSeqFlatten f acc x -> FoldSeqFlatten <$> rebuildAfun k v f <*> k v acc <*> pure x
    Stuple t               -> Stuple <$> rebuildT t
  where
    rebuildT :: Atuple (Consumer acc aenv senv) t -> f (Atuple (Consumer acc aenv' senv) t)
    rebuildT NilAtup        = pure NilAtup
    rebuildT (SnocAtup t s) = SnocAtup <$> (rebuildT t) <*> (rebuildC k v s)
--}

extractExpVars :: OpenExp env aenv a -> Maybe (ExpVars env a)
extractExpVars Nil          = Just TupRunit
extractExpVars (Pair e1 e2) = TupRpair <$> extractExpVars e1 <*> extractExpVars e2
extractExpVars (Evar v)     = Just $ TupRsingle v
extractExpVars _            = Nothing