{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo.Shrink
-- Copyright   : [2012..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- The shrinking substitution arises as a restriction of beta-reduction to cases
-- where the bound variable is used zero (dead-code elimination) or one (linear
-- inlining) times. By simplifying terms, the shrinking reduction can expose
-- opportunities for further optimisation.
--
-- TODO: replace with a linear shrinking algorithm; e.g.
--
--   * Andrew Appel & Trevor Jim, "Shrinking lambda expressions in linear time".
--
--   * Nick Benton, Andrew Kennedy, Sam Lindley and Claudio Russo, "Shrinking
--     Reductions in SML.NET"
--

module Data.Array.Accelerate.Trafo.Shrink (

  -- Shrinking
  ShrinkAcc,
  shrinkExp,
  shrinkFun,

  -- Occurrence counting
  UsesOfAcc, usesOfPreAcc, usesOfExp,

) where

import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Trafo.Substitution

import qualified Data.Array.Accelerate.Debug.Stats                  as Stats

import Control.Applicative                                          hiding ( Const )
import Data.Maybe                                                   ( isJust )
import Data.Monoid
import Data.Semigroup
import Prelude                                                      hiding ( exp, seq )


data VarsRange env =
  VarsRange !(Exists (Idx env))     -- rightmost variable
            {-# UNPACK #-} !Int     -- count
            !(Maybe RangeTuple)     -- tuple

data RangeTuple
  = RTNil
  | RTSingle
  | RTPair !RangeTuple !RangeTuple

lhsVarsRange :: LeftHandSide s v env env' -> Either (env :~: env') (VarsRange env')
lhsVarsRange lhs = case rightIx lhs of
  Left eq -> Left eq
  Right ix -> let (n, rt) = go lhs
              in  Right $ VarsRange ix n rt
  where
    rightIx :: LeftHandSide s v env env' -> Either (env :~: env') (Exists (Idx env'))
    rightIx (LeftHandSideWildcard _) = Left Refl
    rightIx (LeftHandSideSingle _)   = Right $ Exists ZeroIdx
    rightIx (LeftHandSidePair l1 l2) = case rightIx l2 of
      Right ix  -> Right ix
      Left Refl -> rightIx l1

    go :: LeftHandSide s v env env' -> (Int, Maybe (RangeTuple))
    go (LeftHandSideWildcard TupRunit)   = (0,       Just RTNil)
    go (LeftHandSideWildcard _)          = (0,       Nothing)
    go (LeftHandSideSingle _)            = (1,       Just RTSingle)
    go (LeftHandSidePair l1 l2)          = (n1 + n2, RTPair <$> t1 <*> t2)
      where
        (n1, t1) = go l1
        (n2, t2) = go l2

weakenVarsRange :: LeftHandSide s v env env' -> VarsRange env -> VarsRange env'
weakenVarsRange lhs (VarsRange ix n t) = VarsRange (go lhs ix) n t
  where
    go :: LeftHandSide s v env env' -> Exists (Idx env) -> Exists (Idx env')
    go (LeftHandSideWildcard _) ix'          = ix'
    go (LeftHandSideSingle _)   (Exists ix') = Exists (SuccIdx ix')
    go (LeftHandSidePair l1 l2) ix'          = go l2 $ go l1 ix'

matchEVarsRange :: VarsRange env -> OpenExp env aenv t -> Bool
matchEVarsRange (VarsRange (Exists first) _ (Just rt)) expr = isJust $ go (idxToInt first) rt expr
  where
    go :: Int -> RangeTuple -> OpenExp env aenv t -> Maybe Int
    go i RTNil Nil = Just i
    go i RTSingle (Evar (Var _ ix))
      | checkIdx i ix = Just (i + 1)
    go i (RTPair t1 t2) (Pair e1 e2)
      | Just i' <- go i t2 e2 = go i' t1 e1
    go _ _ _ = Nothing

    checkIdx :: Int -> Idx env t ->  Bool
    checkIdx 0 ZeroIdx = True
    checkIdx i (SuccIdx ix) = checkIdx (i - 1) ix
    checkIdx _ _ = False
matchEVarsRange _ _ = False

varInRange :: VarsRange env -> Var s env t -> Maybe Usages
varInRange (VarsRange (Exists rangeIx) n _) (Var _ varIx) = case go rangeIx varIx of
    Nothing -> Nothing
    Just j  -> Just $ replicate j False ++ [True] ++ replicate (n - j - 1) False
  where
    -- `go ix ix'` checks whether ix <= ix' with recursion, and then checks
    -- whether ix' < ix + n in go'. Returns a Just if both checks
    -- are successful, containing an integer j such that ix + j = ix'.
    go :: Idx env u -> Idx env t -> Maybe Int
    go (SuccIdx ix) (SuccIdx ix') = go ix ix'
    go ZeroIdx      ix'           = go' ix' 0
    go _            ZeroIdx       = Nothing

    go' :: Idx env t -> Int -> Maybe Int
    go' _ j | j >= n    = Nothing
    go' ZeroIdx       j = Just j
    go' (SuccIdx ix') j = go' ix' (j + 1)

-- Describes how often the variables defined in a LHS are used together.
data Count
  = Impossible !Usages
      -- Cannot inline this definition. This happens when the definition
      -- declares multiple variables (the right hand side returns a tuple)
      -- and the variables are used seperately.
  | Infinity
      -- The variable is used in a loop. Inlining should only proceed if
      -- the computation is cheap.
  | Finite {-# UNPACK #-} !Int

type Usages = [Bool] -- Per variable a Boolean denoting whether that variable is used.

instance Semigroup Count where
  Impossible u1 <> Impossible u2 = Impossible $ zipWith (||) u1 u2
  Impossible u  <> Finite 0      = Impossible u
  Finite 0      <> Impossible u  = Impossible u
  Impossible u  <> _             = Impossible $ map (const True) u
  _             <> Impossible u  = Impossible $ map (const True) u
  Infinity      <> _             = Infinity
  _             <> Infinity      = Infinity
  Finite a      <> Finite b      = Finite $ a + b

instance Monoid Count where
  mempty = Finite 0

loopCount :: Count -> Count
loopCount (Finite n) | n > 0 = Infinity
loopCount c                  = c

shrinkLhs
    :: HasCallStack
    => Count
    -> LeftHandSide s t env1 env2
    -> Maybe (Exists (LeftHandSide s t env1))
shrinkLhs _ (LeftHandSideWildcard _) = Nothing -- We cannot shrink this
shrinkLhs (Finite 0)          lhs = Just $ Exists $ LeftHandSideWildcard $ lhsToTupR lhs -- LHS isn't used at all, replace with a wildcard
shrinkLhs (Impossible usages) lhs = case go usages lhs of
    (True , [], lhs') -> Just lhs'
    (False, [], _   ) -> Nothing -- No variables were dropped. Thus lhs == lhs'.
    _                 -> internalError "Mismatch in length of usages array and LHS"
  where
    go :: HasCallStack => Usages -> LeftHandSide s t env1 env2 -> (Bool, Usages, Exists (LeftHandSide s t env1))
    go us           (LeftHandSideWildcard tp) = (False, us, Exists $ LeftHandSideWildcard tp)
    go (True  : us) (LeftHandSideSingle tp)   = (False, us, Exists $ LeftHandSideSingle tp)
    go (False : us) (LeftHandSideSingle tp)   = (True , us, Exists $ LeftHandSideWildcard $ TupRsingle tp)
    go us           (LeftHandSidePair l1 l2)
      | (c2, us' , Exists l2') <- go us  l2
      , (c1, us'', Exists l1') <- go us' l1
      , Exists l2'' <- rebuildLHS l2'
      = let
          lhs'
            | LeftHandSideWildcard t1 <- l1'
            , LeftHandSideWildcard t2 <- l2'' = LeftHandSideWildcard $ TupRpair t1 t2
            | otherwise = LeftHandSidePair l1' l2''
        in
          (c1 || c2, us'', Exists lhs')
    go _ _ = internalError "Empty array, mismatch in length of usages array and LHS"
shrinkLhs _ _ = Nothing

-- The first LHS should be 'larger' than the second, eg the second may have
-- a wildcard if the first LHS does bind variables there, but not the other
-- way around.
--
strengthenShrunkLHS
    :: HasCallStack
    => LeftHandSide s t env1 env2
    -> LeftHandSide s t env1' env2'
    -> env1 :?> env1'
    -> env2 :?> env2'
strengthenShrunkLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k
strengthenShrunkLHS (LeftHandSideSingle _)   (LeftHandSideSingle _)   k = \ix -> case ix of
  ZeroIdx     -> Just ZeroIdx
  SuccIdx ix' -> SuccIdx <$> k ix'
strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k
strengthenShrunkLHS (LeftHandSideSingle _)   (LeftHandSideWildcard _) k = \ix -> case ix of
  ZeroIdx     -> Nothing
  SuccIdx ix' -> k ix'
strengthenShrunkLHS (LeftHandSidePair l h)   (LeftHandSideWildcard t) k = strengthenShrunkLHS h (LeftHandSideWildcard t2) $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k
  where
    TupRpair t1 t2 = t
strengthenShrunkLHS (LeftHandSideWildcard _) _                        _ = internalError "Second LHS defines more variables"
strengthenShrunkLHS _                        _                        _ = internalError "Mismatch LHS single with LHS pair"


-- Shrinking
-- =========

-- The shrinking substitution for scalar expressions. This is a restricted
-- instance of beta-reduction to cases where the bound variable is used zero
-- (dead-code elimination) or one (linear inlining) times.
--
shrinkExp :: HasCallStack => OpenExp env aenv t -> (Bool, OpenExp env aenv t)
shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE
  where
    -- If the bound variable is used at most this many times, it will be inlined
    -- into the body. In cases where it is not used at all, this is equivalent
    -- to dead-code elimination.
    --
    lIMIT :: Int
    lIMIT = 1

    cheap :: OpenExp env aenv t -> Bool
    cheap (Evar _)       = True
    cheap (Pair e1 e2)   = cheap e1 && cheap e2
    cheap Nil            = True
    cheap Const{}        = True
    cheap PrimConst{}    = True
    cheap Undef{}        = True
    cheap (Coerce _ _ e) = cheap e
    cheap _              = False

    shrinkE :: HasCallStack => OpenExp env aenv t -> (Any, OpenExp env aenv t)
    shrinkE exp = case exp of
      Let (LeftHandSideSingle _) bnd@Evar{} body -> Stats.inline "Var"   . yes $ shrinkE (inline body bnd)
      Let lhs bnd body
        | shouldInline -> case inlineVars lhs (snd body') (snd bnd') of
            Just inlined -> Stats.betaReduce msg . yes $ shrinkE inlined
            _            -> internalError "Unexpected failure while trying to inline some expression."
        | Just (Exists lhs') <- shrinkLhs count lhs -> case strengthenE (strengthenShrunkLHS lhs lhs' Just) (snd body') of
           Just body'' -> (Any True, Let lhs' (snd bnd') body'')
           Nothing     -> internalError "Unexpected failure in strenthenE. Variable was analysed to be unused in usesOfExp, but appeared to be used in strenthenE."
        | otherwise    -> Let lhs <$> bnd' <*> body'
        where
          shouldInline = case count of
            Finite 0     -> False -- Handled by shrinkLhs
            Finite n     -> n <= lIMIT || cheap (snd bnd')
            Infinity     ->               cheap (snd bnd')
            Impossible _ -> False

          bnd'  = shrinkE bnd
          body' = shrinkE body

          -- If the lhs includes non-trivial wildcards (the last field of range is Nothing),
          -- then we cannot inline the binding. We can only check which variables are not used,
          -- to detect unused variables.
          --
          -- If the lhs does not include non-trivial wildcards (the last field of range is a Just),
          -- we can both analyse whether we can inline the binding, and check which variables are
          -- not used, to detect unused variables.
          --
          count = case lhsVarsRange lhs of
            Left _      -> Finite 0
            Right range -> usesOfExp range (snd body')

          msg = case count of
            Finite 0 -> "dead exp"
            _        -> "inline exp"   -- forced inlining when lIMIT > 1
      --
      Evar v                    -> pure (Evar v)
      Const t c                 -> pure (Const t c)
      Undef t                   -> pure (Undef t)
      Nil                       -> pure Nil
      Pair x y                  -> Pair <$> shrinkE x <*> shrinkE y
      VecPack   vec e           -> VecPack   vec <$> shrinkE e
      VecUnpack vec e           -> VecUnpack vec <$> shrinkE e
      IndexSlice x ix sh        -> IndexSlice x <$> shrinkE ix <*> shrinkE sh
      IndexFull x ix sl         -> IndexFull x <$> shrinkE ix <*> shrinkE sl
      ToIndex shr sh ix         -> ToIndex shr <$> shrinkE sh <*> shrinkE ix
      FromIndex shr sh i        -> FromIndex shr <$> shrinkE sh <*> shrinkE i
      Case e rhs def            -> Case <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def
      Cond p t e                -> Cond <$> shrinkE p <*> shrinkE t <*> shrinkE e
      While p f x               -> While <$> shrinkF p <*> shrinkF f <*> shrinkE x
      PrimConst c               -> pure (PrimConst c)
      PrimApp f x               -> PrimApp f <$> shrinkE x
      Index a sh                -> Index a <$> shrinkE sh
      LinearIndex a i           -> LinearIndex a <$> shrinkE i
      Shape a                   -> pure (Shape a)
      ShapeSize shr sh          -> ShapeSize shr <$> shrinkE sh
      Foreign repr ff f e       -> Foreign repr ff <$> shrinkF f <*> shrinkE e
      Coerce t1 t2 e            -> Coerce t1 t2 <$> shrinkE e

    shrinkF :: HasCallStack => OpenFun env aenv t -> (Any, OpenFun env aenv t)
    shrinkF = first Any . shrinkFun

    shrinkMaybeE :: HasCallStack => Maybe (OpenExp env aenv t) -> (Any, Maybe (OpenExp env aenv t))
    shrinkMaybeE Nothing  = pure Nothing
    shrinkMaybeE (Just e) = Just <$> shrinkE e

    first :: (a -> a') -> (a,b) -> (a',b)
    first f (x,y) = (f x, y)

    yes :: (Any, x) -> (Any, x)
    yes (_, x) = (Any True, x)

shrinkFun :: HasCallStack => OpenFun env aenv f -> (Bool, OpenFun env aenv f)
shrinkFun (Lam lhs f) = case lhsVarsRange lhs of
  Left Refl ->
    let b' = case lhs of
                LeftHandSideWildcard TupRunit -> b
                _                             -> True
    in (b', Lam (LeftHandSideWildcard $ lhsToTupR lhs) f')
  Right range ->
    let
      count = usesOfFun range f
    in case shrinkLhs count lhs of
        Just (Exists lhs') -> case strengthenE (strengthenShrunkLHS lhs lhs' Just) f' of
          Just f'' -> (True, Lam lhs' f'')
          Nothing  -> internalError "Unexpected failure in strenthenE. Variable was analysed to be unused in usesOfExp, but appeared to be used in strenthenE."
        Nothing -> (b, Lam lhs f')
  where
    (b, f') = shrinkFun f

shrinkFun (Body b) = Body <$> shrinkExp b

-- The shrinking substitution for array computations. This is further limited to
-- dead-code elimination only, primarily because linear inlining may inline
-- array computations into scalar expressions, which is generally not desirable.
--
type ShrinkAcc acc = forall aenv a. acc aenv a -> acc aenv a

{--
type ReduceAcc acc = forall aenv s t. acc aenv s -> acc (aenv,s) t -> Maybe (PreOpenAcc acc aenv t)

shrinkPreAcc
    :: forall acc aenv arrs. ShrinkAcc acc -> ReduceAcc acc
    -> PreOpenAcc acc aenv arrs
    -> PreOpenAcc acc aenv arrs
shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA
  where
    shrinkA :: PreOpenAcc acc aenv' a -> PreOpenAcc acc aenv' a
    shrinkA pacc = case pacc of
      Alet lhs bnd body
        | Just reduct <- reduceAcc bnd' body'   -> shrinkA reduct
        | otherwise                             -> Alet lhs bnd' body'
        where
          bnd'  = shrinkAcc bnd
          body' = shrinkAcc body
      --
      Avar ix                   -> Avar ix
      Apair a1 a2               -> Apair (shrinkAcc a1) (shrinkAcc a2)
      Anil                      -> Anil
      Apply repr f a            -> Apply repr (shrinkAF f) (shrinkAcc a)
      Aforeign ff af a          -> Aforeign ff af (shrinkAcc a)
      Acond p t e               -> Acond (shrinkE p) (shrinkAcc t) (shrinkAcc e)
      Awhile p f a              -> Awhile (shrinkAF p) (shrinkAF f) (shrinkAcc a)
      Use repr a                -> Use repr a
      Unit e                    -> Unit (shrinkE e)
      Reshape e a               -> Reshape (shrinkE e) (shrinkAcc a)
      Generate e f              -> Generate (shrinkE e) (shrinkF f)
      Transform sh ix f a       -> Transform (shrinkE sh) (shrinkF ix) (shrinkF f) (shrinkAcc a)
      Replicate sl slix a       -> Replicate sl (shrinkE slix) (shrinkAcc a)
      Slice sl a slix           -> Slice sl (shrinkAcc a) (shrinkE slix)
      Map f a                   -> Map (shrinkF f) (shrinkAcc a)
      ZipWith f a1 a2           -> ZipWith (shrinkF f) (shrinkAcc a1) (shrinkAcc a2)
      Fold f z a                -> Fold (shrinkF f) (shrinkE z) (shrinkAcc a)
      Fold1 f a                 -> Fold1 (shrinkF f) (shrinkAcc a)
      FoldSeg f z a b           -> FoldSeg (shrinkF f) (shrinkE z) (shrinkAcc a) (shrinkAcc b)
      Fold1Seg f a b            -> Fold1Seg (shrinkF f) (shrinkAcc a) (shrinkAcc b)
      Scanl f z a               -> Scanl (shrinkF f) (shrinkE z) (shrinkAcc a)
      Scanl' f z a              -> Scanl' (shrinkF f) (shrinkE z) (shrinkAcc a)
      Scanl1 f a                -> Scanl1 (shrinkF f) (shrinkAcc a)
      Scanr f z a               -> Scanr (shrinkF f) (shrinkE z) (shrinkAcc a)
      Scanr' f z a              -> Scanr' (shrinkF f) (shrinkE z) (shrinkAcc a)
      Scanr1 f a                -> Scanr1 (shrinkF f) (shrinkAcc a)
      Permute f1 a1 f2 a2       -> Permute (shrinkF f1) (shrinkAcc a1) (shrinkF f2) (shrinkAcc a2)
      Backpermute sh f a        -> Backpermute (shrinkE sh) (shrinkF f) (shrinkAcc a)
      Stencil f b a             -> Stencil (shrinkF f) b (shrinkAcc a)
      Stencil2 f b1 a1 b2 a2    -> Stencil2 (shrinkF f) b1 (shrinkAcc a1) b2 (shrinkAcc a2)
      -- Collect s                 -> Collect (shrinkS s)

{--
    shrinkS :: PreOpenSeq acc aenv' senv a -> PreOpenSeq acc aenv' senv a
    shrinkS seq =
      case seq of
        Producer p s -> Producer (shrinkP p) (shrinkS s)
        Consumer c   -> Consumer (shrinkC c)
        Reify ix     -> Reify ix

    shrinkP :: Producer acc aenv' senv a -> Producer acc aenv' senv a
    shrinkP p =
      case p of
        StreamIn arrs        -> StreamIn arrs
        ToSeq sl slix a      -> ToSeq sl slix (shrinkAcc a)
        MapSeq f x           -> MapSeq (shrinkAF f) x
        ChunkedMapSeq f x    -> ChunkedMapSeq (shrinkAF f) x
        ZipWithSeq f x y     -> ZipWithSeq (shrinkAF f) x y
        ScanSeq f e x        -> ScanSeq (shrinkF f) (shrinkE e) x

    shrinkC :: Consumer acc aenv' senv a -> Consumer acc aenv' senv a
    shrinkC c =
      case c of
        FoldSeq f e x        -> FoldSeq (shrinkF f) (shrinkE e) x
        FoldSeqFlatten f a x -> FoldSeqFlatten (shrinkAF f) (shrinkAcc a) x
        Stuple t             -> Stuple (shrinkCT t)

    shrinkCT :: Atuple (Consumer acc aenv' senv) t -> Atuple (Consumer acc aenv' senv) t
    shrinkCT NilAtup        = NilAtup
    shrinkCT (SnocAtup t c) = SnocAtup (shrinkCT t) (shrinkC c)
--}

    shrinkE :: OpenExp env aenv' t -> OpenExp env aenv' t
    shrinkE exp = case exp of
      Let bnd body              -> Let (shrinkE bnd) (shrinkE body)
      Var idx                   -> Var idx
      Const c                   -> Const c
      Undef                     -> Undef
      Tuple t                   -> Tuple (shrinkT t)
      Prj tup e                 -> Prj tup (shrinkE e)
      IndexNil                  -> IndexNil
      IndexCons sl sz           -> IndexCons (shrinkE sl) (shrinkE sz)
      IndexHead sh              -> IndexHead (shrinkE sh)
      IndexTail sh              -> IndexTail (shrinkE sh)
      IndexSlice x ix sh        -> IndexSlice x (shrinkE ix) (shrinkE sh)
      IndexFull x ix sl         -> IndexFull x (shrinkE ix) (shrinkE sl)
      IndexAny                  -> IndexAny
      ToIndex sh ix             -> ToIndex (shrinkE sh) (shrinkE ix)
      FromIndex sh i            -> FromIndex (shrinkE sh) (shrinkE i)
      Cond p t e                -> Cond (shrinkE p) (shrinkE t) (shrinkE e)
      While p f x               -> While (shrinkF p) (shrinkF f) (shrinkE x)
      PrimConst c               -> PrimConst c
      PrimApp f x               -> PrimApp f (shrinkE x)
      Index a sh                -> Index (shrinkAcc a) (shrinkE sh)
      LinearIndex a i           -> LinearIndex (shrinkAcc a) (shrinkE i)
      Shape a                   -> Shape (shrinkAcc a)
      ShapeSize sh              -> ShapeSize (shrinkE sh)
      Intersect sh sz           -> Intersect (shrinkE sh) (shrinkE sz)
      Union sh sz               -> Union (shrinkE sh) (shrinkE sz)
      Foreign ff f e            -> Foreign ff (shrinkF f) (shrinkE e)
      Coerce e                  -> Coerce (shrinkE e)

    shrinkF :: OpenFun env aenv' f -> OpenFun env aenv' f
    shrinkF (Lam f)  = Lam (shrinkF f)
    shrinkF (Body b) = Body (shrinkE b)

    shrinkT :: Tuple (OpenExp env aenv') t -> Tuple (OpenExp env aenv') t
    shrinkT NilTup        = NilTup
    shrinkT (SnocTup t e) = shrinkT t `SnocTup` shrinkE e

    shrinkAF :: PreOpenAfun acc aenv' f -> PreOpenAfun acc aenv' f
    shrinkAF (Alam lhs f) = Alam lhs (shrinkAF f)
    shrinkAF (Abody a) = Abody (shrinkAcc a)
--}

-- Occurrence Counting
-- ===================

-- Count the number of occurrences an in-scope scalar expression bound at the
-- given variable index recursively in a term.
--
usesOfExp :: forall env aenv t. VarsRange env -> OpenExp env aenv t -> Count
usesOfExp range = countE
  where
    countE :: OpenExp env aenv e -> Count
    countE exp | matchEVarsRange range exp = Finite 1
    countE exp = case exp of
      Evar v -> case varInRange range v of
        Just cs                 -> Impossible cs
        Nothing                 -> Finite 0
      --
      Let lhs bnd body          -> countE bnd <> usesOfExp (weakenVarsRange lhs range) body
      Const _ _                 -> Finite 0
      Undef _                   -> Finite 0
      Nil                       -> Finite 0
      Pair e1 e2                -> countE e1 <> countE e2
      VecPack   _ e             -> countE e
      VecUnpack _ e             -> countE e
      IndexSlice _ ix sh        -> countE ix <> countE sh
      IndexFull _ ix sl         -> countE ix <> countE sl
      FromIndex _ sh i          -> countE sh <> countE i
      ToIndex _ sh e            -> countE sh <> countE e
      Case e rhs def            -> countE e  <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def
      Cond p t e                -> countE p  <> countE t <> countE e
      While p f x               -> countE x  <> loopCount (usesOfFun range p) <> loopCount (usesOfFun range f)
      PrimConst _               -> Finite 0
      PrimApp _ x               -> countE x
      Index _ sh                -> countE sh
      LinearIndex _ i           -> countE i
      Shape _                   -> Finite 0
      ShapeSize _ sh            -> countE sh
      Foreign _ _ _ e           -> countE e
      Coerce _ _ e              -> countE e

usesOfFun :: VarsRange env -> OpenFun env aenv f -> Count
usesOfFun range (Lam lhs f) = usesOfFun (weakenVarsRange lhs range) f
usesOfFun range (Body b)    = usesOfExp range b

-- Count the number of occurrences of the array term bound at the given
-- environment index. If the first argument is 'True' then it includes in the
-- total uses of the variable for 'Shape' information, otherwise not.
--
type UsesOfAcc acc = forall aenv s t. Bool -> Idx aenv s -> acc aenv t -> Int

-- XXX: Should this be converted to use the above 'Count' semigroup?
--
usesOfPreAcc
    :: forall acc aenv s t.
       Bool
    -> UsesOfAcc  acc
    -> Idx            aenv s
    -> PreOpenAcc acc aenv t
    -> Int
usesOfPreAcc withShape countAcc idx = count
  where
    countIdx :: Idx aenv a -> Int
    countIdx this
        | Just Refl <- matchIdx this idx = 1
        | otherwise                      = 0

    count :: PreOpenAcc acc aenv a -> Int
    count pacc = case pacc of
      Avar var                   -> countAvar var
      --
      Alet lhs bnd body          -> countA bnd + countAcc withShape (weakenWithLHS lhs >:> idx) body
      Apair a1 a2                -> countA a1 + countA a2
      Anil                       -> 0
      Apply _ f a                -> countAF f idx + countA a
      Aforeign _ _ _ a           -> countA a
      Acond p t e                -> countE p + countA t + countA e
      -- Body and condition of the while loop may be evaluated multiple times.
      -- We multiply the usage count, as a practical solution to this. As
      -- we will check whether the count is at most 1, we will thus never
      -- inline variables used in while loops.
      Awhile c f a               -> 2 * countAF c idx + 2 * countAF f idx + countA a
      Use _ _                    -> 0
      Unit _ e                   -> countE e
      Reshape _ e a              -> countE e  + countA a
      Generate _ e f             -> countE e  + countF f
      Transform _ sh ix f a      -> countE sh + countF ix + countF f  + countA a
      Replicate _ sh a           -> countE sh + countA a
      Slice _ a sl               -> countE sl + countA a
      Map _ f a                  -> countF f  + countA a
      ZipWith _ f a1 a2          -> countF f  + countA a1 + countA a2
      Fold f z a                 -> countF f  + countME z + countA a
      FoldSeg _ f z a s          -> countF f  + countME z + countA a  + countA s
      Scan  _ f z a              -> countF f  + countME z + countA a
      Scan' _ f z a              -> countF f  + countE z  + countA a
      Permute f1 a1 f2 a2        -> countF f1 + countA a1 + countF f2 + countA a2
      Backpermute _ sh f a       -> countE sh + countF f  + countA a
      Stencil _ _ f _ a          -> countF f  + countA a
      Stencil2 _ _ _ f _ a1 _ a2 -> countF f  + countA a1 + countA a2
      -- Collect s                 -> countS s

    countE :: OpenExp env aenv e -> Int
    countE exp = case exp of
      Let _ bnd body             -> countE bnd + countE body
      Evar _                     -> 0
      Const _ _                  -> 0
      Undef _                    -> 0
      Nil                        -> 0
      Pair x y                   -> countE x + countE y
      VecPack   _ e              -> countE e
      VecUnpack _ e              -> countE e
      IndexSlice _ ix sh         -> countE ix + countE sh
      IndexFull _ ix sl          -> countE ix + countE sl
      ToIndex _ sh ix            -> countE sh + countE ix
      FromIndex _ sh i           -> countE sh + countE i
      Case e rhs def             -> countE e  + sum [ countE c | (_,c) <- rhs ] + maybe 0 countE def
      Cond p t e                 -> countE p  + countE t + countE e
      While p f x                -> countF p  + countF f + countE x
      PrimConst _                -> 0
      PrimApp _ x                -> countE x
      Index a sh                 -> countAvar a + countE sh
      LinearIndex a i            -> countAvar a + countE i
      ShapeSize _ sh             -> countE sh
      Shape a
        | withShape              -> countAvar a
        | otherwise              -> 0
      Foreign _ _ _ e            -> countE e
      Coerce _ _ e               -> countE e

    countME :: Maybe (OpenExp env aenv e) -> Int
    countME = maybe 0 countE

    countA :: acc aenv a -> Int
    countA = countAcc withShape idx

    countAvar :: ArrayVar aenv a -> Int
    countAvar (Var _ this) = countIdx this

    countAF :: PreOpenAfun acc aenv' f
            -> Idx aenv' s
            -> Int
    countAF (Alam lhs f) v = countAF f (weakenWithLHS lhs >:> v)
    countAF (Abody a)    v = countAcc withShape v a

    countF :: OpenFun env aenv f -> Int
    countF (Lam _ f) = countF f
    countF (Body  b) = countE b

{--
    countS :: PreOpenSeq acc aenv senv arrs -> Int
    countS seq =
      case seq of
        Producer p s -> countP p + countS s
        Consumer c   -> countC c
        Reify _      -> 0

    countP :: Producer acc aenv senv arrs -> Int
    countP p =
      case p of
        StreamIn _           -> 0
        ToSeq _ _ a          -> countA a
        MapSeq f _           -> countAF f idx
        ChunkedMapSeq f _    -> countAF f idx
        ZipWithSeq f _ _     -> countAF f idx
        ScanSeq f e _        -> countF f + countE e

    countC :: Consumer acc aenv senv arrs -> Int
    countC c =
      case c of
        FoldSeq f e _        -> countF f + countE e
        FoldSeqFlatten f a _ -> countAF f idx + countA a
        Stuple t             -> countCT t

    countCT :: Atuple (Consumer acc aenv senv) t' -> Int
    countCT NilAtup        = 0
    countCT (SnocAtup t c) = countCT t + countC c
--}