{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Trafo.Substitution (
inline, substitute, compose,
subTop, subAtop,
(:>), Sink(..), SinkExp(..),
(:?>), strengthen, strengthenE,
RebuildAcc, Rebuildable(..), RebuildableAcc,
RebuildableExp(..), RebuildTup(..)
) where
import Control.Applicative hiding ( Const )
import Prelude hiding ( exp, seq )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Array.Sugar ( Elt, Arrays, Tuple(..), Atuple(..) )
import qualified Data.Array.Accelerate.Debug.Stats as Stats
infixr `compose`
infixr `substitute`
inline :: RebuildableAcc acc
=> PreOpenExp acc (env, s) aenv t
-> PreOpenExp acc env aenv s
-> PreOpenExp acc env aenv t
inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f
substitute :: (RebuildableAcc acc, Elt b, Elt c)
=> PreOpenExp acc (env, b) aenv c
-> PreOpenExp acc (env, a) aenv b
-> PreOpenExp acc (env, a) aenv c
substitute f g
| Stats.substitution "substitute" False = undefined
| Var ZeroIdx <- g = f
| otherwise = Let g $ rebuildE split f
where
split :: Elt c => Idx (env,b) c -> PreOpenExp acc ((env,a),b) aenv c
split ZeroIdx = Var ZeroIdx
split (SuccIdx ix) = Var (SuccIdx (SuccIdx ix))
compose :: (RebuildableAcc acc, Elt c)
=> PreOpenFun acc env aenv (b -> c)
-> PreOpenFun acc env aenv (a -> b)
-> PreOpenFun acc env aenv (a -> c)
compose (Lam (Body f)) (Lam (Body g)) = Stats.substitution "compose" . Lam . Body $ substitute f g
compose _ _ = error "compose: impossible evaluation"
subTop :: Elt t => PreOpenExp acc env aenv s -> Idx (env, s) t -> PreOpenExp acc env aenv t
subTop s ZeroIdx = s
subTop _ (SuccIdx ix) = Var ix
subAtop :: Arrays t => PreOpenAcc acc aenv s -> Idx (aenv, s) t -> PreOpenAcc acc aenv t
subAtop t ZeroIdx = t
subAtop _ (SuccIdx idx) = Avar idx
data Identity a = Identity { runIdentity :: a }
instance Functor Identity where
fmap f (Identity a) = Identity (f a)
instance Applicative Identity where
Identity f <*> Identity a = Identity (f a)
pure a = Identity a
class Rebuildable f where
{-# MINIMAL rebuildPartial #-}
type AccClo f :: (* -> * -> *)
rebuildPartial :: (Applicative f', SyntacticAcc fa)
=> (forall a'. Arrays a' => Idx aenv a' -> f' (fa (AccClo f) aenv' a'))
-> f aenv a
-> f' (f aenv' a)
{-# INLINEABLE rebuildA #-}
rebuildA :: (SyntacticAcc fa)
=> (forall a'. Arrays a' => Idx aenv a' -> fa (AccClo f) aenv' a')
-> f aenv a
-> f aenv' a
rebuildA av = runIdentity . rebuildPartial (Identity . av)
class RebuildableExp f where
{-# MINIMAL rebuildPartialE #-}
rebuildPartialE :: (Applicative f', SyntacticExp fe)
=> (forall e'. Elt e' => Idx env e' -> f' (fe (AccClo (f env)) env' aenv e'))
-> f env aenv e
-> f' (f env' aenv e)
{-# INLINABLE rebuildE #-}
rebuildE :: SyntacticExp fe
=> (forall e'. Elt e' => Idx env e' -> fe (AccClo (f env)) env' aenv e')
-> f env aenv e
-> f env' aenv e
rebuildE v = runIdentity . rebuildPartialE (Identity . v)
type RebuildableAcc acc = (Rebuildable acc, AccClo acc ~ acc)
instance RebuildableAcc acc => Rebuildable (PreOpenExp acc env) where
type AccClo (PreOpenExp acc env) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial = rebuildPreOpenExp rebuildPartial (pure . IE)
instance RebuildableAcc acc => Rebuildable (PreOpenFun acc env) where
type AccClo (PreOpenFun acc env) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial = rebuildFun rebuildPartial (pure . IE)
instance RebuildableAcc acc => Rebuildable (PreOpenAcc acc) where
type AccClo (PreOpenAcc acc) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial = rebuildPreOpenAcc rebuildPartial
instance RebuildableAcc acc => Rebuildable (PreOpenAfun acc) where
type AccClo (PreOpenAfun acc) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial = rebuildAfun rebuildPartial
newtype RebuildTup acc env aenv t = RebuildTup { unRTup :: Tuple (PreOpenExp acc env aenv) t }
instance RebuildableAcc acc => Rebuildable (RebuildTup acc env) where
type AccClo (RebuildTup acc env) = acc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial v t = RebuildTup <$> rebuildTup rebuildPartial (pure . IE) v (unRTup t)
instance Rebuildable OpenAcc where
type AccClo OpenAcc = OpenAcc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial = rebuildOpenAcc
instance RebuildableAcc acc => RebuildableExp (PreOpenExp acc) where
{-# INLINEABLE rebuildPartialE #-}
rebuildPartialE v = rebuildPreOpenExp rebuildPartial v (pure . IA)
instance RebuildableAcc acc => RebuildableExp (PreOpenFun acc) where
{-# INLINEABLE rebuildPartialE #-}
rebuildPartialE v = rebuildFun rebuildPartial v (pure . IA)
type env :> env' = forall t'. Idx env t' -> Idx env' t'
class Sink f where
weaken :: env :> env' -> f env t -> f env' t
instance Sink Idx where
{-# INLINEABLE weaken #-}
weaken k = k
instance RebuildableAcc acc => Sink (PreOpenAcc acc) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
instance RebuildableAcc acc => Sink (PreOpenAfun acc) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
instance RebuildableAcc acc => Sink (PreOpenExp acc env) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
instance RebuildableAcc acc => Sink (PreOpenFun acc env) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
instance RebuildableAcc acc => Sink (RebuildTup acc env) where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
instance RebuildableAcc acc => Sink (PreBoundary acc) where
{-# INLINEABLE weaken #-}
weaken k bndy =
case bndy of
Clamp -> Clamp
Mirror -> Mirror
Wrap -> Wrap
Constant c -> Constant c
Function f -> Function (weaken k f)
instance Sink OpenAcc where
{-# INLINEABLE weaken #-}
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
class SinkExp f where
weakenE :: env :> env' -> f env aenv t -> f env' aenv t
instance RebuildableAcc acc => SinkExp (PreOpenExp acc) where
{-# INLINEABLE weakenE #-}
weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v)
instance RebuildableAcc acc => SinkExp (PreOpenFun acc) where
{-# INLINEABLE weakenE #-}
weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v)
type env :?> env' = forall t'. Idx env t' -> Maybe (Idx env' t')
{-# INLINEABLE strengthen #-}
strengthen :: Rebuildable f => env :?> env' -> f env t -> Maybe (f env' t)
strengthen k = rebuildPartial (fmap IA . k)
{-# INLINEABLE strengthenE #-}
strengthenE :: RebuildableExp f => env :?> env' -> f env aenv t -> Maybe (f env' aenv t)
strengthenE k = rebuildPartialE (fmap IE . k)
class SyntacticExp f where
varIn :: Elt t => Idx env t -> f acc env aenv t
expOut :: Elt t => f acc env aenv t -> PreOpenExp acc env aenv t
weakenExp :: Elt t => RebuildAcc acc -> f acc env aenv t -> f acc (env, s) aenv t
weakenExpAcc :: Elt t => RebuildAcc acc -> f acc env aenv t -> f acc env (aenv, s) t
newtype IdxE (acc :: * -> * -> *) env aenv t = IE { unIE :: Idx env t }
instance SyntacticExp IdxE where
varIn = IE
expOut = Var . unIE
weakenExp _ = IE . SuccIdx . unIE
weakenExpAcc _ = IE . unIE
instance SyntacticExp PreOpenExp where
varIn = Var
expOut = id
weakenExp k = runIdentity . rebuildPreOpenExp k (Identity . weakenExp k . IE) (Identity . IA)
weakenExpAcc k = runIdentity . rebuildPreOpenExp k (Identity . IE) (Identity . weakenAcc k . IA)
{-# INLINEABLE shiftE #-}
shiftE
:: (Applicative f, SyntacticExp fe, Elt t)
=> RebuildAcc acc
-> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv t'))
-> Idx (env, s) t
-> f (fe acc (env', s) aenv t)
shiftE _ _ ZeroIdx = pure $ varIn ZeroIdx
shiftE k v (SuccIdx ix) = weakenExp k <$> (v ix)
{-# INLINEABLE rebuildPreOpenExp #-}
rebuildPreOpenExp
:: (Applicative f, SyntacticExp fe, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv' t'))
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> PreOpenExp acc env aenv t
-> f (PreOpenExp acc env' aenv' t)
rebuildPreOpenExp k v av exp =
case exp of
Const c -> pure (Const c)
PrimConst c -> pure (PrimConst c)
IndexNil -> pure IndexNil
IndexAny -> pure IndexAny
Var ix -> expOut <$> v ix
Let a b -> Let <$> rebuildPreOpenExp k v av a <*> rebuildPreOpenExp k (shiftE k v) av b
Tuple tup -> Tuple <$> rebuildTup k v av tup
Prj tup e -> Prj tup <$> rebuildPreOpenExp k v av e
IndexCons sh sz -> IndexCons <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av sz
IndexHead sh -> IndexHead <$> rebuildPreOpenExp k v av sh
IndexTail sh -> IndexTail <$> rebuildPreOpenExp k v av sh
IndexSlice x ix sh -> IndexSlice x <$> rebuildPreOpenExp k v av ix <*> rebuildPreOpenExp k v av sh
IndexFull x ix sl -> IndexFull x <$> rebuildPreOpenExp k v av ix <*> rebuildPreOpenExp k v av sl
ToIndex sh ix -> ToIndex <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av ix
FromIndex sh ix -> FromIndex <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av ix
Cond p t e -> Cond <$> rebuildPreOpenExp k v av p <*> rebuildPreOpenExp k v av t <*> rebuildPreOpenExp k v av e
While p f x -> While <$> rebuildFun k v av p <*> rebuildFun k v av f <*> rebuildPreOpenExp k v av x
PrimApp f x -> PrimApp f <$> rebuildPreOpenExp k v av x
Index a sh -> Index <$> k av a <*> rebuildPreOpenExp k v av sh
LinearIndex a i -> LinearIndex <$> k av a <*> rebuildPreOpenExp k v av i
Shape a -> Shape <$> k av a
ShapeSize sh -> ShapeSize <$> rebuildPreOpenExp k v av sh
Intersect s t -> Intersect <$> rebuildPreOpenExp k v av s <*> rebuildPreOpenExp k v av t
Union s t -> Union <$> rebuildPreOpenExp k v av s <*> rebuildPreOpenExp k v av t
Foreign ff f e -> Foreign ff f <$> rebuildPreOpenExp k v av e
{-# INLINEABLE rebuildTup #-}
rebuildTup
:: (Applicative f, SyntacticExp fe, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv' t'))
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> Tuple (PreOpenExp acc env aenv) t
-> f (Tuple (PreOpenExp acc env' aenv') t)
rebuildTup k v av tup =
case tup of
NilTup -> pure NilTup
SnocTup t e -> SnocTup <$> rebuildTup k v av t <*> rebuildPreOpenExp k v av e
{-# INLINEABLE rebuildFun #-}
rebuildFun
:: (Applicative f, SyntacticExp fe, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv' t'))
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> PreOpenFun acc env aenv t
-> f (PreOpenFun acc env' aenv' t)
rebuildFun k v av fun =
case fun of
Body e -> Body <$> rebuildPreOpenExp k v av e
Lam f -> Lam <$> rebuildFun k (shiftE k v) av f
type RebuildAcc acc =
forall aenv aenv' f fa a. (Applicative f, SyntacticAcc fa)
=> (forall a'. Arrays a' => Idx aenv a' -> f (fa acc aenv' a'))
-> acc aenv a
-> f (acc aenv' a)
class SyntacticAcc f where
avarIn :: Arrays t => Idx aenv t -> f acc aenv t
accOut :: Arrays t => f acc aenv t -> PreOpenAcc acc aenv t
weakenAcc :: Arrays t => RebuildAcc acc -> f acc aenv t -> f acc (aenv, s) t
newtype IdxA (acc :: * -> * -> *) aenv t = IA { unIA :: Idx aenv t }
instance SyntacticAcc IdxA where
avarIn = IA
accOut = Avar . unIA
weakenAcc _ = IA . SuccIdx . unIA
instance SyntacticAcc PreOpenAcc where
avarIn = Avar
accOut = id
weakenAcc k = runIdentity . rebuildPreOpenAcc k (Identity . weakenAcc k . IA)
{-# INLINEABLE shiftA #-}
shiftA
:: (Applicative f, SyntacticAcc fa, Arrays t)
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> Idx (aenv, s) t
-> f (fa acc (aenv', s) t)
shiftA _ _ ZeroIdx = pure $ avarIn ZeroIdx
shiftA k v (SuccIdx ix) = weakenAcc k <$> v ix
{-# INLINEABLE rebuildOpenAcc #-}
rebuildOpenAcc
:: (Applicative f, SyntacticAcc fa)
=> (forall t'. Arrays t' => Idx aenv t' -> f (fa OpenAcc aenv' t'))
-> OpenAcc aenv t
-> f (OpenAcc aenv' t)
rebuildOpenAcc av (OpenAcc acc) = OpenAcc <$> rebuildPreOpenAcc rebuildOpenAcc av acc
{-# INLINEABLE rebuildPreOpenAcc #-}
rebuildPreOpenAcc
:: (Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> PreOpenAcc acc aenv t
-> f (PreOpenAcc acc aenv' t)
rebuildPreOpenAcc k av acc =
case acc of
Use a -> pure (Use a)
Alet a b -> Alet <$> k av a <*> k (shiftA k av) b
Avar ix -> accOut <$> av ix
Atuple tup -> Atuple <$> rebuildAtup k av tup
Aprj tup a -> Aprj tup <$> k av a
Apply f a -> Apply <$> rebuildAfun k av f <*> k av a
Acond p t e -> Acond <$> rebuildPreOpenExp k (pure . IE) av p <*> k av t <*> k av e
Awhile p f a -> Awhile <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a
Unit e -> Unit <$> rebuildPreOpenExp k (pure . IE) av e
Reshape e a -> Reshape <$> rebuildPreOpenExp k (pure . IE) av e <*> k av a
Generate e f -> Generate <$> rebuildPreOpenExp k (pure . IE) av e <*> rebuildFun k (pure . IE) av f
Transform sh ix f a -> Transform <$> rebuildPreOpenExp k (pure . IE) av sh <*> rebuildFun k (pure . IE) av ix <*> rebuildFun k (pure . IE) av f <*> k av a
Replicate sl slix a -> Replicate sl <$> rebuildPreOpenExp k (pure . IE) av slix <*> k av a
Slice sl a slix -> Slice sl <$> k av a <*> rebuildPreOpenExp k (pure . IE) av slix
Map f a -> Map <$> rebuildFun k (pure . IE) av f <*> k av a
ZipWith f a1 a2 -> ZipWith <$> rebuildFun k (pure . IE) av f <*> k av a1 <*> k av a2
Fold f z a -> Fold <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a
Fold1 f a -> Fold1 <$> rebuildFun k (pure . IE) av f <*> k av a
FoldSeg f z a s -> FoldSeg <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a <*> k av s
Fold1Seg f a s -> Fold1Seg <$> rebuildFun k (pure . IE) av f <*> k av a <*> k av s
Scanl f z a -> Scanl <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a
Scanl' f z a -> Scanl' <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a
Scanl1 f a -> Scanl1 <$> rebuildFun k (pure . IE) av f <*> k av a
Scanr f z a -> Scanr <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a
Scanr' f z a -> Scanr' <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a
Scanr1 f a -> Scanr1 <$> rebuildFun k (pure . IE) av f <*> k av a
Permute f1 a1 f2 a2 -> Permute <$> rebuildFun k (pure . IE) av f1 <*> k av a1 <*> rebuildFun k (pure . IE) av f2 <*> k av a2
Backpermute sh f a -> Backpermute <$> rebuildPreOpenExp k (pure . IE) av sh <*> rebuildFun k (pure . IE) av f <*> k av a
Stencil f b a -> Stencil <$> rebuildFun k (pure . IE) av f <*> rebuildBoundary k av b <*> k av a
Stencil2 f b1 a1 b2 a2 -> Stencil2 <$> rebuildFun k (pure . IE) av f <*> rebuildBoundary k av b1 <*> k av a1 <*> rebuildBoundary k av b2 <*> k av a2
Aforeign ff afun as -> Aforeign ff afun <$> k av as
{-# INLINEABLE rebuildAfun #-}
rebuildAfun
:: (Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> PreOpenAfun acc aenv t
-> f (PreOpenAfun acc aenv' t)
rebuildAfun k av afun =
case afun of
Abody b -> Abody <$> k av b
Alam f -> Alam <$> rebuildAfun k (shiftA k av) f
{-# INLINEABLE rebuildAtup #-}
rebuildAtup
:: (Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> Atuple (acc aenv) t
-> f (Atuple (acc aenv') t)
rebuildAtup k av atup =
case atup of
NilAtup -> pure NilAtup
SnocAtup t a -> SnocAtup <$> rebuildAtup k av t <*> k av a
{-# INLINEABLE rebuildBoundary #-}
rebuildBoundary
:: (Applicative f, SyntacticAcc fa)
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f (fa acc aenv' t'))
-> PreBoundary acc aenv t
-> f (PreBoundary acc aenv' t)
rebuildBoundary k av bndy =
case bndy of
Clamp -> pure Clamp
Mirror -> pure Mirror
Wrap -> pure Wrap
Constant v -> pure (Constant v)
Function f -> Function <$> rebuildFun k (pure . IE) av f