{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Trafo.Substitution (
inline, inlineVars, compose,
subTop, subAtop,
(:>), Sink(..), SinkExp(..), weakenVars,
(:?>), strengthen, strengthenE,
RebuildAcc, Rebuildable(..), RebuildableAcc,
RebuildableExp(..), rebuildWeakenVar, rebuildLHS,
OpenAccFun(..), OpenAccExp(..),
isIdentity, isIdentityIndexing, extractExpVars,
bindingIsTrivial,
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Array
import qualified Data.Array.Accelerate.Debug.Stats as Stats
import Data.Kind
import Control.Applicative hiding ( Const )
import Control.Monad
import Prelude hiding ( exp, seq )
infixr `compose`
lhsFullVars :: forall s a env1 env2. LeftHandSide s a env1 env2 -> Maybe (Vars s env2 a)
lhsFullVars = fmap snd . go weakenId
where
go :: forall env env' b. (env' :> env2) -> LeftHandSide s b env env' -> Maybe (env :> env2, Vars s env2 b)
go k (LeftHandSideWildcard TupRunit) = Just (k, TupRunit)
go k (LeftHandSideSingle s) = Just $ (weakenSucc $ k, TupRsingle $ Var s $ k >:> ZeroIdx)
go k (LeftHandSidePair l1 l2)
| Just (k', v2) <- go k l2
, Just (k'', v1) <- go k' l1 = Just (k'', TupRpair v1 v2)
go _ _ = Nothing
bindingIsTrivial :: LeftHandSide s a env1 env2 -> Vars s env2 b -> Maybe (a :~: b)
bindingIsTrivial lhs vars
| Just lhsVars <- lhsFullVars lhs
, Just Refl <- matchVars vars lhsVars
= Just Refl
bindingIsTrivial _ _ = Nothing
isIdentity :: OpenFun env aenv (a -> b) -> Maybe (a :~: b)
isIdentity (Lam lhs (Body (extractExpVars -> Just vars))) = bindingIsTrivial lhs vars
isIdentity _ = Nothing
isIdentityIndexing :: OpenFun env aenv (a -> b) -> Maybe (ArrayVar aenv (Array a b))
isIdentityIndexing (Lam lhs (Body body))
| Index avar ix <- body
, Just vars <- extractExpVars ix
, Just Refl <- bindingIsTrivial lhs vars
= Just avar
isIdentityIndexing _ = Nothing
inline :: OpenExp (env, s) aenv t
-> OpenExp env aenv s
-> OpenExp env aenv t
inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f
inlineVars :: forall env env' aenv t1 t2.
ELeftHandSide t1 env env'
-> OpenExp env' aenv t2
-> OpenExp env aenv t1
-> Maybe (OpenExp env aenv t2)
inlineVars lhsBound expr bound
| Just vars <- lhsFullVars lhsBound = substitute (strengthenWithLHS lhsBound) weakenId vars expr
where
substitute
:: forall env1 env2 t.
env1 :?> env2
-> env :> env2
-> ExpVars env1 t1
-> OpenExp env1 aenv t
-> Maybe (OpenExp env2 aenv t)
substitute _ k2 vars (extractExpVars -> Just vars')
| Just Refl <- matchVars vars vars' = Just $ weakenE k2 bound
substitute k1 k2 vars topExp = case topExp of
Let lhs e1 e2
| Exists lhs' <- rebuildLHS lhs
-> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2
Evar (Var t ix) -> Evar . Var t <$> k1 ix
Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1
Pair e1 e2 -> Pair <$> travE e1 <*> travE e2
Nil -> Just Nil
VecPack vec e1 -> VecPack vec <$> travE e1
VecUnpack vec e1 -> VecUnpack vec <$> travE e1
IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2
IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2
ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2
FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2
Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def
Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3
While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1
Const t c -> Just $ Const t c
PrimConst c -> Just $ PrimConst c
PrimApp p e1 -> PrimApp p <$> travE e1
Index a e1 -> Index a <$> travE e1
LinearIndex a e1 -> LinearIndex a <$> travE e1
Shape a -> Just $ Shape a
ShapeSize shr e1 -> ShapeSize shr <$> travE e1
Undef t -> Just $ Undef t
Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1
where
travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s)
travE = substitute k1 k2 vars
travF :: OpenFun env1 aenv s -> Maybe (OpenFun env2 aenv s)
travF = substituteF k1 k2 vars
travMaybeE :: Maybe (OpenExp env1 aenv s) -> Maybe (Maybe (OpenExp env2 aenv s))
travMaybeE Nothing = pure Nothing
travMaybeE (Just x) = Just <$> travE x
substituteF :: forall env1 env2 t.
env1 :?> env2
-> env :> env2
-> ExpVars env1 t1
-> OpenFun env1 aenv t
-> Maybe (OpenFun env2 aenv t)
substituteF k1 k2 vars (Body e) = Body <$> substitute k1 k2 vars e
substituteF k1 k2 vars (Lam lhs f)
| Exists lhs' <- rebuildLHS lhs = Lam lhs' <$> substituteF (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) f
inlineVars _ _ _ = Nothing
compose :: HasCallStack
=> OpenFun env aenv (b -> c)
-> OpenFun env aenv (a -> b)
-> OpenFun env aenv (a -> c)
compose f@(Lam lhsB (Body c)) g@(Lam lhsA (Body b))
| Stats.substitution "compose" False = undefined
| Just Refl <- isIdentity f = g
| Just Refl <- isIdentity g = f
| Exists lhsB' <- rebuildLHS lhsB
= Lam lhsA
$ Body
$ Let lhsB' b
$ weakenE (sinkWithLHS lhsB lhsB' $ weakenWithLHS lhsA) c
compose _
_ = error "compose: impossible evaluation"
subTop :: OpenExp env aenv s -> ExpVar (env, s) t -> OpenExp env aenv t
subTop s (Var _ ZeroIdx ) = s
subTop _ (Var tp (SuccIdx ix)) = Evar $ Var tp ix
subAtop :: PreOpenAcc acc aenv t -> ArrayVar (aenv, t) (Array sh2 e2) -> PreOpenAcc acc aenv (Array sh2 e2)
subAtop t (Var _ ZeroIdx ) = t
subAtop _ (Var repr (SuccIdx idx)) = Avar $ Var repr idx
data Identity a = Identity { runIdentity :: a }
instance Functor Identity where
{-# INLINE fmap #-}
fmap f (Identity a) = Identity (f a)
instance Applicative Identity where
{-# INLINE (<*>) #-}
{-# INLINE pure #-}
Identity f <*> Identity a = Identity (f a)
pure a = Identity a
class Rebuildable f where
{-# MINIMAL rebuildPartial #-}
type AccClo f :: Type -> Type -> Type
rebuildPartial :: (Applicative f', SyntacticAcc fa)
=> (forall sh e. ArrayVar aenv (Array sh e) -> f' (fa (AccClo f) aenv' (Array sh e)))
-> f aenv a
-> f' (f aenv' a)
{-# INLINEABLE rebuildA #-}
rebuildA :: (SyntacticAcc fa)
=> (forall sh e. ArrayVar aenv (Array sh e) -> fa (AccClo f) aenv' (Array sh e))
-> f aenv a
-> f aenv' a
rebuildA av = runIdentity . rebuildPartial (Identity . av)
class RebuildableExp f where
{-# MINIMAL rebuildPartialE #-}
rebuildPartialE :: (Applicative f', SyntacticExp fe)
=> (forall e'. ExpVar env e' -> f' (fe env' aenv e'))
-> f env aenv e
-> f' (f env' aenv e)
{-# INLINEABLE rebuildE #-}
rebuildE :: SyntacticExp fe
=> (forall e'. ExpVar env e' -> fe env' aenv e')
-> f env aenv e
-> f env' aenv e
rebuildE v = runIdentity . rebuildPartialE (Identity . v)
type RebuildableAcc acc = (Rebuildable acc, AccClo acc ~ acc)
data OpenAccExp (acc :: Type -> Type -> Type) env aenv a where
OpenAccExp :: { unOpenAccExp :: OpenExp env aenv a } -> OpenAccExp acc env aenv a
data OpenAccFun (acc :: Type -> Type -> Type) env aenv a where
OpenAccFun :: { unOpenAccFun :: OpenFun env aenv a } -> OpenAccFun acc env aenv a
instance Rebuildable (OpenAccExp acc env) where
type AccClo (OpenAccExp acc env) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial v (OpenAccExp e) = OpenAccExp <$> Stats.substitution "rebuild" (rebuildOpenExp (pure . IE) (reindexAvar v) e)
instance Rebuildable (OpenAccFun acc env) where
type AccClo (OpenAccFun acc env) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial v (OpenAccFun f) = OpenAccFun <$> Stats.substitution "rebuild" (rebuildFun (pure . IE) (reindexAvar v) f)
instance RebuildableAcc acc => Rebuildable (PreOpenAcc acc) where
type AccClo (PreOpenAcc acc) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial x = Stats.substitution "rebuild" $ rebuildPreOpenAcc rebuildPartial x
instance RebuildableAcc acc => Rebuildable (PreOpenAfun acc) where
type AccClo (PreOpenAfun acc) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial x = Stats.substitution "rebuild" $ rebuildAfun rebuildPartial x
instance Rebuildable OpenAcc where
type AccClo OpenAcc = OpenAcc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial x = Stats.substitution "rebuild" $ rebuildOpenAcc x
instance RebuildableExp OpenExp where
{-# INLINEABLE rebuildPartialE #-}
rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildOpenExp v (ReindexAvar pure) x
instance RebuildableExp OpenFun where
{-# INLINEABLE rebuildPartialE #-}
rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildFun v (ReindexAvar pure) x
class Sink f where
weaken :: env :> env' -> f env t -> f env' t
instance Sink Idx where
{-# INLINEABLE weaken #-}
weaken = (>:>)
instance Sink (Var s) where
{-# INLINEABLE weaken #-}
weaken k (Var s ix) = Var s (k >:> ix)
weakenVars :: env :> env' -> Vars s env t -> Vars s env' t
weakenVars _ TupRunit = TupRunit
weakenVars k (TupRsingle v) = TupRsingle $ weaken k v
weakenVars k (TupRpair v w) = TupRpair (weakenVars k v) (weakenVars k w)
rebuildWeakenVar :: env :> env' -> ArrayVar env (Array sh e) -> PreOpenAcc acc env' (Array sh e)
rebuildWeakenVar k (Var s idx) = Avar $ Var s $ k >:> idx
rebuildWeakenEvar :: env :> env' -> ExpVar env t -> OpenExp env' aenv t
rebuildWeakenEvar k (Var s idx) = Evar $ Var s $ k >:> idx
instance RebuildableAcc acc => Sink (PreOpenAcc acc) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)
instance RebuildableAcc acc => Sink (PreOpenAfun acc) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)
instance Sink (OpenExp env) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . runIdentity . rebuildOpenExp (Identity . Evar) (ReindexAvar (Identity . weaken k))
instance Sink (OpenFun env) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . runIdentity . rebuildFun (Identity . Evar) (ReindexAvar (Identity . weaken k))
instance Sink Boundary where
{-# INLINEABLE weaken #-}
weaken k bndy =
case bndy of
Clamp -> Clamp
Mirror -> Mirror
Wrap -> Wrap
Constant c -> Constant c
Function f -> Function (weaken k f)
instance Sink OpenAcc where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)
class SinkExp f where
weakenE :: env :> env' -> f env aenv t -> f env' aenv t
instance SinkExp OpenExp where
{-# INLINEABLE weakenE #-}
weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v)
instance SinkExp OpenFun where
{-# INLINEABLE weakenE #-}
weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v)
type env :?> env' = forall t'. Idx env t' -> Maybe (Idx env' t')
{-# INLINEABLE strengthen #-}
strengthen :: forall f env env' t. Rebuildable f => env :?> env' -> f env t -> Maybe (f env' t)
strengthen k x = Stats.substitution "strengthen" $ rebuildPartial @f @Maybe @IdxA (\(Var s ix) -> fmap (IA . Var s) $ k ix) x
{-# INLINEABLE strengthenE #-}
strengthenE :: forall f env env' aenv t. RebuildableExp f => env :?> env' -> f env aenv t -> Maybe (f env' aenv t)
strengthenE k x = Stats.substitution "strengthenE" $ rebuildPartialE @f @Maybe @IdxE (\(Var tp ix) -> fmap (IE . Var tp) $ k ix) x
strengthenWithLHS :: LeftHandSide s t env1 env2 -> env2 :?> env1
strengthenWithLHS (LeftHandSideWildcard _) = Just
strengthenWithLHS (LeftHandSideSingle _) = \ix -> case ix of
ZeroIdx -> Nothing
SuccIdx i -> Just i
strengthenWithLHS (LeftHandSidePair l1 l2) = strengthenWithLHS l2 >=> strengthenWithLHS l1
strengthenAfter :: LeftHandSide s t env1 env2 -> LeftHandSide s t env1' env2' -> env1 :?> env1' -> env2 :?> env2'
strengthenAfter (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k
strengthenAfter (LeftHandSideSingle _) (LeftHandSideSingle _) k = \ix -> case ix of
ZeroIdx -> Just ZeroIdx
SuccIdx i -> SuccIdx <$> k i
strengthenAfter (LeftHandSidePair l1 l2) (LeftHandSidePair l1' l2') k =
strengthenAfter l2 l2' $ strengthenAfter l1 l1' k
strengthenAfter _ _ _ = error "Substitution.strengthenAfter: left hand sides do not match"
class SyntacticExp f where
varIn :: ExpVar env t -> f env aenv t
expOut :: f env aenv t -> OpenExp env aenv t
weakenExp :: f env aenv t -> f (env, s) aenv t
newtype IdxE env aenv t = IE { unIE :: ExpVar env t }
instance SyntacticExp IdxE where
varIn = IE
expOut = Evar . unIE
weakenExp (IE (Var tp ix)) = IE $ Var tp $ SuccIdx ix
instance SyntacticExp OpenExp where
varIn = Evar
expOut = id
weakenExp = runIdentity . rebuildOpenExp (Identity . weakenExp . IE) (ReindexAvar Identity)
{-# INLINEABLE shiftE #-}
shiftE
:: (Applicative f, SyntacticExp fe)
=> RebuildEvar f fe env env' aenv
-> RebuildEvar f fe (env, s) (env', s) aenv
shiftE _ (Var tp ZeroIdx) = pure $ varIn (Var tp ZeroIdx)
shiftE v (Var tp (SuccIdx ix)) = weakenExp <$> v (Var tp ix)
{-# INLINEABLE shiftE' #-}
shiftE'
:: (Applicative f, SyntacticExp fa)
=> ELeftHandSide t env1 env1'
-> ELeftHandSide t env2 env2'
-> RebuildEvar f fa env1 env2 aenv
-> RebuildEvar f fa env1' env2' aenv
shiftE' (LeftHandSideWildcard _) (LeftHandSideWildcard _) v = v
shiftE' (LeftHandSideSingle _) (LeftHandSideSingle _) v = shiftE v
shiftE' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) v = shiftE' b1 b2 $ shiftE' a1 a2 v
shiftE' _ _ _ = error "Substitution: left hand sides do not match"
{-# INLINEABLE rebuildMaybeExp #-}
rebuildMaybeExp
:: (HasCallStack, Applicative f, SyntacticExp fe)
=> RebuildEvar f fe env env' aenv'
-> ReindexAvar f aenv aenv'
-> Maybe (OpenExp env aenv t)
-> f (Maybe (OpenExp env' aenv' t))
rebuildMaybeExp _ _ Nothing = pure Nothing
rebuildMaybeExp v av (Just x) = Just <$> rebuildOpenExp v av x
{-# INLINEABLE rebuildOpenExp #-}
rebuildOpenExp
:: (HasCallStack, Applicative f, SyntacticExp fe)
=> RebuildEvar f fe env env' aenv'
-> ReindexAvar f aenv aenv'
-> OpenExp env aenv t
-> f (OpenExp env' aenv' t)
rebuildOpenExp v av@(ReindexAvar reindex) exp =
case exp of
Const t c -> pure $ Const t c
PrimConst c -> pure $ PrimConst c
Undef t -> pure $ Undef t
Evar var -> expOut <$> v var
Let lhs a b
| Exists lhs' <- rebuildLHS lhs
-> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b
Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2
Nil -> pure Nil
VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e
VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e
IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh
IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl
ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix
FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix
Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def
Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e
While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x
PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x
Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh
LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i
Shape a -> Shape <$> reindex a
ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh
Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e
Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e
{-# INLINEABLE rebuildFun #-}
rebuildFun
:: (HasCallStack, Applicative f, SyntacticExp fe)
=> RebuildEvar f fe env env' aenv'
-> ReindexAvar f aenv aenv'
-> OpenFun env aenv t
-> f (OpenFun env' aenv' t)
rebuildFun v av fun =
case fun of
Body e -> Body <$> rebuildOpenExp v av e
Lam lhs f
| Exists lhs' <- rebuildLHS lhs
-> Lam lhs' <$> rebuildFun (shiftE' lhs lhs' v) av f
type RebuildAcc acc =
forall aenv aenv' f fa a. (HasCallStack, Applicative f, SyntacticAcc fa)
=> RebuildAvar f fa acc aenv aenv'
-> acc aenv a
-> f (acc aenv' a)
newtype IdxA (acc :: Type -> Type -> Type) aenv t = IA { unIA :: ArrayVar aenv t }
class SyntacticAcc f where
avarIn :: ArrayVar aenv (Array sh e) -> f acc aenv (Array sh e)
accOut :: f acc aenv (Array sh e) -> PreOpenAcc acc aenv (Array sh e)
weakenAcc :: RebuildAcc acc -> f acc aenv (Array sh e) -> f acc (aenv, s) (Array sh e)
instance SyntacticAcc IdxA where
avarIn = IA
accOut = Avar . unIA
weakenAcc _ (IA (Var s idx)) = IA $ Var s $ SuccIdx idx
instance SyntacticAcc PreOpenAcc where
avarIn = Avar
accOut = id
weakenAcc k = runIdentity . rebuildPreOpenAcc k (Identity . weakenAcc k . IA)
type RebuildAvar f (fa :: (Type -> Type -> Type) -> Type -> Type -> Type) acc aenv aenv'
= forall sh e. ArrayVar aenv (Array sh e) -> f (fa acc aenv' (Array sh e))
type RebuildEvar f fe env env' aenv' =
forall t'. ExpVar env t' -> f (fe env' aenv' t')
newtype ReindexAvar f aenv aenv' =
ReindexAvar (forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e)))
reindexAvar
:: forall f fa acc aenv aenv'.
(HasCallStack, Applicative f, SyntacticAcc fa)
=> RebuildAvar f fa acc aenv aenv'
-> ReindexAvar f aenv aenv'
reindexAvar v = ReindexAvar f where
f :: forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e))
f var = g <$> v var
g :: fa acc aenv' (Array sh e) -> ArrayVar aenv' (Array sh e)
g fa = case accOut fa of
Avar var' -> var'
_ -> internalError "An Avar which was used in an Exp was mapped to an array term other than Avar. This mapping is invalid as an Exp can only contain array variables."
{-# INLINEABLE shiftA #-}
shiftA
:: (HasCallStack, Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> RebuildAvar f fa acc aenv aenv'
-> ArrayVar (aenv, s) (Array sh e)
-> f (fa acc (aenv', s) (Array sh e))
shiftA _ _ (Var s ZeroIdx) = pure $ avarIn $ Var s ZeroIdx
shiftA k v (Var s (SuccIdx ix)) = weakenAcc k <$> v (Var s ix)
shiftA'
:: (HasCallStack, Applicative f, SyntacticAcc fa)
=> ALeftHandSide t aenv1 aenv1'
-> ALeftHandSide t aenv2 aenv2'
-> RebuildAcc acc
-> RebuildAvar f fa acc aenv1 aenv2
-> RebuildAvar f fa acc aenv1' aenv2'
shiftA' (LeftHandSideWildcard _) (LeftHandSideWildcard _) _ v = v
shiftA' (LeftHandSideSingle _) (LeftHandSideSingle _) k v = shiftA k v
shiftA' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k v = shiftA' b1 b2 k $ shiftA' a1 a2 k v
shiftA' _ _ _ _ = internalError "left hand sides do not match"
{-# INLINEABLE rebuildOpenAcc #-}
rebuildOpenAcc
:: (HasCallStack, Applicative f, SyntacticAcc fa)
=> (forall sh e. ArrayVar aenv (Array sh e) -> f (fa OpenAcc aenv' (Array sh e)))
-> OpenAcc aenv t
-> f (OpenAcc aenv' t)
rebuildOpenAcc av (OpenAcc acc) = OpenAcc <$> rebuildPreOpenAcc rebuildOpenAcc av acc
{-# INLINEABLE rebuildPreOpenAcc #-}
rebuildPreOpenAcc
:: (HasCallStack, Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> RebuildAvar f fa acc aenv aenv'
-> PreOpenAcc acc aenv t
-> f (PreOpenAcc acc aenv' t)
rebuildPreOpenAcc k av acc =
case acc of
Use repr a -> pure $ Use repr a
Alet lhs a b -> rebuildAlet k av lhs a b
Avar ix -> accOut <$> av ix
Apair as bs -> Apair <$> k av as <*> k av bs
Anil -> pure Anil
Apply repr f a -> Apply repr <$> rebuildAfun k av f <*> k av a
Acond p t e -> Acond <$> rebuildOpenExp (pure . IE) av' p <*> k av t <*> k av e
Awhile p f a -> Awhile <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a
Unit tp e -> Unit tp <$> rebuildOpenExp (pure . IE) av' e
Reshape shr e a -> Reshape shr <$> rebuildOpenExp (pure . IE) av' e <*> k av a
Generate repr e f -> Generate repr <$> rebuildOpenExp (pure . IE) av' e <*> rebuildFun (pure . IE) av' f
Transform repr sh ix f a -> Transform repr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' ix <*> rebuildFun (pure . IE) av' f <*> k av a
Replicate sl slix a -> Replicate sl <$> rebuildOpenExp (pure . IE) av' slix <*> k av a
Slice sl a slix -> Slice sl <$> k av a <*> rebuildOpenExp (pure . IE) av' slix
Map tp f a -> Map tp <$> rebuildFun (pure . IE) av' f <*> k av a
ZipWith tp f a1 a2 -> ZipWith tp <$> rebuildFun (pure . IE) av' f <*> k av a1 <*> k av a2
Fold f z a -> Fold <$> rebuildFun (pure . IE) av' f <*> rebuildMaybeExp (pure . IE) av' z <*> k av a
FoldSeg itp f z a s -> FoldSeg itp <$> rebuildFun (pure . IE) av' f <*> rebuildMaybeExp (pure . IE) av' z <*> k av a <*> k av s
Scan d f z a -> Scan d <$> rebuildFun (pure . IE) av' f <*> rebuildMaybeExp (pure . IE) av' z <*> k av a
Scan' d f z a -> Scan' d <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a
Permute f1 a1 f2 a2 -> Permute <$> rebuildFun (pure . IE) av' f1 <*> k av a1 <*> rebuildFun (pure . IE) av' f2 <*> k av a2
Backpermute shr sh f a -> Backpermute shr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' f <*> k av a
Stencil sr tp f b a -> Stencil sr tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b <*> k av a
Stencil2 s1 s2 tp f b1 a1 b2 a2
-> Stencil2 s1 s2 tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b1 <*> k av a1 <*> rebuildBoundary av' b2 <*> k av a2
Aforeign repr ff afun as -> Aforeign repr ff afun <$> k av as
where
av' = reindexAvar av
{-# INLINEABLE rebuildAfun #-}
rebuildAfun
:: (HasCallStack, Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> RebuildAvar f fa acc aenv aenv'
-> PreOpenAfun acc aenv t
-> f (PreOpenAfun acc aenv' t)
rebuildAfun k av (Abody b) = Abody <$> k av b
rebuildAfun k av (Alam lhs1 f)
| Exists lhs2 <- rebuildLHS lhs1
= Alam lhs2 <$> rebuildAfun k (shiftA' lhs1 lhs2 k av) f
rebuildAlet
:: forall f fa acc aenv1 aenv1' aenv2 bndArrs arrs. (HasCallStack, Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> RebuildAvar f fa acc aenv1 aenv2
-> ALeftHandSide bndArrs aenv1 aenv1'
-> acc aenv1 bndArrs
-> acc aenv1' arrs
-> f (PreOpenAcc acc aenv2 arrs)
rebuildAlet k av lhs1 bind1 body1
| Exists lhs2 <- rebuildLHS lhs1
= Alet lhs2 <$> k av bind1 <*> k (shiftA' lhs1 lhs2 k av) body1
{-# INLINEABLE rebuildLHS #-}
rebuildLHS :: LeftHandSide s t aenv1 aenv1' -> Exists (LeftHandSide s t aenv2)
rebuildLHS (LeftHandSideWildcard r) = Exists $ LeftHandSideWildcard r
rebuildLHS (LeftHandSideSingle s) = Exists $ LeftHandSideSingle s
rebuildLHS (LeftHandSidePair as bs)
| Exists as' <- rebuildLHS as
, Exists bs' <- rebuildLHS bs
= Exists $ LeftHandSidePair as' bs'
{-# INLINEABLE rebuildBoundary #-}
rebuildBoundary
:: Applicative f
=> ReindexAvar f aenv aenv'
-> Boundary aenv t
-> f (Boundary aenv' t)
rebuildBoundary av bndy =
case bndy of
Clamp -> pure Clamp
Mirror -> pure Mirror
Wrap -> pure Wrap
Constant v -> pure (Constant v)
Function f -> Function <$> rebuildFun (pure . IE) av f
extractExpVars :: OpenExp env aenv a -> Maybe (ExpVars env a)
extractExpVars Nil = Just TupRunit
extractExpVars (Pair e1 e2) = TupRpair <$> extractExpVars e1 <*> extractExpVars e2
extractExpVars (Evar v) = Just $ TupRsingle v
extractExpVars _ = Nothing