{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
module Data.Array.Accelerate.Trafo.Fusion (
DelayedAcc, DelayedOpenAcc(..),
DelayedAfun, DelayedOpenAfun,
DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun,
convertAcc, convertAfun,
) where
import Prelude hiding ( exp, until )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Trafo.Base
import Data.Array.Accelerate.Trafo.Shrink
import Data.Array.Accelerate.Trafo.Simplify
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Array.Sugar ( Array, Arrays(..), ArraysR(..), ArrRepr
, Elt, EltRepr, Shape, Tuple(..), Atuple(..)
, IsAtuple, TupleRepr )
import Data.Array.Accelerate.Product
import qualified Data.Array.Accelerate.Debug as Stats
#ifdef ACCELERATE_DEBUG
import System.IO.Unsafe
#endif
convertAcc :: Arrays arrs => Bool -> Acc arrs -> DelayedAcc arrs
convertAcc fuseAcc = withSimplStats . convertOpenAcc fuseAcc
convertAfun :: Bool -> Afun f -> DelayedAfun f
convertAfun fuseAcc = withSimplStats . convertOpenAfun fuseAcc
withSimplStats :: a -> a
#ifdef ACCELERATE_DEBUG
withSimplStats x = unsafePerformIO Stats.resetSimplCount `seq` x
#else
withSimplStats x = x
#endif
convertOpenAcc :: Arrays arrs => Bool -> OpenAcc aenv arrs -> DelayedOpenAcc aenv arrs
convertOpenAcc fuseAcc = manifest fuseAcc . computeAcc . embedOpenAcc fuseAcc
delayed :: (Shape sh, Elt e) => Bool -> OpenAcc aenv (Array sh e) -> DelayedOpenAcc aenv (Array sh e)
delayed fuseAcc (embedOpenAcc fuseAcc -> Embed BaseEnv cc) =
case cc of
Done v -> Delayed (arrayShape v) (indexArray v) (linearIndex v)
Yield (cvtE -> sh) (cvtF -> f) -> Delayed sh f (f `compose` fromIndex sh)
Step (cvtE -> sh) (cvtF -> p) (cvtF -> f) v
| Just Refl <- match sh (arrayShape v)
, Just Refl <- isIdentity p
-> Delayed sh (f `compose` indexArray v) (f `compose` linearIndex v)
| f' <- f `compose` indexArray v `compose` p
-> Delayed sh f' (f' `compose` fromIndex sh)
where
cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t
cvtE = convertOpenExp fuseAcc
cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f
cvtF (Lam f) = Lam (cvtF f)
cvtF (Body b) = Body (cvtE b)
manifest :: Bool -> OpenAcc aenv a -> DelayedOpenAcc aenv a
manifest fuseAcc (OpenAcc pacc) =
let fusionError = $internalError "manifest" "unexpected fusible materials"
in
Manifest $ case pacc of
Avar ix -> Avar ix
Use arr -> Use arr
Unit e -> Unit (cvtE e)
Alet bnd body -> alet (manifest fuseAcc bnd) (manifest fuseAcc body)
Acond p t e -> Acond (cvtE p) (manifest fuseAcc t) (manifest fuseAcc e)
Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (manifest fuseAcc a)
Atuple tup -> Atuple (cvtAT tup)
Aprj ix tup -> Aprj ix (manifest fuseAcc tup)
Apply f a -> Apply (cvtAF f) (manifest fuseAcc a)
Aforeign ff f a -> Aforeign ff (cvtAF f) (manifest fuseAcc a)
Map f a -> Map (cvtF f) (delayed fuseAcc a)
Generate sh f -> Generate (cvtE sh) (cvtF f)
Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (delayed fuseAcc a)
Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (delayed fuseAcc a)
Reshape sl a -> Reshape (cvtE sl) (manifest fuseAcc a)
Replicate{} -> fusionError
Slice{} -> fusionError
ZipWith{} -> fusionError
Fold f z a -> Fold (cvtF f) (cvtE z) (delayed fuseAcc a)
Fold1 f a -> Fold1 (cvtF f) (delayed fuseAcc a)
FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (delayed fuseAcc a) (delayed fuseAcc s)
Fold1Seg f a s -> Fold1Seg (cvtF f) (delayed fuseAcc a) (delayed fuseAcc s)
Scanl f z a -> Scanl (cvtF f) (cvtE z) (delayed fuseAcc a)
Scanl1 f a -> Scanl1 (cvtF f) (delayed fuseAcc a)
Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (delayed fuseAcc a)
Scanr f z a -> Scanr (cvtF f) (cvtE z) (delayed fuseAcc a)
Scanr1 f a -> Scanr1 (cvtF f) (delayed fuseAcc a)
Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (delayed fuseAcc a)
Permute f d p a -> Permute (cvtF f) (manifest fuseAcc d) (cvtF p) (delayed fuseAcc a)
Stencil f x a -> Stencil (cvtF f) (cvtB x) (manifest fuseAcc a)
Stencil2 f x a y b -> Stencil2 (cvtF f) (cvtB x) (manifest fuseAcc a) (cvtB y) (manifest fuseAcc b)
where
alet bnd body
| Manifest (Avar ZeroIdx) <- body
, Manifest x <- bnd
= x
| otherwise
= Alet bnd body
cvtAT :: Atuple (OpenAcc aenv) a -> Atuple (DelayedOpenAcc aenv) a
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup t a) = cvtAT t `SnocAtup` manifest fuseAcc a
cvtAF :: OpenAfun aenv f -> PreOpenAfun DelayedOpenAcc aenv f
cvtAF (Alam f) = Alam (cvtAF f)
cvtAF (Abody b) = Abody (manifest fuseAcc b)
cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f
cvtF (Lam f) = Lam (cvtF f)
cvtF (Body b) = Body (cvtE b)
cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t
cvtE = convertOpenExp fuseAcc
cvtB :: Boundary aenv t -> PreBoundary DelayedOpenAcc aenv t
cvtB Clamp = Clamp
cvtB Mirror = Mirror
cvtB Wrap = Wrap
cvtB (Constant v) = Constant v
cvtB (Function f) = Function (cvtF f)
convertOpenExp :: Bool -> OpenExp env aenv t -> DelayedOpenExp env aenv t
convertOpenExp fuseAcc exp =
case exp of
Let bnd body -> Let (cvtE bnd) (cvtE body)
Var ix -> Var ix
Const c -> Const c
Tuple tup -> Tuple (cvtT tup)
Prj ix t -> Prj ix (cvtE t)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (cvtE sh) (cvtE sz)
IndexHead sh -> IndexHead (cvtE sh)
IndexTail sh -> IndexTail (cvtE sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix)
FromIndex sh ix -> FromIndex (cvtE sh) (cvtE ix)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
While p f x -> While (cvtF p) (cvtF f) (cvtE x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (cvtE x)
Index a sh -> Index (manifest fuseAcc a) (cvtE sh)
LinearIndex a i -> LinearIndex (manifest fuseAcc a) (cvtE i)
Shape a -> Shape (manifest fuseAcc a)
ShapeSize sh -> ShapeSize (cvtE sh)
Intersect s t -> Intersect (cvtE s) (cvtE t)
Union s t -> Union (cvtE s) (cvtE t)
Foreign ff f e -> Foreign ff (cvtF f) (cvtE e)
where
cvtT :: Tuple (OpenExp env aenv) t -> Tuple (DelayedOpenExp env aenv) t
cvtT NilTup = NilTup
cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e
cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f
cvtF (Lam f) = Lam (cvtF f)
cvtF (Body b) = Body (cvtE b)
cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t
cvtE = convertOpenExp fuseAcc
convertOpenAfun :: Bool -> OpenAfun aenv f -> DelayedOpenAfun aenv f
convertOpenAfun c (Alam f) = Alam (convertOpenAfun c f)
convertOpenAfun c (Abody b) = Abody (convertOpenAcc c b)
type EmbedAcc acc = forall aenv arrs. Arrays arrs => acc aenv arrs -> Embed acc aenv arrs
type ElimAcc acc = forall aenv s t. acc aenv s -> acc (aenv,s) t -> Bool
embedOpenAcc :: Arrays arrs => Bool -> OpenAcc aenv arrs -> Embed OpenAcc aenv arrs
embedOpenAcc fuseAcc (OpenAcc pacc) =
embedPreAcc fuseAcc (embedOpenAcc fuseAcc) elimOpenAcc pacc
where
elimOpenAcc :: ElimAcc OpenAcc
elimOpenAcc _bnd body
| count False ZeroIdx body <= lIMIT = True
| otherwise = False
where
lIMIT = 1
count :: UsesOfAcc OpenAcc
count no ix (OpenAcc pacc) = usesOfPreAcc no count ix pacc
embedPreAcc
:: forall acc aenv arrs. (Kit acc, Arrays arrs)
=> Bool
-> EmbedAcc acc
-> ElimAcc acc
-> PreOpenAcc acc aenv arrs
-> Embed acc aenv arrs
embedPreAcc fuseAcc embedAcc elimAcc pacc
= unembed
$ case pacc of
Apply f a -> applyD (cvtAF f) (cvtA a)
Alet bnd body -> aletD embedAcc elimAcc bnd body
Aprj ix tup -> aprjD embedAcc ix tup
Acond p at ae -> acondD embedAcc (cvtE p) at ae
Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a)
Atuple tup -> done $ Atuple (cvtAT tup)
Aforeign ff f a -> done $ Aforeign ff (cvtAF f) (cvtA a)
Avar v -> done $ Avar v
Use arrs -> done $ Use arrs
Unit e -> done $ Unit (cvtE e)
Generate sh f -> generateD (cvtE sh) (cvtF f)
Map f a -> mapD (cvtF f) (embedAcc a)
ZipWith f a b -> fuse2 (into zipWithD (cvtF f)) a b
Transform sh p f a -> transformD (cvtE sh) (cvtF p) (cvtF f) (embedAcc a)
Backpermute sl p a -> fuse (into2 backpermuteD (cvtE sl) (cvtF p)) a
Slice slix a sl -> fuse (into (sliceD slix) (cvtE sl)) a
Replicate slix sh a -> fuse (into (replicateD slix) (cvtE sh)) a
Reshape sl a -> reshapeD (embedAcc a) (cvtE sl)
Fold f z a -> embed (into2 Fold (cvtF f) (cvtE z)) a
Fold1 f a -> embed (into Fold1 (cvtF f)) a
FoldSeg f z a s -> embed2 (into2 FoldSeg (cvtF f) (cvtE z)) a s
Fold1Seg f a s -> embed2 (into Fold1Seg (cvtF f)) a s
Scanl f z a -> embed (into2 Scanl (cvtF f) (cvtE z)) a
Scanl1 f a -> embed (into Scanl1 (cvtF f)) a
Scanl' f z a -> embed (into2 Scanl' (cvtF f) (cvtE z)) a
Scanr f z a -> embed (into2 Scanr (cvtF f) (cvtE z)) a
Scanr1 f a -> embed (into Scanr1 (cvtF f)) a
Scanr' f z a -> embed (into2 Scanr' (cvtF f) (cvtE z)) a
Permute f d p a -> embed2 (into2 permute (cvtF f) (cvtF p)) d a
Stencil f x a -> lift (into2 Stencil (cvtF f) (cvtB x)) a
Stencil2 f x a y b -> lift2 (into3 stencil2 (cvtF f) (cvtB x) (cvtB y)) a b
where
unembed :: Embed acc aenv arrs -> Embed acc aenv arrs
unembed x
| fuseAcc = x
| otherwise = done (compute x)
cvtA :: Arrays a => acc aenv' a -> acc aenv' a
cvtA = computeAcc . embedAcc
cvtAT :: Atuple (acc aenv') a -> Atuple (acc aenv') a
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup tup a) = cvtAT tup `SnocAtup` cvtA a
cvtAF :: PreOpenAfun acc aenv' f -> PreOpenAfun acc aenv' f
cvtAF (Alam f) = Alam (cvtAF f)
cvtAF (Abody a) = Abody (cvtA a)
permute f p d a = Permute f d p a
stencil2 f x y a b = Stencil2 f x a y b
cvtF :: PreFun acc aenv' t -> PreFun acc aenv' t
cvtF = simplify
cvtE :: Elt t => PreExp acc aenv' t -> PreExp acc aenv' t
cvtE = simplify
cvtB :: PreBoundary acc aenv' t -> PreBoundary acc aenv' t
cvtB Clamp = Clamp
cvtB Mirror = Mirror
cvtB Wrap = Wrap
cvtB (Constant c) = Constant c
cvtB (Function f) = Function (cvtF f)
into :: Sink f => (f env' a -> b) -> f env a -> Extend acc env env' -> b
into op a env = op (sink env a)
into2 :: (Sink f1, Sink f2)
=> (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend acc env env' -> c
into2 op a b env = op (sink env a) (sink env b)
into3 :: (Sink f1, Sink f2, Sink f3)
=> (f1 env' a -> f2 env' b -> f3 env' c -> d) -> f1 env a -> f2 env b -> f3 env c -> Extend acc env env' -> d
into3 op a b c env = op (sink env a) (sink env b) (sink env c)
fuse :: Arrays as
=> (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs)
-> acc aenv as
-> Embed acc aenv bs
fuse op (embedAcc -> Embed env cc) = Embed env (op env cc)
fuse2 :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs -> Cunctation acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Embed acc aenv cs
fuse2 op a1 a0
| Embed env1 cc1 <- embedAcc a1
, Embed env0 cc0 <- embedAcc (sink env1 a0)
, env <- env1 `append` env0
= Embed env (op env (sink env0 cc1) cc0)
embed :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs)
-> acc aenv as
-> Embed acc aenv bs
embed = trav1 id
embed2 :: forall aenv as bs cs. (Arrays as, Arrays bs, Arrays cs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Embed acc aenv cs
embed2 = trav2 id id
lift :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs)
-> acc aenv as
-> Embed acc aenv bs
lift = trav1 bind
lift2 :: forall aenv as bs cs. (Arrays as, Arrays bs, Arrays cs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Embed acc aenv cs
lift2 = trav2 bind bind
trav1 :: (Arrays as, Arrays bs)
=> (forall aenv'. Embed acc aenv' as -> Embed acc aenv' as)
-> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs)
-> acc aenv as
-> Embed acc aenv bs
trav1 f op (f . embedAcc -> Embed env cc)
= Embed (env `PushEnv` inject (op env (inject (compute' cc)))) (Done ZeroIdx)
trav2 :: forall aenv as bs cs. (Arrays as, Arrays bs, Arrays cs)
=> (forall aenv'. Embed acc aenv' as -> Embed acc aenv' as)
-> (forall aenv'. Embed acc aenv' bs -> Embed acc aenv' bs)
-> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Embed acc aenv cs
trav2 f1 f0 op (f1 . embedAcc -> Embed env1 cc1) (f0 . embedAcc . sink env1 -> Embed env0 cc0)
| env <- env1 `append` env0
, acc1 <- inject . compute' $ sink env0 cc1
, acc0 <- inject . compute' $ cc0
= Embed (env `PushEnv` inject (op env acc1 acc0)) (Done ZeroIdx)
bind :: Arrays as => Embed acc aenv' as -> Embed acc aenv' as
bind (Embed env cc)
| Done{} <- cc = Embed env cc
| otherwise = Embed (env `PushEnv` inject (compute' cc)) (Done ZeroIdx)
data Embed acc aenv a where
Embed :: Extend acc aenv aenv'
-> Cunctation acc aenv' a
-> Embed acc aenv a
data Cunctation acc aenv a where
Done :: Arrays a
=> Idx aenv a
-> Cunctation acc aenv a
Yield :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> Cunctation acc aenv (Array sh e)
Step :: (Shape sh, Shape sh', Elt a, Elt b)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> PreFun acc aenv (a -> b)
-> Idx aenv (Array sh a)
-> Cunctation acc aenv (Array sh' b)
instance Kit acc => Simplify (Cunctation acc aenv a) where
simplify (Done v) = Done v
simplify (Yield sh f) = Yield (simplify sh) (simplify f)
simplify (Step sh p f v) = Step (simplify sh) (simplify p) (simplify f) v
done :: (Arrays a, Kit acc) => PreOpenAcc acc aenv a -> Embed acc aenv a
done pacc
| Avar v <- pacc = Embed BaseEnv (Done v)
| otherwise = Embed (BaseEnv `PushEnv` inject pacc) (Done ZeroIdx)
yield :: Kit acc
=> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sh e)
yield cc =
case cc of
Yield{} -> cc
Step sh p f v -> Yield sh (f `compose` indexArray v `compose` p)
Done v
| ArraysRarray <- accType cc -> Yield (arrayShape v) (indexArray v)
| otherwise -> error "yield: impossible case"
step :: Kit acc
=> Cunctation acc aenv (Array sh e)
-> Maybe (Cunctation acc aenv (Array sh e))
step cc =
case cc of
Yield{} -> Nothing
Step{} -> Just cc
Done v
| ArraysRarray <- accType cc -> Just $ Step (arrayShape v) identity identity v
| otherwise -> error "step: impossible case"
shape :: Kit acc => Cunctation acc aenv (Array sh e) -> PreExp acc aenv sh
shape cc
| Just (Step sh _ _ _) <- step cc = sh
| Yield sh _ <- yield cc = sh
accType :: forall acc aenv a. Arrays a => Cunctation acc aenv a -> ArraysR (ArrRepr a)
accType _ = arrays (undefined :: a)
instance Kit acc => Sink (Cunctation acc) where
weaken k cc = case cc of
Done v -> Done (weaken k v)
Step sh p f v -> Step (weaken k sh) (weaken k p) (weaken k f) (weaken k v)
Yield sh f -> Yield (weaken k sh) (weaken k f)
compute :: (Kit acc, Arrays arrs) => Embed acc aenv arrs -> PreOpenAcc acc aenv arrs
compute (Embed env cc) = bind env (compute' cc)
compute' :: (Kit acc, Arrays arrs) => Cunctation acc aenv arrs -> PreOpenAcc acc aenv arrs
compute' cc = case simplify cc of
Done v -> Avar v
Yield sh f -> Generate sh f
Step sh p f v
| Just Refl <- match sh (simplify (arrayShape v))
, Just Refl <- isIdentity p
, Just Refl <- isIdentity f -> Avar v
| Just Refl <- match sh (simplify (arrayShape v))
, Just Refl <- isIdentity p -> Map f (avarIn v)
| Just Refl <- isIdentity f -> Backpermute sh p (avarIn v)
| otherwise -> Transform sh p f (avarIn v)
computeAcc :: (Kit acc, Arrays arrs) => Embed acc aenv arrs -> acc aenv arrs
computeAcc = inject . compute
generateD :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> Embed acc aenv (Array sh e)
generateD sh f
= Stats.ruleFired "generateD"
$ Embed BaseEnv (Yield sh f)
mapD :: (Kit acc, Shape sh, Elt b)
=> PreFun acc aenv (a -> b)
-> Embed acc aenv (Array sh a)
-> Embed acc aenv (Array sh b)
mapD f (unzipD f -> Just a) = a
mapD f (Embed env cc)
= Stats.ruleFired "mapD"
$ Embed env (go cc)
where
go (step -> Just (Step sh ix g v)) = Step sh ix (sink env f `compose` g) v
go (yield -> Yield sh g) = Yield sh (sink env f `compose` g)
unzipD
:: (Kit acc, Shape sh, Elt b)
=> PreFun acc aenv (a -> b)
-> Embed acc aenv (Array sh a)
-> Maybe (Embed acc aenv (Array sh b))
unzipD f (Embed env (Done v))
| Lam (Body (Prj tix (Var ZeroIdx))) <- f
= Stats.ruleFired "unzipD"
$ let f' = Lam (Body (Prj tix (Var ZeroIdx)))
a' = avarIn v
in
Just $ Embed (env `PushEnv` inject (Map f' a')) (Done ZeroIdx)
| Lam (Body (Prj tix p@Prj{})) <- f
, Just (Embed env' (Done v')) <- unzipD (Lam (Body p)) (Embed env (Done v))
= Stats.ruleFired "unzipD"
$ let f' = Lam (Body (Prj tix (Var ZeroIdx)))
a' = avarIn v'
in
Just $ Embed (env' `PushEnv` inject (Map f' a')) (Done ZeroIdx)
unzipD _ _
= Nothing
backpermuteD
:: (Kit acc, Shape sh')
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sh' e)
backpermuteD sh' p = Stats.ruleFired "backpermuteD" . go
where
go (step -> Just (Step _ q f v)) = Step sh' (q `compose` p) f v
go (yield -> Yield _ g) = Yield sh' (g `compose` p)
transformD
:: (Kit acc, Shape sh, Shape sh', Elt b)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> PreFun acc aenv (a -> b)
-> Embed acc aenv (Array sh a)
-> Embed acc aenv (Array sh' b)
transformD sh' p f
= Stats.ruleFired "transformD"
. fuse (into2 backpermuteD sh' p)
. mapD f
where
fuse :: (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs)
-> Embed acc aenv as
-> Embed acc aenv bs
fuse op (Embed env cc) = Embed env (op env cc)
into2 :: (Sink f1, Sink f2)
=> (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend acc env env' -> c
into2 op a b env = op (sink env a) (sink env b)
replicateD
:: (Kit acc, Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> Cunctation acc aenv (Array sl e)
-> Cunctation acc aenv (Array sh e)
replicateD sliceIndex slix cc
= Stats.ruleFired "replicateD"
$ backpermuteD (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc
sliceD
:: (Kit acc, Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sl e)
sliceD sliceIndex slix cc
= Stats.ruleFired "sliceD"
$ backpermuteD (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc
reshapeD
:: (Kit acc, Shape sh, Shape sl, Elt e)
=> Embed acc aenv (Array sh e)
-> PreExp acc aenv sl
-> Embed acc aenv (Array sl e)
reshapeD (Embed env cc) (sink env -> sl)
| Done v <- cc
= Embed (env `PushEnv` inject (Reshape sl (avarIn v))) (Done ZeroIdx)
| otherwise
= Stats.ruleFired "reshapeD"
$ Embed env (backpermuteD sl (reindex (shape cc) sl) cc)
zipWithD :: (Kit acc, Shape sh, Elt a, Elt b, Elt c)
=> PreFun acc aenv (a -> b -> c)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh b)
-> Cunctation acc aenv (Array sh c)
zipWithD f cc1 cc0
| Just (Step sh1 p1 f1 v1) <- step cc1
, Just (Step sh0 p0 f0 v0) <- step cc0
, Just Refl <- match v1 v0
, Just Refl <- match p1 p0
= Stats.ruleFired "zipWithD/step"
$ Step (sh1 `Intersect` sh0) p0 (combine f f1 f0) v0
| Yield sh1 f1 <- yield cc1
, Yield sh0 f0 <- yield cc0
= Stats.ruleFired "zipWithD"
$ Yield (sh1 `Intersect` sh0) (combine f f1 f0)
where
combine :: forall acc aenv a b c e. (Kit acc, Elt a, Elt b, Elt c)
=> PreFun acc aenv (a -> b -> c)
-> PreFun acc aenv (e -> a)
-> PreFun acc aenv (e -> b)
-> PreFun acc aenv (e -> c)
combine c ixa ixb
| Lam (Lam (Body c')) <- weakenE SuccIdx c :: PreOpenFun acc ((),e) aenv (a -> b -> c)
, Lam (Body ixa') <- ixa
, Lam (Body ixb') <- ixb
= Lam $ Body $ Let ixa' $ Let (weakenE SuccIdx ixb') c'
aletD :: (Kit acc, Arrays arrs, Arrays brrs)
=> EmbedAcc acc
-> ElimAcc acc
-> acc aenv arrs
-> acc (aenv,arrs) brrs
-> Embed acc aenv brrs
aletD embedAcc elimAcc (embedAcc -> Embed env1 cc1) acc0
| Done v1 <- cc1
, Embed env0 cc0 <- embedAcc $ rebuildA (subAtop (Avar v1) . sink1 env1) acc0
= Stats.ruleFired "aletD/float"
$ Embed (env1 `append` env0) cc0
| otherwise
= aletD' embedAcc elimAcc (Embed env1 cc1) (embedAcc acc0)
aletD' :: forall acc aenv arrs brrs. (Kit acc, Arrays arrs, Arrays brrs)
=> EmbedAcc acc
-> ElimAcc acc
-> Embed acc aenv arrs
-> Embed acc (aenv, arrs) brrs
-> Embed acc aenv brrs
aletD' embedAcc elimAcc (Embed env1 cc1) (Embed env0 cc0)
| acc1 <- compute (Embed env1 cc1)
, False <- elimAcc (inject acc1) acc0
= Stats.ruleFired "aletD/bind"
$ Embed (BaseEnv `PushEnv` inject acc1 `append` env0) cc0
| acc0' <- sink1 env1 acc0
= Stats.ruleFired "aletD/eliminate"
$ case cc1 of
Step{} -> eliminate env1 cc1 acc0'
Yield{} -> eliminate env1 cc1 acc0'
where
acc0 :: acc (aenv, arrs) brrs
acc0 = computeAcc (Embed env0 cc0)
eliminate :: forall aenv aenv' sh e brrs. (Shape sh, Elt e, Arrays brrs)
=> Extend acc aenv aenv'
-> Cunctation acc aenv' (Array sh e)
-> acc (aenv', Array sh e) brrs
-> Embed acc aenv brrs
eliminate env1 cc1 body
| Done v1 <- cc1 = elim (arrayShape v1) (indexArray v1)
| Step sh1 p1 f1 v1 <- cc1 = elim sh1 (f1 `compose` indexArray v1 `compose` p1)
| Yield sh1 f1 <- cc1 = elim sh1 f1
where
bnd :: PreOpenAcc acc aenv' (Array sh e)
bnd = compute' cc1
elim :: PreExp acc aenv' sh -> PreFun acc aenv' (sh -> e) -> Embed acc aenv brrs
elim sh1 f1
| sh1' <- weaken SuccIdx sh1
, f1' <- weaken SuccIdx f1
, Embed env0' cc0' <- embedAcc $ rebuildA (subAtop bnd) $ kmap (replaceA sh1' f1' ZeroIdx) body
= Embed (env1 `append` env0') cc0'
replaceE :: forall env aenv sh e t. (Shape sh, Elt e)
=> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv t
replaceE sh' f' avar exp =
case exp of
Let x y -> Let (cvtE x) (replaceE (weakenE SuccIdx sh') (weakenE SuccIdx f') avar y)
Var i -> Var i
Foreign ff f e -> Foreign ff f (cvtE e)
Const c -> Const c
Tuple t -> Tuple (cvtT t)
Prj ix e -> Prj ix (cvtE e)
IndexNil -> IndexNil
IndexCons sl sz -> IndexCons (cvtE sl) (cvtE sz)
IndexHead sh -> IndexHead (cvtE sh)
IndexTail sz -> IndexTail (cvtE sz)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix)
FromIndex sh i -> FromIndex (cvtE sh) (cvtE i)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
PrimConst c -> PrimConst c
PrimApp g x -> PrimApp g (cvtE x)
ShapeSize sh -> ShapeSize (cvtE sh)
Intersect sh sl -> Intersect (cvtE sh) (cvtE sl)
Union s t -> Union (cvtE s) (cvtE t)
While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x)
Shape a
| Just Refl <- match a a' -> Stats.substitution "replaceE/shape" sh'
| otherwise -> exp
Index a sh
| Just Refl <- match a a'
, Lam (Body b) <- f' -> Stats.substitution "replaceE/!" . cvtE $ Let sh b
| otherwise -> Index a (cvtE sh)
LinearIndex a i
| Just Refl <- match a a'
, Lam (Body b) <- f' -> Stats.substitution "replaceE/!!" . cvtE $ Let (Let i (FromIndex (weakenE SuccIdx sh') (Var ZeroIdx))) b
| otherwise -> LinearIndex a (cvtE i)
where
a' :: acc aenv (Array sh e)
a' = avarIn avar
cvtE :: PreOpenExp acc env aenv s -> PreOpenExp acc env aenv s
cvtE = replaceE sh' f' avar
cvtT :: Tuple (PreOpenExp acc env aenv) s -> Tuple (PreOpenExp acc env aenv) s
cvtT NilTup = NilTup
cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e
replaceF :: forall env aenv sh e t. (Shape sh, Elt e)
=> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenFun acc env aenv t
-> PreOpenFun acc env aenv t
replaceF sh' f' avar fun =
case fun of
Body e -> Body (replaceE sh' f' avar e)
Lam f -> Lam (replaceF (weakenE SuccIdx sh') (weakenE SuccIdx f') avar f)
replaceA :: forall aenv sh e a. (Shape sh, Elt e)
=> PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenAcc acc aenv a
-> PreOpenAcc acc aenv a
replaceA sh' f' avar pacc =
case pacc of
Avar v
| Just Refl <- match v avar -> Avar avar
| otherwise -> Avar v
Alet bnd body ->
let sh'' = weaken SuccIdx sh'
f'' = weaken SuccIdx f'
in
Alet (cvtA bnd) (kmap (replaceA sh'' f'' (SuccIdx avar)) body)
Use arrs -> Use arrs
Unit e -> Unit (cvtE e)
Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae)
Aprj ix tup -> Aprj ix (cvtA tup)
Atuple tup -> Atuple (cvtAT tup)
Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (cvtA a)
Apply f a -> Apply (cvtAF f) (cvtA a)
Aforeign ff f a -> Aforeign ff f (cvtA a)
Generate sh f -> Generate (cvtE sh) (cvtF f)
Map f a -> Map (cvtF f) (cvtA a)
ZipWith f a b -> ZipWith (cvtF f) (cvtA a) (cvtA b)
Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (cvtA a)
Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (cvtA a)
Slice slix a sl -> Slice slix (cvtA a) (cvtE sl)
Replicate slix sh a -> Replicate slix (cvtE sh) (cvtA a)
Reshape sl a -> Reshape (cvtE sl) (cvtA a)
Fold f z a -> Fold (cvtF f) (cvtE z) (cvtA a)
Fold1 f a -> Fold1 (cvtF f) (cvtA a)
FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (cvtA a) (cvtA s)
Fold1Seg f a s -> Fold1Seg (cvtF f) (cvtA a) (cvtA s)
Scanl f z a -> Scanl (cvtF f) (cvtE z) (cvtA a)
Scanl1 f a -> Scanl1 (cvtF f) (cvtA a)
Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (cvtA a)
Scanr f z a -> Scanr (cvtF f) (cvtE z) (cvtA a)
Scanr1 f a -> Scanr1 (cvtF f) (cvtA a)
Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (cvtA a)
Permute f d p a -> Permute (cvtF f) (cvtA d) (cvtF p) (cvtA a)
Stencil f x a -> Stencil (cvtF f) (cvtB x) (cvtA a)
Stencil2 f x a y b -> Stencil2 (cvtF f) (cvtB x) (cvtA a) (cvtB y) (cvtA b)
where
cvtA :: acc aenv s -> acc aenv s
cvtA = kmap (replaceA sh' f' avar)
cvtE :: PreExp acc aenv s -> PreExp acc aenv s
cvtE = replaceE sh' f' avar
cvtF :: PreFun acc aenv s -> PreFun acc aenv s
cvtF = replaceF sh' f' avar
cvtB :: PreBoundary acc aenv s -> PreBoundary acc aenv s
cvtB Clamp = Clamp
cvtB Mirror = Mirror
cvtB Wrap = Wrap
cvtB (Constant c) = Constant c
cvtB (Function f) = Function (cvtF f)
cvtAT :: Atuple (acc aenv) s -> Atuple (acc aenv) s
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup tup a) = cvtAT tup `SnocAtup` cvtA a
cvtAF :: PreOpenAfun acc aenv s -> PreOpenAfun acc aenv s
cvtAF = cvt sh' f' avar
where
cvt :: forall aenv a.
PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenAfun acc aenv a
-> PreOpenAfun acc aenv a
cvt sh'' f'' avar' (Abody a) = Abody $ kmap (replaceA sh'' f'' avar') a
cvt sh'' f'' avar' (Alam af) = Alam $ cvt (weaken SuccIdx sh'')
(weaken SuccIdx f'')
(SuccIdx avar')
af
applyD :: (Kit acc, Arrays as, Arrays bs)
=> PreOpenAfun acc aenv (as -> bs)
-> acc aenv as
-> Embed acc aenv bs
applyD afun x
| Alam (Abody body) <- afun
, Avar ZeroIdx <- extract body
= Stats.ruleFired "applyD/identity"
$ done $ extract x
| otherwise
= done $ Apply afun x
acondD :: (Kit acc, Arrays arrs)
=> EmbedAcc acc
-> PreExp acc aenv Bool
-> acc aenv arrs
-> acc aenv arrs
-> Embed acc aenv arrs
acondD embedAcc p t e
| Const True <- p = Stats.knownBranch "True" $ embedAcc t
| Const False <- p = Stats.knownBranch "False" $ embedAcc e
| Just Refl <- match t e = Stats.knownBranch "redundant" $ embedAcc e
| otherwise = done $ Acond p (computeAcc (embedAcc t))
(computeAcc (embedAcc e))
aprjD :: forall acc aenv arrs a. (Kit acc, IsAtuple arrs, Arrays arrs, Arrays a)
=> EmbedAcc acc
-> TupleIdx (TupleRepr arrs) a
-> acc aenv arrs
-> Embed acc aenv a
aprjD embedAcc ix a
| Atuple tup <- extract a = Stats.ruleFired "aprj/Atuple" . embedAcc $ aprjAT ix tup
| otherwise = done $ Aprj ix (cvtA a)
where
cvtA :: acc aenv arrs -> acc aenv arrs
cvtA = computeAcc . embedAcc
aprjAT :: TupleIdx atup a -> Atuple (acc aenv) atup -> acc aenv a
aprjAT ZeroTupIdx (SnocAtup _ a) = a
aprjAT (SuccTupIdx ix) (SnocAtup t _) = aprjAT ix t
isIdentity :: PreFun acc aenv (a -> b) -> Maybe (a :~: b)
isIdentity f
| Lam (Body (Var ZeroIdx)) <- f = Just Refl
| otherwise = Nothing
identity :: Elt a => PreOpenFun acc env aenv (a -> a)
identity = Lam (Body (Var ZeroIdx))
toIndex :: (Kit acc, Shape sh) => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> Int)
toIndex sh = Lam (Body (ToIndex (weakenE SuccIdx sh) (Var ZeroIdx)))
fromIndex :: (Kit acc, Shape sh) => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (Int -> sh)
fromIndex sh = Lam (Body (FromIndex (weakenE SuccIdx sh) (Var ZeroIdx)))
reindex :: (Kit acc, Shape sh, Shape sh')
=> PreOpenExp acc env aenv sh'
-> PreOpenExp acc env aenv sh
-> PreOpenFun acc env aenv (sh -> sh')
reindex sh' sh
| Just Refl <- match sh sh' = identity
| otherwise = fromIndex sh' `compose` toIndex sh
extend :: (Kit acc, Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> PreFun acc aenv (sh -> sl)
extend sliceIndex slix = Lam (Body (IndexSlice sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx)))
restrict :: (Kit acc, Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> PreFun acc aenv (sl -> sh)
restrict sliceIndex slix = Lam (Body (IndexFull sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx)))
arrayShape :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreExp acc aenv sh
arrayShape = Shape . avarIn
indexArray :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreFun acc aenv (sh -> e)
indexArray v = Lam (Body (Index (avarIn v) (Var ZeroIdx)))
linearIndex :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreFun acc aenv (Int -> e)
linearIndex v = Lam (Body (LinearIndex (avarIn v) (Var ZeroIdx)))