{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Trafo.Fusion (
convertAcc, convertAccWith,
convertAfun, convertAfunWith,
) where
import Data.BitSet
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Trafo.Config
import Data.Array.Accelerate.Trafo.Var
import Data.Array.Accelerate.Trafo.Delayed
import Data.Array.Accelerate.Trafo.Environment
import Data.Array.Accelerate.Trafo.Shrink
import Data.Array.Accelerate.Trafo.Simplify
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Representation.Array ( Array, ArrayR(..), ArraysR )
import Data.Array.Accelerate.Representation.Shape ( ShapeR(..), shapeType )
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Debug.Flags ( array_fusion )
import qualified Data.Array.Accelerate.Debug.Stats as Stats
#ifdef ACCELERATE_DEBUG
import System.IO.Unsafe
#endif
import Control.Lens ( over, mapped, _2 )
import Prelude hiding ( exp, until )
convertAcc :: HasCallStack => Acc arrs -> DelayedAcc arrs
convertAcc = convertAccWith defaultOptions
convertAccWith :: HasCallStack => Config -> Acc arrs -> DelayedAcc arrs
convertAccWith config = withSimplStats . convertOpenAcc config
convertAfun :: HasCallStack => Afun f -> DelayedAfun f
convertAfun = convertAfunWith defaultOptions
convertAfunWith :: HasCallStack => Config -> Afun f -> DelayedAfun f
convertAfunWith config = withSimplStats . convertOpenAfun config
withSimplStats :: a -> a
#ifdef ACCELERATE_DEBUG
withSimplStats x = unsafePerformIO Stats.resetSimplCount `seq` x
#else
withSimplStats x = x
#endif
convertOpenAcc
:: HasCallStack
=> Config
-> OpenAcc aenv arrs
-> DelayedOpenAcc aenv arrs
convertOpenAcc config = manifest config . computeAcc . embedOpenAcc config
delayed
:: HasCallStack
=> Config
-> OpenAcc aenv (Array sh e)
-> DelayedOpenAcc aenv (Array sh e)
delayed config (embedOpenAcc config -> Embed env cc)
| BaseEnv <- env
= case simplifyCC cc of
Done v -> avarsIn Manifest v
Yield aR sh f -> Delayed aR sh f (f `compose` fromIndex (arrayRshape aR) sh)
Step aR sh p f v
| Just Refl <- matchOpenExp sh (arrayShape v)
, Just Refl <- isIdentity p -> Delayed aR sh (f `compose` indexArray v) (f `compose` linearIndex v)
| f' <- f `compose` indexArray v `compose` p -> Delayed aR sh f' (f' `compose` fromIndex (arrayRshape aR) sh)
| otherwise
= manifest config (computeAcc (Embed env cc))
manifest
:: HasCallStack
=> Config
-> OpenAcc aenv a
-> DelayedOpenAcc aenv a
manifest config (OpenAcc pacc) =
let fusionError = internalError "unexpected fusible materials"
in
Manifest $ case pacc of
Avar ix -> Avar ix
Use aR a -> Use aR a
Unit t e -> Unit t e
Alet lhs bnd body -> alet lhs (manifest config bnd) (manifest config body)
Acond p t e -> Acond p (manifest config t) (manifest config e)
Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (manifest config a)
Apair a1 a2 -> Apair (manifest config a1) (manifest config a2)
Anil -> Anil
Apply repr f a -> apply repr (cvtAF f) (manifest config a)
Aforeign repr ff f a -> Aforeign repr ff (cvtAF f) (manifest config a)
Map t f a -> Map t f (delayed config a)
Generate repr sh f -> Generate repr sh f
Transform repr sh p f a -> Transform repr sh p f (delayed config a)
Backpermute shR sh p a -> Backpermute shR sh p (delayed config a)
Reshape slr sl a -> Reshape slr sl (manifest config a)
Replicate{} -> fusionError
Slice{} -> fusionError
ZipWith{} -> fusionError
Fold f z a -> Fold f z (delayed config a)
FoldSeg i f z a s -> FoldSeg i f z (delayed config a) (delayed config s)
Scan d f z a -> Scan d f z (delayed config a)
Scan' d f z a -> Scan' d f z (delayed config a)
Permute f d p a -> Permute f (manifest config d) p (delayed config a)
Stencil s t f x a -> Stencil s t f x (delayed config a)
Stencil2 s1 s2 t f x a y b
-> Stencil2 s1 s2 t f x (delayed config a) y (delayed config b)
where
alet :: HasCallStack
=> ALeftHandSide a aenv aenv'
-> DelayedOpenAcc aenv a
-> DelayedOpenAcc aenv' b
-> PreOpenAcc DelayedOpenAcc aenv b
alet lhs bnd body
| Just bodyVars <- extractDelayedArrayVars body
, Just Refl <- bindingIsTrivial lhs bodyVars
, Manifest x <- bnd
= x
| otherwise
= Alet lhs bnd body
apply :: HasCallStack
=> ArraysR b
-> PreOpenAfun DelayedOpenAcc aenv (a -> b)
-> DelayedOpenAcc aenv a
-> PreOpenAcc DelayedOpenAcc aenv b
apply repr afun x
| Alam lhs (Abody body) <- afun
, Just bodyVars <- extractDelayedArrayVars body
, Just Refl <- bindingIsTrivial lhs bodyVars
, Manifest x' <- x
= Stats.ruleFired "applyD/identity" x'
| otherwise
= Apply repr afun x
cvtAF :: HasCallStack => OpenAfun aenv f -> PreOpenAfun DelayedOpenAcc aenv f
cvtAF (Alam lhs f) = Alam lhs (cvtAF f)
cvtAF (Abody b) = Abody (manifest config b)
convertOpenAfun :: HasCallStack => Config -> OpenAfun aenv f -> DelayedOpenAfun aenv f
convertOpenAfun c (Alam lhs f) = Alam lhs (convertOpenAfun c f)
convertOpenAfun c (Abody b) = Abody (convertOpenAcc c b)
type EmbedAcc acc = forall aenv arrs. acc aenv arrs -> Embed acc aenv arrs
type ElimAcc acc = forall aenv s t. acc aenv s -> acc (aenv,s) t -> Bool
embedOpenAcc :: HasCallStack => Config -> OpenAcc aenv arrs -> Embed OpenAcc aenv arrs
embedOpenAcc config (OpenAcc pacc) =
embedPreOpenAcc config matchOpenAcc (embedOpenAcc config) elimOpenAcc pacc
where
elimOpenAcc :: ElimAcc OpenAcc
elimOpenAcc _bnd body
| count False ZeroIdx body <= lIMIT = True
| otherwise = False
where
lIMIT = 1
count :: UsesOfAcc OpenAcc
count no ix (OpenAcc pacc) = usesOfPreAcc no count ix pacc
matchOpenAcc :: MatchAcc OpenAcc
matchOpenAcc (OpenAcc pacc1) (OpenAcc pacc2) =
matchPreOpenAcc matchOpenAcc pacc1 pacc2
embedPreOpenAcc
:: HasCallStack
=> Config
-> MatchAcc OpenAcc
-> EmbedAcc OpenAcc
-> ElimAcc OpenAcc
-> PreOpenAcc OpenAcc aenv arrs
-> Embed OpenAcc aenv arrs
embedPreOpenAcc config matchAcc embedAcc elimAcc pacc
= unembed
$ case pacc of
Alet lhs bnd body -> aletD embedAcc elimAcc lhs bnd body
Anil -> done $ Anil
Acond p at ae -> acondD matchAcc embedAcc (cvtE p) at ae
Apply aR f a -> done $ Apply aR (cvtAF f) (cvtA a)
Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a)
Apair a1 a2 -> done $ Apair (cvtA a1) (cvtA a2)
Aforeign aR ff f a -> done $ Aforeign aR ff (cvtAF f) (cvtA a)
Avar v -> done $ Avar v
Use aR a -> done $ Use aR a
Unit t e -> done $ Unit t (cvtE e)
Generate aR sh f -> generateD aR (cvtE sh) (cvtF f)
Map t f a -> mapD t (cvtF f) (embedAcc a)
ZipWith t f a b -> fuse2 (into (zipWithD t) (cvtF f)) a b
Transform aR sh p f a -> transformD aR (cvtE sh) (cvtF p) (cvtF f) (embedAcc a)
Backpermute slr sl p a
-> fuse (into2 (backpermuteD slr) (cvtE sl) (cvtF p)) a
Slice slix a sl -> fuse (into (sliceD slix) (cvtE sl)) a
Replicate slix sh a -> fuse (into (replicateD slix) (cvtE sh)) a
Reshape slr sl a -> reshapeD slr (embedAcc a) (cvtE sl)
Fold f z a -> embed aR (into2M Fold (cvtF f) (cvtE <$> z)) a
FoldSeg i f z a s -> embed2 aR (into2M (FoldSeg i) (cvtF f) (cvtE <$> z)) a s
Scan d f z a -> embed aR (into2M (Scan d) (cvtF f) (cvtE <$> z)) a
Scan' d f z a -> embed aR (into2 (Scan' d) (cvtF f) (cvtE z)) a
Permute f d p a -> embed2 aR (into2 permute (cvtF f) (cvtF p)) d a
Stencil s t f x a -> embed aR (into2 (stencil1 s t) (cvtF f) (cvtB x)) a
Stencil2 s1 s2 t f x a y b
-> embed2 aR (into3 (stencil2 s1 s2 t) (cvtF f) (cvtB x) (cvtB y)) a b
where
aR = arraysR pacc
unembed :: HasCallStack => Embed OpenAcc aenv arrs -> Embed OpenAcc aenv arrs
unembed x
| array_fusion `member` options config = x
| Embed env cc <- x
, pacc <- compute cc
= case avarsOut extractOpenAcc pacc of
Just vars -> Embed env $ Done vars
_
| DeclareVars lhs _ value <- declareVars (arraysR pacc)
-> Embed (PushEnv env lhs $ OpenAcc pacc) $ Done $ value weakenId
cvtA :: HasCallStack => OpenAcc aenv' a -> OpenAcc aenv' a
cvtA = computeAcc . embedAcc
cvtAF :: HasCallStack => PreOpenAfun OpenAcc aenv' f -> PreOpenAfun OpenAcc aenv' f
cvtAF (Alam lhs f) = Alam lhs (cvtAF f)
cvtAF (Abody a) = Abody (cvtA a)
permute f p d a = Permute f d p a
stencil1 s t f x a = Stencil s t f x a
stencil2 s1 s2 t f x y a b = Stencil2 s1 s2 t f x a y b
cvtF :: HasCallStack => Fun aenv' t -> Fun aenv' t
cvtF = simplifyFun
cvtE :: HasCallStack => Exp aenv' t -> Exp aenv' t
cvtE = simplifyExp
cvtB :: HasCallStack => Boundary aenv' t -> Boundary aenv' t
cvtB Clamp = Clamp
cvtB Mirror = Mirror
cvtB Wrap = Wrap
cvtB (Constant c) = Constant c
cvtB (Function f) = Function (cvtF f)
into :: (HasCallStack, Sink f)
=> (f env' a -> b)
-> f env a
-> Extend ArrayR OpenAcc env env'
-> b
into op a env = op (sinkA env a)
into2 :: (HasCallStack, Sink f1, Sink f2)
=> (f1 env' a -> f2 env' b -> c)
-> f1 env a
-> f2 env b
-> Extend ArrayR OpenAcc env env'
-> c
into2 op a b env = op (sinkA env a) (sinkA env b)
into2M :: (HasCallStack, Sink f1, Sink f2)
=> (f1 env' a -> Maybe (f2 env' b) -> c)
-> f1 env a
-> Maybe (f2 env b)
-> Extend ArrayR acc env env'
-> c
into2M op a b env = op (sinkA env a) (sinkA env <$> b)
into3 :: (HasCallStack, Sink f1, Sink f2, Sink f3)
=> (f1 env' a -> f2 env' b -> f3 env' c -> d)
-> f1 env a
-> f2 env b
-> f3 env c
-> Extend ArrayR OpenAcc env env'
-> d
into3 op a b c env = op (sinkA env a) (sinkA env b) (sinkA env c)
fuse :: HasCallStack
=> (forall aenv'. Extend ArrayR OpenAcc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs)
-> OpenAcc aenv as
-> Embed OpenAcc aenv bs
fuse op (embedAcc -> Embed env cc) = Embed env (op env cc)
fuse2 :: HasCallStack
=> (forall aenv'. Extend ArrayR OpenAcc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs -> Cunctation aenv' cs)
-> OpenAcc aenv as
-> OpenAcc aenv bs
-> Embed OpenAcc aenv cs
fuse2 op a1 a0
| Embed env1 cc1 <- embedAcc a1
, Embed env0 cc0 <- embedAcc (sinkA env1 a0)
, env <- env1 `append` env0
= Embed env (op env (sinkA env0 cc1) cc0)
embed :: HasCallStack
=> ArraysR bs
-> (forall aenv'. Extend ArrayR OpenAcc aenv aenv' -> OpenAcc aenv' as -> PreOpenAcc OpenAcc aenv' bs)
-> OpenAcc aenv as
-> Embed OpenAcc aenv bs
embed reprBs op (embedAcc -> Embed env cc)
| Done{} <- cc
, DeclareVars lhs _ value <- declareVars reprBs
= Embed (PushEnv BaseEnv lhs $ OpenAcc (op BaseEnv (computeAcc (Embed env cc)))) $ Done $ value weakenId
| otherwise
, DeclareVars lhs _ value <- declareVars reprBs
= Embed (PushEnv env lhs $ OpenAcc (op env (OpenAcc (compute cc)))) $ Done $ value weakenId
embed2 :: HasCallStack
=> ArraysR cs
-> (forall aenv'. Extend ArrayR OpenAcc aenv aenv' -> OpenAcc aenv' as -> OpenAcc aenv' bs -> PreOpenAcc OpenAcc aenv' cs)
-> OpenAcc aenv as
-> OpenAcc aenv bs
-> Embed OpenAcc aenv cs
embed2 reprCs op (embedAcc -> Embed env1 cc1) a0
| Done{} <- cc1
, a1 <- computeAcc (Embed env1 cc1)
= embed reprCs (\env0 -> op env0 (sinkA env0 a1)) a0
| Embed env0 cc0 <- embedAcc (sinkA env1 a0)
, env <- env1 `append` env0
= case cc0 of
Done{}
| DeclareVars lhs _ value <- declareVars reprCs
-> Embed (PushEnv env1 lhs $ OpenAcc (op env1 (OpenAcc (compute cc1)) (computeAcc (Embed env0 cc0))))
$ Done
$ value weakenId
_
| DeclareVars lhs _ value <- declareVars reprCs
-> Embed (PushEnv env lhs $ OpenAcc (op env (OpenAcc (compute (sinkA env0 cc1))) (OpenAcc (compute cc0))))
$ Done
$ value weakenId
data Embed acc aenv a where
Embed :: Extend ArrayR acc aenv aenv'
-> Cunctation aenv' a
-> Embed acc aenv a
instance HasArraysR acc => HasArraysR (Embed acc) where
arraysR (Embed _ c) = arraysR c
data Cunctation aenv a where
Done :: ArrayVars aenv arrs
-> Cunctation aenv arrs
Yield :: ArrayR (Array sh e)
-> Exp aenv sh
-> Fun aenv (sh -> e)
-> Cunctation aenv (Array sh e)
Step :: ArrayR (Array sh' b)
-> Exp aenv sh'
-> Fun aenv (sh' -> sh)
-> Fun aenv (a -> b)
-> ArrayVar aenv (Array sh a)
-> Cunctation aenv (Array sh' b)
instance HasArraysR Cunctation where
arraysR (Done v) = varsType v
arraysR (Yield aR _ _) = TupRsingle aR
arraysR (Step aR _ _ _ _) = TupRsingle aR
instance Sink Cunctation where
weaken k = \case
Done v -> Done (weakenVars k v)
Step repr sh p f v -> Step repr (weaken k sh) (weaken k p) (weaken k f) (weaken k v)
Yield repr sh f -> Yield repr (weaken k sh) (weaken k f)
simplifyCC :: HasCallStack => Cunctation aenv a -> Cunctation aenv a
simplifyCC = \case
Done v
-> Done v
Yield aR (simplifyExp -> sh) (simplifyFun -> f)
-> Yield aR sh f
Step aR (simplifyExp -> sh) (simplifyFun -> p) (simplifyFun -> f) v
| Just Refl <- matchOpenExp sh (arrayShape v)
, Just Refl <- isIdentity p
, Just Refl <- isIdentity f
-> Done $ TupRsingle v
| otherwise
-> Step aR sh p f v
done :: HasCallStack => PreOpenAcc OpenAcc aenv a -> Embed OpenAcc aenv a
done pacc
| Just vars <- avarsOut extractOpenAcc pacc
= Embed BaseEnv (Done vars)
| DeclareVars lhs _ value <- declareVars (arraysR pacc)
= Embed (PushEnv BaseEnv lhs $ OpenAcc pacc) $ Done $ value weakenId
doneZeroIdx :: ArrayR (Array sh e) -> Cunctation (aenv, Array sh e) (Array sh e)
doneZeroIdx repr = Done $ TupRsingle $ Var repr ZeroIdx
yield :: HasCallStack
=> Cunctation aenv (Array sh e)
-> Cunctation aenv (Array sh e)
yield cc =
case cc of
Yield{} -> cc
Step tR sh p f v -> Yield tR sh (f `compose` indexArray v `compose` p)
Done (TupRsingle v@(Var tR _)) -> Yield tR (arrayShape v) (indexArray v)
step :: HasCallStack
=> Cunctation aenv (Array sh e)
-> Maybe (Cunctation aenv (Array sh e))
step cc =
case cc of
Yield{} -> Nothing
Step{} -> Just cc
Done (TupRsingle v@(Var aR@(ArrayR shR tR) _))
-> Just $ Step aR (arrayShape v) (identity $ shapeType shR) (identity tR) v
shape :: HasCallStack => Cunctation aenv (Array sh e) -> Exp aenv sh
shape cc
| Just (Step _ sh _ _ _) <- step cc = sh
| Yield _ sh _ <- yield cc = sh
computeAcc
:: HasCallStack
=> Embed OpenAcc aenv arrs
-> OpenAcc aenv arrs
computeAcc (Embed BaseEnv cc) = OpenAcc (compute cc)
computeAcc (Embed env@(PushEnv bot lhs top) cc) =
case simplifyCC cc of
Done v -> bindA env (avarsIn OpenAcc v)
Yield repr sh f -> bindA env (OpenAcc (Generate repr sh f))
Step repr sh p f v@(Var _ ix)
| Just Refl <- matchOpenExp sh (arrayShape v)
, Just Refl <- isIdentity p
-> case ix of
ZeroIdx
| LeftHandSideSingle ArrayR{} <- lhs
, Just (OpenAccFun g) <- strengthen noTop (OpenAccFun f)
-> bindA bot (OpenAcc (Map (arrayRtype repr) g top))
_ -> bindA env (OpenAcc (Map (arrayRtype repr) f (avarIn OpenAcc v)))
| Just Refl <- isIdentity f
-> case ix of
ZeroIdx
| LeftHandSideSingle ArrayR{} <- lhs
, Just (OpenAccFun q) <- strengthen noTop (OpenAccFun p)
, Just (OpenAccExp sz) <- strengthen noTop (OpenAccExp sh)
-> bindA bot (OpenAcc (Backpermute (arrayRshape repr) sz q top))
_ -> bindA env (OpenAcc (Backpermute (arrayRshape repr) sh p (avarIn OpenAcc v)))
| otherwise
-> case ix of
ZeroIdx
| LeftHandSideSingle ArrayR{} <- lhs
, Just (OpenAccFun g) <- strengthen noTop (OpenAccFun f)
, Just (OpenAccFun q) <- strengthen noTop (OpenAccFun p)
, Just (OpenAccExp sz) <- strengthen noTop (OpenAccExp sh)
-> bindA bot (OpenAcc (Transform repr sz q g top))
_ -> bindA env (OpenAcc (Transform repr sh p f (avarIn OpenAcc v)))
where
bindA :: HasCallStack
=> Extend ArrayR OpenAcc aenv aenv'
-> OpenAcc aenv' a
-> OpenAcc aenv a
bindA BaseEnv b = b
bindA (PushEnv env lhs a) b
| Just vars <- extractOpenArrayVars b
, Just Refl <- bindingIsTrivial lhs vars = bindA env a
| otherwise = bindA env (OpenAcc (Alet lhs a b))
noTop :: (aenv, a) :?> aenv
noTop ZeroIdx = Nothing
noTop (SuccIdx ix) = Just ix
compute
:: HasCallStack
=> Cunctation aenv arrs
-> PreOpenAcc OpenAcc aenv arrs
compute cc = case simplifyCC cc of
Done TupRunit -> Anil
Done (TupRsingle v@(Var ArrayR{} _)) -> Avar v
Done (TupRpair v1 v2) -> avarsIn OpenAcc v1 `Apair` avarsIn OpenAcc v2
Yield repr sh f -> Generate repr sh f
Step (ArrayR shR tR) sh p f v
| Just Refl <- matchOpenExp sh (arrayShape v)
, Just Refl <- isIdentity p -> Map tR f (avarIn OpenAcc v)
| Just Refl <- isIdentity f -> Backpermute shR sh p (avarIn OpenAcc v)
| otherwise -> Transform (ArrayR shR tR) sh p f (avarIn OpenAcc v)
generateD
:: HasCallStack
=> ArrayR (Array sh e)
-> Exp aenv sh
-> Fun aenv (sh -> e)
-> Embed OpenAcc aenv (Array sh e)
generateD repr sh f
= Stats.ruleFired "generateD"
$ Embed BaseEnv (Yield repr sh f)
mapD :: HasCallStack
=> TypeR b
-> Fun aenv (a -> b)
-> Embed OpenAcc aenv (Array sh a)
-> Embed OpenAcc aenv (Array sh b)
mapD tR f (unzipD tR f -> Just a) = a
mapD tR f (Embed env cc)
= Stats.ruleFired "mapD"
$ Embed env (go cc)
where
go (step -> Just (Step (ArrayR shR _) sh ix g v)) = Step (ArrayR shR tR) sh ix (sinkA env f `compose` g) v
go (yield -> Yield (ArrayR shR _) sh g) = Yield (ArrayR shR tR) sh (sinkA env f `compose` g)
unzipD
:: HasCallStack
=> TypeR b
-> Fun aenv (a -> b)
-> Embed OpenAcc aenv (Array sh a)
-> Maybe (Embed OpenAcc aenv (Array sh b))
unzipD tR f (Embed env cc@(Done v))
| Lam lhs (Body a) <- f
, Just vars <- extractExpVars a
, ArrayR shR _ <- arrayR cc
, f' <- Lam lhs $ Body $ expVars vars
= Just $ Embed (env `pushArrayEnv` OpenAcc (Map tR f' $ avarsIn OpenAcc v)) $ doneZeroIdx $ ArrayR shR tR
unzipD _ _ _
= Nothing
backpermuteD
:: HasCallStack
=> ShapeR sh'
-> Exp aenv sh'
-> Fun aenv (sh' -> sh)
-> Cunctation aenv (Array sh e)
-> Cunctation aenv (Array sh' e)
backpermuteD shR' sh' p = Stats.ruleFired "backpermuteD" . go
where
go (step -> Just (Step (ArrayR _ tR) _ q f v)) = Step (ArrayR shR' tR) sh' (q `compose` p) f v
go (yield -> Yield (ArrayR _ tR) _ g) = Yield (ArrayR shR' tR) sh' (g `compose` p)
transformD
:: HasCallStack
=> ArrayR (Array sh' b)
-> Exp aenv sh'
-> Fun aenv (sh' -> sh)
-> Fun aenv (a -> b)
-> Embed OpenAcc aenv (Array sh a)
-> Embed OpenAcc aenv (Array sh' b)
transformD (ArrayR shR' tR) sh' p f
= Stats.ruleFired "transformD"
. fuse (into2 (backpermuteD shR') sh' p)
. mapD tR f
where
fuse :: HasCallStack
=> (forall aenv'. Extend ArrayR OpenAcc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs)
-> Embed OpenAcc aenv as
-> Embed OpenAcc aenv bs
fuse op (Embed env cc) = Embed env (op env cc)
into2 :: (HasCallStack, Sink f1, Sink f2)
=> (f1 env' a -> f2 env' b -> c)
-> f1 env a
-> f2 env b
-> Extend ArrayR OpenAcc env env'
-> c
into2 op a b env = op (sinkA env a) (sinkA env b)
replicateD
:: HasCallStack
=> SliceIndex slix sl co sh
-> Exp aenv slix
-> Cunctation aenv (Array sl e)
-> Cunctation aenv (Array sh e)
replicateD sliceIndex slix cc
= Stats.ruleFired "replicateD"
$ backpermuteD (sliceDomainR sliceIndex) (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc
sliceD
:: HasCallStack
=> SliceIndex slix sl co sh
-> Exp aenv slix
-> Cunctation aenv (Array sh e)
-> Cunctation aenv (Array sl e)
sliceD sliceIndex slix cc
= Stats.ruleFired "sliceD"
$ backpermuteD (sliceShapeR sliceIndex) (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc
reshapeD
:: HasCallStack
=> ShapeR sl
-> Embed OpenAcc aenv (Array sh e)
-> Exp aenv sl
-> Embed OpenAcc aenv (Array sl e)
reshapeD slr (Embed env cc) (sinkA env -> sl)
| Done v <- cc
= Embed (env `pushArrayEnv` OpenAcc (Reshape slr sl (avarsIn OpenAcc v))) $ doneZeroIdx repr
| otherwise
= Stats.ruleFired "reshapeD"
$ Embed env (backpermuteD slr sl (reindex (arrayRshape $ arrayR cc) (shape cc) slr sl) cc)
where
ArrayR _ tR = arrayR cc
repr = ArrayR slr tR
zipWithD
:: HasCallStack
=> TypeR c
-> Fun aenv (a -> b -> c)
-> Cunctation aenv (Array sh a)
-> Cunctation aenv (Array sh b)
-> Cunctation aenv (Array sh c)
zipWithD tR f cc1 cc0
| Just (Step (ArrayR shR _) sh1 p1 f1 v1) <- step cc1
, Just (Step _ sh0 p0 f0 v0) <- step cc0
, Just Refl <- matchVar v1 v0
, Just Refl <- matchOpenFun p1 p0
= Stats.ruleFired "zipWithD/step"
$ Step (ArrayR shR tR) (intersect shR sh1 sh0) p0 (combine f f1 f0) v0
| Yield (ArrayR shR _) sh1 f1 <- yield cc1
, Yield _ sh0 f0 <- yield cc0
= Stats.ruleFired "zipWithD"
$ Yield (ArrayR shR tR) (intersect shR sh1 sh0) (combine f f1 f0)
where
combine :: forall aenv a b c e. HasCallStack
=> Fun aenv (a -> b -> c)
-> Fun aenv (e -> a)
-> Fun aenv (e -> b)
-> Fun aenv (e -> c)
combine c ixa ixb
| Lam lhs1 (Body ixa') <- ixa
, Lam lhs2 (Body ixb') <- ixb
= case matchELeftHandSide lhs1 lhs2 of
Just Refl
| Lam lhsA (Lam lhsB (Body c')) <- weakenE (weakenWithLHS lhs1) c
-> Lam lhs1 $ Body $ Let lhsA ixa' $ Let lhsB (weakenE (weakenWithLHS lhsA) ixb') c'
Nothing
| CombinedLHS lhs k1 k2 <- combineLhs lhs1 lhs2
, Lam lhsA (Lam lhsB (Body c')) <- weakenE (weakenWithLHS lhs) c
, ixa'' <- weakenE k1 ixa'
-> Lam lhs $ Body $ Let lhsA ixa'' $ Let lhsB (weakenE (weakenWithLHS lhsA .> k2) ixb') c'
combineLhs
:: HasCallStack
=> LeftHandSide s t env env1'
-> LeftHandSide s t env env2'
-> CombinedLHS s t env1' env2' env
combineLhs = go weakenId weakenId
where
go :: env1 :> env -> env2 :> env -> LeftHandSide s t env1 env1' -> LeftHandSide s t env2 env2' -> CombinedLHS s t env1' env2' env
go k1 k2 (LeftHandSideWildcard tR) (LeftHandSideWildcard _) = CombinedLHS (LeftHandSideWildcard tR) k1 k2
go k1 k2 (LeftHandSideSingle tR) (LeftHandSideSingle _) = CombinedLHS (LeftHandSideSingle tR) (sink k1) (sink k2)
go k1 k2 (LeftHandSidePair l1 h1) (LeftHandSidePair l2 h2)
| CombinedLHS l k1' k2' <- go k1 k2 l1 l2
, CombinedLHS h k1'' k2'' <- go k1' k2' h1 h2 = CombinedLHS (LeftHandSidePair l h) k1'' k2''
go k1 k2 (LeftHandSideWildcard _) lhs
| Exists lhs' <- rebuildLHS lhs = CombinedLHS lhs' (weakenWithLHS lhs' .> k1) (sinkWithLHS lhs lhs' k2)
go k1 k2 lhs (LeftHandSideWildcard _)
| Exists lhs' <- rebuildLHS lhs = CombinedLHS lhs' (sinkWithLHS lhs lhs' k1) (weakenWithLHS lhs' .> k2)
data CombinedLHS s t env1' env2' env where
CombinedLHS :: LeftHandSide s t env env'
-> env1' :> env'
-> env2' :> env'
-> CombinedLHS s t env1' env2' env
aletD :: HasCallStack
=> EmbedAcc OpenAcc
-> ElimAcc OpenAcc
-> ALeftHandSide arrs aenv aenv'
-> OpenAcc aenv arrs
-> OpenAcc aenv' brrs
-> Embed OpenAcc aenv brrs
aletD embedAcc elimAcc lhs (embedAcc -> Embed env1 cc1) acc0
| LeftHandSideSingle _ <- lhs
, Done (TupRsingle v1@(Var ArrayR{} _)) <- cc1
, Embed env0 cc0 <- embedAcc $ rebuildA (subAtop (Avar v1) . sink1 env1) acc0
= Stats.ruleFired "aletD/float"
$ Embed (env1 `append` env0) cc0
| otherwise
= aletD' embedAcc elimAcc lhs (Embed env1 cc1) (embedAcc acc0)
aletD' :: forall aenv aenv' arrs brrs. HasCallStack
=> EmbedAcc OpenAcc
-> ElimAcc OpenAcc
-> ALeftHandSide arrs aenv aenv'
-> Embed OpenAcc aenv arrs
-> Embed OpenAcc aenv' brrs
-> Embed OpenAcc aenv brrs
aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed env0 cc0)
| acc1 <- computeAcc (Embed env1 cc1)
, False <- elimAcc acc1 acc0
= Stats.ruleFired "aletD/bind"
$ Embed (BaseEnv `pushArrayEnv` acc1 `append` env0) cc0
| acc0' <- sink1 env1 acc0
= Stats.ruleFired "aletD/eliminate"
$ case cc1 of
Step{} -> eliminate env1 cc1 acc0'
Yield{} -> eliminate env1 cc1 acc0'
where
acc0 :: OpenAcc aenv' brrs
acc0 = computeAcc (Embed env0 cc0)
kmap :: forall aenv a b. (PreOpenAcc OpenAcc aenv a -> PreOpenAcc OpenAcc aenv b)
-> OpenAcc aenv a
-> OpenAcc aenv b
kmap f (OpenAcc pacc) = OpenAcc (f pacc)
eliminate
:: forall aenv aenv' sh e brrs. HasCallStack
=> Extend ArrayR OpenAcc aenv aenv'
-> Cunctation aenv' (Array sh e)
-> OpenAcc (aenv', Array sh e) brrs
-> Embed OpenAcc aenv brrs
eliminate env1 cc1 body
| Done v1 <- cc1
, TupRsingle v1'@(Var r _) <- v1 = elim r (arrayShape v1') (indexArray v1')
| Step r sh1 p1 f1 v1 <- cc1 = elim r sh1 (f1 `compose` indexArray v1 `compose` p1)
| Yield r sh1 f1 <- cc1 = elim r sh1 f1
where
bnd :: PreOpenAcc OpenAcc aenv' (Array sh e)
bnd = compute cc1
elim :: HasCallStack
=> ArrayR (Array sh e)
-> Exp aenv' sh
-> Fun aenv' (sh -> e)
-> Embed OpenAcc aenv brrs
elim r sh1 f1
| sh1' <- weaken (weakenSucc' weakenId) sh1
, f1' <- weaken (weakenSucc' weakenId) f1
, Embed env0' cc0' <- embedAcc $ rebuildA (subAtop bnd) $ kmap (replaceA sh1' f1' $ Var r ZeroIdx) body
= Embed (env1 `append` env0') cc0'
replaceE :: forall env aenv sh e t. HasCallStack
=> OpenExp env aenv sh
-> OpenFun env aenv (sh -> e)
-> ArrayVar aenv (Array sh e)
-> OpenExp env aenv t
-> OpenExp env aenv t
replaceE sh' f' avar@(Var (ArrayR shR _) _) exp =
case exp of
Let lhs x y -> let k = weakenWithLHS lhs
in Let lhs (cvtE x) (replaceE (weakenE k sh') (weakenE k f') avar y)
Evar var -> Evar var
Foreign tR ff f e -> Foreign tR ff f (cvtE e)
Const tR c -> Const tR c
Undef tR -> Undef tR
Nil -> Nil
Pair e1 e2 -> Pair (cvtE e1) (cvtE e2)
VecPack vR e -> VecPack vR (cvtE e)
VecUnpack vR e -> VecUnpack vR (cvtE e)
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex shR' sh ix -> ToIndex shR' (cvtE sh) (cvtE ix)
FromIndex shR' sh i -> FromIndex shR' (cvtE sh) (cvtE i)
Case e rhs def -> Case (cvtE e) (over (mapped . _2) cvtE rhs) (fmap cvtE def)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
PrimConst c -> PrimConst c
PrimApp g x -> PrimApp g (cvtE x)
ShapeSize shR' sh -> ShapeSize shR' (cvtE sh)
While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x)
Coerce t1 t2 e -> Coerce t1 t2 (cvtE e)
Shape a
| Just Refl <- matchVar a avar -> Stats.substitution "replaceE/shape" sh'
| otherwise -> exp
Index a sh
| Just Refl <- matchVar a avar
, Lam lhs (Body b) <- f' -> Stats.substitution "replaceE/!" . cvtE $ Let lhs sh b
| otherwise -> Index a (cvtE sh)
LinearIndex a i
| Just Refl <- matchVar a avar
, Lam lhs (Body b) <- f'
-> Stats.substitution "replaceE/!!" . cvtE
$ Let lhs
(Let (LeftHandSideSingle scalarTypeInt) i $ FromIndex shR (weakenE (weakenSucc' weakenId) sh') $ Evar $ Var scalarTypeInt ZeroIdx)
b
| otherwise -> LinearIndex a (cvtE i)
where
cvtE :: OpenExp env aenv s -> OpenExp env aenv s
cvtE = replaceE sh' f' avar
replaceF :: forall env aenv sh e t. HasCallStack
=> OpenExp env aenv sh
-> OpenFun env aenv (sh -> e)
-> ArrayVar aenv (Array sh e)
-> OpenFun env aenv t
-> OpenFun env aenv t
replaceF sh' f' avar fun =
case fun of
Body e -> Body (replaceE sh' f' avar e)
Lam lhs f -> let k = weakenWithLHS lhs
in Lam lhs (replaceF (weakenE k sh') (weakenE k f') avar f)
replaceA :: forall aenv sh e a. HasCallStack
=> Exp aenv sh
-> Fun aenv (sh -> e)
-> ArrayVar aenv (Array sh e)
-> PreOpenAcc OpenAcc aenv a
-> PreOpenAcc OpenAcc aenv a
replaceA sh' f' avar pacc =
case pacc of
Avar v
| Just Refl <- matchVar v avar -> Avar avar
| otherwise -> Avar v
Alet lhs bnd (body :: OpenAcc aenv1 a) ->
let w :: aenv :> aenv1
w = weakenWithLHS lhs
sh'' = weaken w sh'
f'' = weaken w f'
in
Alet lhs (cvtA bnd) (kmap (replaceA sh'' f'' (weaken w avar)) body)
Use repr arrs -> Use repr arrs
Unit tR e -> Unit tR (cvtE e)
Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae)
Anil -> Anil
Apair a1 a2 -> Apair (cvtA a1) (cvtA a2)
Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (cvtA a)
Apply repr f a -> Apply repr (cvtAF f) (cvtA a)
Aforeign repr ff f a -> Aforeign repr ff f (cvtA a)
Generate repr sh f -> Generate repr (cvtE sh) (cvtF f)
Map tR f a -> Map tR (cvtF f) (cvtA a)
ZipWith tR f a b -> ZipWith tR (cvtF f) (cvtA a) (cvtA b)
Backpermute shR sh p a -> Backpermute shR (cvtE sh) (cvtF p) (cvtA a)
Transform repr sh p f a -> Transform repr (cvtE sh) (cvtF p) (cvtF f) (cvtA a)
Slice slix a sl -> Slice slix (cvtA a) (cvtE sl)
Replicate slix sh a -> Replicate slix (cvtE sh) (cvtA a)
Reshape shR sl a -> Reshape shR (cvtE sl) (cvtA a)
Fold f z a -> Fold (cvtF f) (cvtE <$> z) (cvtA a)
FoldSeg i f z a s -> FoldSeg i (cvtF f) (cvtE <$> z) (cvtA a) (cvtA s)
Scan d f z a -> Scan d (cvtF f) (cvtE <$> z) (cvtA a)
Scan' d f z a -> Scan' d (cvtF f) (cvtE z) (cvtA a)
Permute f d p a -> Permute (cvtF f) (cvtA d) (cvtF p) (cvtA a)
Stencil s t f x a -> Stencil s t (cvtF f) (cvtB x) (cvtA a)
Stencil2 s1 s2 t f x a y b
-> Stencil2 s1 s2 t (cvtF f) (cvtB x) (cvtA a) (cvtB y) (cvtA b)
where
cvtA :: OpenAcc aenv s -> OpenAcc aenv s
cvtA = kmap (replaceA sh' f' avar)
cvtE :: Exp aenv s -> Exp aenv s
cvtE = replaceE sh' f' avar
cvtF :: Fun aenv s -> Fun aenv s
cvtF = replaceF sh' f' avar
cvtB :: Boundary aenv s -> Boundary aenv s
cvtB Clamp = Clamp
cvtB Mirror = Mirror
cvtB Wrap = Wrap
cvtB (Constant c) = Constant c
cvtB (Function f) = Function (cvtF f)
cvtAF :: HasCallStack => PreOpenAfun OpenAcc aenv s -> PreOpenAfun OpenAcc aenv s
cvtAF = cvt sh' f' avar
where
cvt :: forall aenv a.
Exp aenv sh -> Fun aenv (sh -> e) -> ArrayVar aenv (Array sh e)
-> PreOpenAfun OpenAcc aenv a
-> PreOpenAfun OpenAcc aenv a
cvt sh'' f'' avar' (Abody a) = Abody $ kmap (replaceA sh'' f'' avar') a
cvt sh'' f'' avar' (Alam lhs (af :: PreOpenAfun OpenAcc aenv1 b)) =
Alam lhs $ cvt (weaken w sh'')
(weaken w f'')
(weaken w avar')
af
where
w :: aenv :> aenv1
w = weakenWithLHS lhs
aletD' _ _ lhs (Embed env1 cc1) (Embed env0 cc0)
= Stats.ruleFired "aletD/bind"
$ Embed (PushEnv BaseEnv lhs acc1 `append` env0) cc0
where
acc1 = computeAcc $ Embed env1 cc1
acondD :: HasCallStack
=> MatchAcc OpenAcc
-> EmbedAcc OpenAcc
-> Exp aenv PrimBool
-> OpenAcc aenv arrs
-> OpenAcc aenv arrs
-> Embed OpenAcc aenv arrs
acondD matchAcc embedAcc p t e
| Const _ 1 <- p = Stats.knownBranch "True" $ embedAcc t
| Const _ 0 <- p = Stats.knownBranch "False" $ embedAcc e
| Just Refl <- matchAcc t e = Stats.knownBranch "redundant" $ embedAcc e
| otherwise = done $ Acond p (computeAcc (embedAcc t))
(computeAcc (embedAcc e))
identity :: TypeR a -> OpenFun env aenv (a -> a)
identity t
| DeclareVars lhs _ value <- declareVars t
= Lam lhs $ Body $ expVars $ value weakenId
toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> Int)
toIndex shR sh
| DeclareVars lhs k value <- declareVars $ shapeType shR
= Lam lhs $ Body $ ToIndex shR (weakenE k sh) $ expVars $ value weakenId
fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (Int -> sh)
fromIndex shR sh
= Lam (LeftHandSideSingle scalarTypeInt)
$ Body
$ FromIndex shR (weakenE (weakenSucc' weakenId) sh)
$ Evar
$ Var scalarTypeInt ZeroIdx
intersect :: ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh -> OpenExp env aenv sh
intersect = mkShapeBinary f
where
f a b = PrimApp (PrimMin singleType) $ Pair a b
mkShapeBinary
:: (forall env'. OpenExp env' aenv Int -> OpenExp env' aenv Int -> OpenExp env' aenv Int)
-> ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv sh
-> OpenExp env aenv sh
mkShapeBinary _ ShapeRz _ _ = Nil
mkShapeBinary f (ShapeRsnoc shR) (Pair as a) (Pair bs b) = mkShapeBinary f shR as bs `Pair` f a b
mkShapeBinary f shR (Let lhs bnd a) b = Let lhs bnd $ mkShapeBinary f shR a (weakenE (weakenWithLHS lhs) b)
mkShapeBinary f shR a (Let lhs bnd b) = Let lhs bnd $ mkShapeBinary f shR (weakenE (weakenWithLHS lhs) a) b
mkShapeBinary f shR a b@Pair{}
| DeclareVars lhs k value <- declareVars $ shapeType shR
= Let lhs a $ mkShapeBinary f shR (expVars $ value weakenId) (weakenE k b)
mkShapeBinary f shR a b
| DeclareVars lhs k value <- declareVars $ shapeType shR
= Let lhs b $ mkShapeBinary f shR (weakenE k a) (expVars $ value weakenId)
reindex :: ShapeR sh'
-> OpenExp env aenv sh'
-> ShapeR sh
-> OpenExp env aenv sh
-> OpenFun env aenv (sh -> sh')
reindex shR' sh' shR sh
| Just Refl <- matchOpenExp sh sh' = identity (shapeType shR')
| otherwise = fromIndex shR' sh' `compose` toIndex shR sh
extend :: SliceIndex slix sl co sh
-> Exp aenv slix
-> Fun aenv (sh -> sl)
extend sliceIndex slix
| DeclareVars lhs k value <- declareVars $ shapeType $ sliceDomainR sliceIndex
= Lam lhs $ Body $ IndexSlice sliceIndex (weakenE k slix) $ expVars $ value weakenId
restrict :: SliceIndex slix sl co sh
-> Exp aenv slix
-> Fun aenv (sl -> sh)
restrict sliceIndex slix
| DeclareVars lhs k value <- declareVars $ shapeType $ sliceShapeR sliceIndex
= Lam lhs $ Body $ IndexFull sliceIndex (weakenE k slix) $ expVars $ value weakenId
arrayShape :: ArrayVar aenv (Array sh e) -> Exp aenv sh
arrayShape = simplifyExp . Shape
indexArray :: ArrayVar aenv (Array sh e) -> Fun aenv (sh -> e)
indexArray v@(Var (ArrayR shR _) _)
| DeclareVars lhs _ value <- declareVars $ shapeType shR
= Lam lhs $ Body $ Index v $ expVars $ value weakenId
linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (Int -> e)
linearIndex v = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ LinearIndex v $ Evar $ Var scalarTypeInt ZeroIdx
extractOpenAcc :: ExtractAcc OpenAcc
extractOpenAcc (OpenAcc pacc) = Just pacc
extractDelayedOpenAcc :: ExtractAcc DelayedOpenAcc
extractDelayedOpenAcc (Manifest pacc) = Just pacc
extractDelayedOpenAcc _ = Nothing
extractOpenArrayVars
:: OpenAcc aenv a
-> Maybe (ArrayVars aenv a)
extractOpenArrayVars (OpenAcc pacc) =
avarsOut extractOpenAcc pacc
extractDelayedArrayVars
:: DelayedOpenAcc aenv a
-> Maybe (ArrayVars aenv a)
extractDelayedArrayVars acc
| Just pacc <- extractDelayedOpenAcc acc = avarsOut extractDelayedOpenAcc pacc
| otherwise = Nothing