{-# LANGUAGE TupleSections #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Trafo.Shrink (
ShrinkAcc,
shrinkExp,
shrinkFun,
UsesOfAcc, usesOfPreAcc, usesOfExp,
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Trafo.Substitution
import qualified Data.Array.Accelerate.Debug.Stats as Stats
import Control.Applicative hiding ( Const )
import Data.Maybe ( isJust )
import Data.Monoid
import Data.Semigroup
import Prelude hiding ( exp, seq )
data VarsRange env =
VarsRange !(Exists (Idx env))
{-# UNPACK #-} !Int
!(Maybe RangeTuple)
data RangeTuple
= RTNil
| RTSingle
| RTPair !RangeTuple !RangeTuple
lhsVarsRange :: LeftHandSide s v env env' -> Either (env :~: env') (VarsRange env')
lhsVarsRange lhs = case rightIx lhs of
Left eq -> Left eq
Right ix -> let (n, rt) = go lhs
in Right $ VarsRange ix n rt
where
rightIx :: LeftHandSide s v env env' -> Either (env :~: env') (Exists (Idx env'))
rightIx (LeftHandSideWildcard _) = Left Refl
rightIx (LeftHandSideSingle _) = Right $ Exists ZeroIdx
rightIx (LeftHandSidePair l1 l2) = case rightIx l2 of
Right ix -> Right ix
Left Refl -> rightIx l1
go :: LeftHandSide s v env env' -> (Int, Maybe (RangeTuple))
go (LeftHandSideWildcard TupRunit) = (0, Just RTNil)
go (LeftHandSideWildcard _) = (0, Nothing)
go (LeftHandSideSingle _) = (1, Just RTSingle)
go (LeftHandSidePair l1 l2) = (n1 + n2, RTPair <$> t1 <*> t2)
where
(n1, t1) = go l1
(n2, t2) = go l2
weakenVarsRange :: LeftHandSide s v env env' -> VarsRange env -> VarsRange env'
weakenVarsRange lhs (VarsRange ix n t) = VarsRange (go lhs ix) n t
where
go :: LeftHandSide s v env env' -> Exists (Idx env) -> Exists (Idx env')
go (LeftHandSideWildcard _) ix' = ix'
go (LeftHandSideSingle _) (Exists ix') = Exists (SuccIdx ix')
go (LeftHandSidePair l1 l2) ix' = go l2 $ go l1 ix'
matchEVarsRange :: VarsRange env -> OpenExp env aenv t -> Bool
matchEVarsRange (VarsRange (Exists first) _ (Just rt)) expr = isJust $ go (idxToInt first) rt expr
where
go :: Int -> RangeTuple -> OpenExp env aenv t -> Maybe Int
go i RTNil Nil = Just i
go i RTSingle (Evar (Var _ ix))
| checkIdx i ix = Just (i + 1)
go i (RTPair t1 t2) (Pair e1 e2)
| Just i' <- go i t2 e2 = go i' t1 e1
go _ _ _ = Nothing
checkIdx :: Int -> Idx env t -> Bool
checkIdx 0 ZeroIdx = True
checkIdx i (SuccIdx ix) = checkIdx (i - 1) ix
checkIdx _ _ = False
matchEVarsRange _ _ = False
varInRange :: VarsRange env -> Var s env t -> Maybe Usages
varInRange (VarsRange (Exists rangeIx) n _) (Var _ varIx) = case go rangeIx varIx of
Nothing -> Nothing
Just j -> Just $ replicate j False ++ [True] ++ replicate (n - j - 1) False
where
go :: Idx env u -> Idx env t -> Maybe Int
go (SuccIdx ix) (SuccIdx ix') = go ix ix'
go ZeroIdx ix' = go' ix' 0
go _ ZeroIdx = Nothing
go' :: Idx env t -> Int -> Maybe Int
go' _ j | j >= n = Nothing
go' ZeroIdx j = Just j
go' (SuccIdx ix') j = go' ix' (j + 1)
data Count
= Impossible !Usages
| Infinity
| Finite {-# UNPACK #-} !Int
type Usages = [Bool]
instance Semigroup Count where
Impossible u1 <> Impossible u2 = Impossible $ zipWith (||) u1 u2
Impossible u <> Finite 0 = Impossible u
Finite 0 <> Impossible u = Impossible u
Impossible u <> _ = Impossible $ map (const True) u
_ <> Impossible u = Impossible $ map (const True) u
Infinity <> _ = Infinity
_ <> Infinity = Infinity
Finite a <> Finite b = Finite $ a + b
instance Monoid Count where
mempty = Finite 0
loopCount :: Count -> Count
loopCount (Finite n) | n > 0 = Infinity
loopCount c = c
shrinkLhs
:: HasCallStack
=> Count
-> LeftHandSide s t env1 env2
-> Maybe (Exists (LeftHandSide s t env1))
shrinkLhs _ (LeftHandSideWildcard _) = Nothing
shrinkLhs (Finite 0) lhs = Just $ Exists $ LeftHandSideWildcard $ lhsToTupR lhs
shrinkLhs (Impossible usages) lhs = case go usages lhs of
(True , [], lhs') -> Just lhs'
(False, [], _ ) -> Nothing
_ -> internalError "Mismatch in length of usages array and LHS"
where
go :: HasCallStack => Usages -> LeftHandSide s t env1 env2 -> (Bool, Usages, Exists (LeftHandSide s t env1))
go us (LeftHandSideWildcard tp) = (False, us, Exists $ LeftHandSideWildcard tp)
go (True : us) (LeftHandSideSingle tp) = (False, us, Exists $ LeftHandSideSingle tp)
go (False : us) (LeftHandSideSingle tp) = (True , us, Exists $ LeftHandSideWildcard $ TupRsingle tp)
go us (LeftHandSidePair l1 l2)
| (c2, us' , Exists l2') <- go us l2
, (c1, us'', Exists l1') <- go us' l1
, Exists l2'' <- rebuildLHS l2'
= let
lhs'
| LeftHandSideWildcard t1 <- l1'
, LeftHandSideWildcard t2 <- l2'' = LeftHandSideWildcard $ TupRpair t1 t2
| otherwise = LeftHandSidePair l1' l2''
in
(c1 || c2, us'', Exists lhs')
go _ _ = internalError "Empty array, mismatch in length of usages array and LHS"
shrinkLhs _ _ = Nothing
strengthenShrunkLHS
:: HasCallStack
=> LeftHandSide s t env1 env2
-> LeftHandSide s t env1' env2'
-> env1 :?> env1'
-> env2 :?> env2'
strengthenShrunkLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k
strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = \ix -> case ix of
ZeroIdx -> Just ZeroIdx
SuccIdx ix' -> SuccIdx <$> k ix'
strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k
strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideWildcard _) k = \ix -> case ix of
ZeroIdx -> Nothing
SuccIdx ix' -> k ix'
strengthenShrunkLHS (LeftHandSidePair l h) (LeftHandSideWildcard t) k = strengthenShrunkLHS h (LeftHandSideWildcard t2) $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k
where
TupRpair t1 t2 = t
strengthenShrunkLHS (LeftHandSideWildcard _) _ _ = internalError "Second LHS defines more variables"
strengthenShrunkLHS _ _ _ = internalError "Mismatch LHS single with LHS pair"
shrinkExp :: HasCallStack => OpenExp env aenv t -> (Bool, OpenExp env aenv t)
shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE
where
lIMIT :: Int
lIMIT = 1
cheap :: OpenExp env aenv t -> Bool
cheap (Evar _) = True
cheap (Pair e1 e2) = cheap e1 && cheap e2
cheap Nil = True
cheap Const{} = True
cheap PrimConst{} = True
cheap Undef{} = True
cheap (Coerce _ _ e) = cheap e
cheap _ = False
shrinkE :: HasCallStack => OpenExp env aenv t -> (Any, OpenExp env aenv t)
shrinkE exp = case exp of
Let (LeftHandSideSingle _) bnd@Evar{} body -> Stats.inline "Var" . yes $ shrinkE (inline body bnd)
Let lhs bnd body
| shouldInline -> case inlineVars lhs (snd body') (snd bnd') of
Just inlined -> Stats.betaReduce msg . yes $ shrinkE inlined
_ -> internalError "Unexpected failure while trying to inline some expression."
| Just (Exists lhs') <- shrinkLhs count lhs -> case strengthenE (strengthenShrunkLHS lhs lhs' Just) (snd body') of
Just body'' -> (Any True, Let lhs' (snd bnd') body'')
Nothing -> internalError "Unexpected failure in strenthenE. Variable was analysed to be unused in usesOfExp, but appeared to be used in strenthenE."
| otherwise -> Let lhs <$> bnd' <*> body'
where
shouldInline = case count of
Finite 0 -> False
Finite n -> n <= lIMIT || cheap (snd bnd')
Infinity -> cheap (snd bnd')
Impossible _ -> False
bnd' = shrinkE bnd
body' = shrinkE body
count = case lhsVarsRange lhs of
Left _ -> Finite 0
Right range -> usesOfExp range (snd body')
msg = case count of
Finite 0 -> "dead exp"
_ -> "inline exp"
Evar v -> pure (Evar v)
Const t c -> pure (Const t c)
Undef t -> pure (Undef t)
Nil -> pure Nil
Pair x y -> Pair <$> shrinkE x <*> shrinkE y
VecPack vec e -> VecPack vec <$> shrinkE e
VecUnpack vec e -> VecUnpack vec <$> shrinkE e
IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh
IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl
ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix
FromIndex shr sh i -> FromIndex shr <$> shrinkE sh <*> shrinkE i
Case e rhs def -> Case <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def
Cond p t e -> Cond <$> shrinkE p <*> shrinkE t <*> shrinkE e
While p f x -> While <$> shrinkF p <*> shrinkF f <*> shrinkE x
PrimConst c -> pure (PrimConst c)
PrimApp f x -> PrimApp f <$> shrinkE x
Index a sh -> Index a <$> shrinkE sh
LinearIndex a i -> LinearIndex a <$> shrinkE i
Shape a -> pure (Shape a)
ShapeSize shr sh -> ShapeSize shr <$> shrinkE sh
Foreign repr ff f e -> Foreign repr ff <$> shrinkF f <*> shrinkE e
Coerce t1 t2 e -> Coerce t1 t2 <$> shrinkE e
shrinkF :: HasCallStack => OpenFun env aenv t -> (Any, OpenFun env aenv t)
shrinkF = first Any . shrinkFun
shrinkMaybeE :: HasCallStack => Maybe (OpenExp env aenv t) -> (Any, Maybe (OpenExp env aenv t))
shrinkMaybeE Nothing = pure Nothing
shrinkMaybeE (Just e) = Just <$> shrinkE e
first :: (a -> a') -> (a,b) -> (a',b)
first f (x,y) = (f x, y)
yes :: (Any, x) -> (Any, x)
yes (_, x) = (Any True, x)
shrinkFun :: HasCallStack => OpenFun env aenv f -> (Bool, OpenFun env aenv f)
shrinkFun (Lam lhs f) = case lhsVarsRange lhs of
Left Refl ->
let b' = case lhs of
LeftHandSideWildcard TupRunit -> b
_ -> True
in (b', Lam (LeftHandSideWildcard $ lhsToTupR lhs) f')
Right range ->
let
count = usesOfFun range f
in case shrinkLhs count lhs of
Just (Exists lhs') -> case strengthenE (strengthenShrunkLHS lhs lhs' Just) f' of
Just f'' -> (True, Lam lhs' f'')
Nothing -> internalError "Unexpected failure in strenthenE. Variable was analysed to be unused in usesOfExp, but appeared to be used in strenthenE."
Nothing -> (b, Lam lhs f')
where
(b, f') = shrinkFun f
shrinkFun (Body b) = Body <$> shrinkExp b
type ShrinkAcc acc = forall aenv a. acc aenv a -> acc aenv a
usesOfExp :: forall env aenv t. VarsRange env -> OpenExp env aenv t -> Count
usesOfExp range = countE
where
countE :: OpenExp env aenv e -> Count
countE exp | matchEVarsRange range exp = Finite 1
countE exp = case exp of
Evar v -> case varInRange range v of
Just cs -> Impossible cs
Nothing -> Finite 0
Let lhs bnd body -> countE bnd <> usesOfExp (weakenVarsRange lhs range) body
Const _ _ -> Finite 0
Undef _ -> Finite 0
Nil -> Finite 0
Pair e1 e2 -> countE e1 <> countE e2
VecPack _ e -> countE e
VecUnpack _ e -> countE e
IndexSlice _ ix sh -> countE ix <> countE sh
IndexFull _ ix sl -> countE ix <> countE sl
FromIndex _ sh i -> countE sh <> countE i
ToIndex _ sh e -> countE sh <> countE e
Case e rhs def -> countE e <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def
Cond p t e -> countE p <> countE t <> countE e
While p f x -> countE x <> loopCount (usesOfFun range p) <> loopCount (usesOfFun range f)
PrimConst _ -> Finite 0
PrimApp _ x -> countE x
Index _ sh -> countE sh
LinearIndex _ i -> countE i
Shape _ -> Finite 0
ShapeSize _ sh -> countE sh
Foreign _ _ _ e -> countE e
Coerce _ _ e -> countE e
usesOfFun :: VarsRange env -> OpenFun env aenv f -> Count
usesOfFun range (Lam lhs f) = usesOfFun (weakenVarsRange lhs range) f
usesOfFun range (Body b) = usesOfExp range b
type UsesOfAcc acc = forall aenv s t. Bool -> Idx aenv s -> acc aenv t -> Int
usesOfPreAcc
:: forall acc aenv s t.
Bool
-> UsesOfAcc acc
-> Idx aenv s
-> PreOpenAcc acc aenv t
-> Int
usesOfPreAcc withShape countAcc idx = count
where
countIdx :: Idx aenv a -> Int
countIdx this
| Just Refl <- matchIdx this idx = 1
| otherwise = 0
count :: PreOpenAcc acc aenv a -> Int
count pacc = case pacc of
Avar var -> countAvar var
Alet lhs bnd body -> countA bnd + countAcc withShape (weakenWithLHS lhs >:> idx) body
Apair a1 a2 -> countA a1 + countA a2
Anil -> 0
Apply _ f a -> countAF f idx + countA a
Aforeign _ _ _ a -> countA a
Acond p t e -> countE p + countA t + countA e
Awhile c f a -> 2 * countAF c idx + 2 * countAF f idx + countA a
Use _ _ -> 0
Unit _ e -> countE e
Reshape _ e a -> countE e + countA a
Generate _ e f -> countE e + countF f
Transform _ sh ix f a -> countE sh + countF ix + countF f + countA a
Replicate _ sh a -> countE sh + countA a
Slice _ a sl -> countE sl + countA a
Map _ f a -> countF f + countA a
ZipWith _ f a1 a2 -> countF f + countA a1 + countA a2
Fold f z a -> countF f + countME z + countA a
FoldSeg _ f z a s -> countF f + countME z + countA a + countA s
Scan _ f z a -> countF f + countME z + countA a
Scan' _ f z a -> countF f + countE z + countA a
Permute f1 a1 f2 a2 -> countF f1 + countA a1 + countF f2 + countA a2
Backpermute _ sh f a -> countE sh + countF f + countA a
Stencil _ _ f _ a -> countF f + countA a
Stencil2 _ _ _ f _ a1 _ a2 -> countF f + countA a1 + countA a2
countE :: OpenExp env aenv e -> Int
countE exp = case exp of
Let _ bnd body -> countE bnd + countE body
Evar _ -> 0
Const _ _ -> 0
Undef _ -> 0
Nil -> 0
Pair x y -> countE x + countE y
VecPack _ e -> countE e
VecUnpack _ e -> countE e
IndexSlice _ ix sh -> countE ix + countE sh
IndexFull _ ix sl -> countE ix + countE sl
ToIndex _ sh ix -> countE sh + countE ix
FromIndex _ sh i -> countE sh + countE i
Case e rhs def -> countE e + sum [ countE c | (_,c) <- rhs ] + maybe 0 countE def
Cond p t e -> countE p + countE t + countE e
While p f x -> countF p + countF f + countE x
PrimConst _ -> 0
PrimApp _ x -> countE x
Index a sh -> countAvar a + countE sh
LinearIndex a i -> countAvar a + countE i
ShapeSize _ sh -> countE sh
Shape a
| withShape -> countAvar a
| otherwise -> 0
Foreign _ _ _ e -> countE e
Coerce _ _ e -> countE e
countME :: Maybe (OpenExp env aenv e) -> Int
countME = maybe 0 countE
countA :: acc aenv a -> Int
countA = countAcc withShape idx
countAvar :: ArrayVar aenv a -> Int
countAvar (Var _ this) = countIdx this
countAF :: PreOpenAfun acc aenv' f
-> Idx aenv' s
-> Int
countAF (Alam lhs f) v = countAF f (weakenWithLHS lhs >:> v)
countAF (Abody a) v = countAcc withShape v a
countF :: OpenFun env aenv f -> Int
countF (Lam _ f) = countF f
countF (Body b) = countE b