{-# LANGUAGE CPP                  #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE PatternGuards        #-}
{-# LANGUAGE RankNTypes           #-}
{-# LANGUAGE RecordWildCards      #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TupleSections        #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE ViewPatterns         #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo.Simplify
-- Copyright   : [2012..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Trafo.Simplify (

  simplifyFun,
  simplifyExp

) where

import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Hash
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array                   ( Array, ArrayR(..) )
import Data.Array.Accelerate.Representation.Shape                   ( ShapeR(..), shapeToList )
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Trafo.Algebra
import Data.Array.Accelerate.Trafo.Environment
import Data.Array.Accelerate.Trafo.Shrink
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Type

import qualified Data.Array.Accelerate.Debug.Stats                  as Stats
import qualified Data.Array.Accelerate.Debug.Flags                  as Debug
import qualified Data.Array.Accelerate.Debug.Trace                  as Debug

import Control.Applicative                                          hiding ( Const )
import Control.Lens                                                 hiding ( Const, ix )
import Data.List                                                    ( partition )
import Data.Maybe
import Data.Monoid
import Text.Printf
import Prelude                                                      hiding ( exp, iterate )
import qualified Data.Map.Strict                                    as Map


-- Scalar optimisations
-- ====================

{--
-- Common subexpression elimination finds computations that are performed at
-- least twice on a given execution path and eliminates the second and later
-- occurrences, replacing them with uses of saved values. This implements a
-- simplified version of that idea, where we look for the expressions of the
-- form:
--
--   let x = e1 in e2
--
-- and replace all occurrences of e1 in e2 with x. This is not full redundancy
-- elimination, but good enough to catch some cases, and in particular those
-- likely to be introduced by scalar composition of terms in the fusion process.
--
-- While it may seem that common subexpression elimination is always worthwhile,
-- as it reduces the number of arithmetic operations performed, this is not
-- necessarily advantageous. The simplest case in which it may not be desirable
-- is if it causes a register to be occupied for a long time in order to hold
-- the shared expression's value, which hence reduces the number of registers
-- available for other uses. Even worse is if the value has to be spilled to
-- memory because there are insufficient registers available. We sidestep this
-- tricky and target-dependent issue by, for now, simply ignoring it.
--
localCSE :: (Kit acc, Elt a)
         => Gamma acc env env aenv
         -> OpenExp env aenv a
         -> OpenExp (env,a) aenv b
         -> Maybe (OpenExp env aenv b)
localCSE env bnd body
  | Just ix <- lookupExp env bnd = Stats.ruleFired "CSE" . Just $ inline body (Var ix)
  | otherwise                    = Nothing
--}
{--
-- Common subexpression elimination, which attempts to match the given
-- expression against something already bound in the environment. This can occur
-- due to simplification, in which case we replace the entire subterm with x.
--
-- > let x = e in .. e ..
--
globalCSE :: (Kit acc, Elt t)
          => Gamma acc env env aenv
          -> OpenExp env aenv t
          -> Maybe (OpenExp env aenv t)
globalCSE env exp
  | Just ix <- lookupExp env exp = Stats.ruleFired "CSE" . Just $ Var ix
  | otherwise                    = Nothing
--}

{--
-- Compared to regular Haskell, the scalar expression language of Accelerate is
-- rather limited in order to meet the restrictions of what can be efficiently
-- implemented on specialised hardware, such as GPUs. For example, to avoid
-- excessive SIMD divergence, we do not support any form of recursion or
-- iteration in scalar expressions. This harmonises well with the stratified
-- design of the Accelerate language: collective array operations comprise many
-- scalar computations that are executed in parallel, so for simplicity of
-- scheduling these operations we would like some assurance that each scalar
-- computation takes approximately the same time to execute as all others.
--
-- However, some computations are naturally expressed in terms of iteration. For
-- some problems, we can instead use generative techniques to implement the
-- program by defining a single step of a recurrence relation as an Accelerate
-- collective operation and using standard Haskell to unroll the loop a _fixed_
-- number of times.
--
-- However, this is outrageously slow because the intermediate values are
-- written to memory at the end of every iteration. Luckily the fusion process
-- will eliminate this intermediate memory traffic by combining the 'n'
-- collective operations into a single operation with 'n' instances of the loop
-- body. However, doing this we uncover an embarrassing secret: C compilers do
-- not compile C code, they compile _idiomatic_ C code.
--
-- This process recovers the iteration structure that was lost in the process of
-- fusing the collective operations. This allows a backend to generate explicit
-- loops in its target language.
--
recoverLoops
    :: (Kit acc, Elt b)
    => Gamma acc env env aenv
    -> OpenExp env aenv a
    -> OpenExp (env,a) aenv b
    -> Maybe (OpenExp env aenv b)
recoverLoops _ bnd e3
  -- To introduce scaler loops, we look for expressions of the form:
  --
  --   let x =
  --     let y = e1 in e2
  --   in e3
  --
  -- and if e2 and e3 are congruent, replace with:
  --
  --   iterate[2] (\y -> e2) e1
  --
  | Let e1 e2           <- bnd
  , Just Refl           <- matchEnvTop e2 e3
  , Just Refl           <- match e2 e3
  = Stats.ruleFired "loop recovery/intro" . Just
  $ Iterate (constant 2) e2 e1

  -- To merge expressions into a loop body, look for the pattern:
  --
  --   let x = iterate[n] f e1
  --   in e3
  --
  -- and if e3 matches the loop body, replace the let binding with the bare
  -- iteration with the trip count increased by one.
  --
  | Iterate n f e1      <- bnd
  , Just Refl           <- match f e3
  = Stats.ruleFired "loop recovery/merge" . Just
  $ Iterate (constant 1 `plus` n) f e1

  | otherwise
  = Nothing

  where
    plus :: OpenExp env aenv Int -> OpenExp env aenv Int -> OpenExp env aenv Int
    plus x y = PrimApp (PrimAdd numType) $ Tuple $ NilTup `SnocTup` x `SnocTup` y

    constant :: Int -> OpenExp env aenv Int
    constant i = Const ((),i)

    matchEnvTop :: (Elt s, Elt t)
                => OpenExp (env,s) aenv f
                -> OpenExp (env,t) aenv g
                -> Maybe (s :=: t)
    matchEnvTop _ _ = gcast Refl
--}


-- Walk a scalar expression applying simplifications to terms bottom-up.
--
-- TODO: Look for particular patterns of expressions that can be replaced by
--       something equivalent and simpler. In particular, indexing operations
--       introduced by the fusion transformation. This would benefit from a
--       rewrite rule schema.
--
-- TODO: We currently pass around an environment Gamma, but we do not use it.
--       It might be helpful to do some inlining if this enables other optimizations.
--       Eg, for `let x = -y in -x`, the inlining would allow us to shorten it to `y`.
--       If we do not want to do inlining, we should remove the environment here.
--
simplifyOpenExp
    :: forall env aenv e.
       Gamma env env aenv
    -> OpenExp env aenv e
    -> (Bool, OpenExp env aenv e)
simplifyOpenExp :: Gamma env env aenv
-> OpenExp env aenv e -> (Bool, OpenExp env aenv e)
simplifyOpenExp Gamma env env aenv
env = (Any -> Bool)
-> (Any, OpenExp env aenv e) -> (Bool, OpenExp env aenv e)
forall a a' b. (a -> a') -> (a, b) -> (a', b)
first Any -> Bool
getAny ((Any, OpenExp env aenv e) -> (Bool, OpenExp env aenv e))
-> (OpenExp env aenv e -> (Any, OpenExp env aenv e))
-> OpenExp env aenv e
-> (Bool, OpenExp env aenv e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenExp env aenv e -> (Any, OpenExp env aenv e)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE
  where
    cvtE :: OpenExp env aenv t -> (Any, OpenExp env aenv t)
    cvtE :: OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t
exp = case OpenExp env aenv t
exp of
      Let ELeftHandSide bnd_t env env'
lhs OpenExp env aenv bnd_t
bnd OpenExp env' aenv t
body -> (Any
u Any -> Any -> Any
forall a. Semigroup a => a -> a -> a
<> Any
v, OpenExp env aenv t
exp')
        where
          (Any
u, OpenExp env aenv bnd_t
bnd') = OpenExp env aenv bnd_t -> (Any, OpenExp env aenv bnd_t)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv bnd_t
bnd
          (Any
v, OpenExp env aenv t
exp') = Gamma env env aenv
-> ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> (Gamma env' env' aenv -> (Any, OpenExp env' aenv t))
-> (Any, OpenExp env aenv t)
forall env' bnd env'' t.
Gamma env' env' aenv
-> ELeftHandSide bnd env' env''
-> OpenExp env' aenv bnd
-> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t))
-> (Any, OpenExp env' aenv t)
cvtLet Gamma env env aenv
env ELeftHandSide bnd_t env env'
lhs OpenExp env aenv bnd_t
bnd' (\Gamma env' env' aenv
env' -> Gamma env' env' aenv
-> OpenExp env' aenv t -> (Any, OpenExp env' aenv t)
forall env' e'.
Gamma env' env' aenv
-> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e')
cvtE' Gamma env' env' aenv
env' OpenExp env' aenv t
body)
      Evar ExpVar env t
var                  -> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenExp env aenv t -> (Any, OpenExp env aenv t))
-> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall a b. (a -> b) -> a -> b
$ ExpVar env t -> OpenExp env aenv t
forall env t aenv. ExpVar env t -> OpenExp env aenv t
Evar ExpVar env t
var
      Const ScalarType t
tp t
c                -> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenExp env aenv t -> (Any, OpenExp env aenv t))
-> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall a b. (a -> b) -> a -> b
$ ScalarType t -> t -> OpenExp env aenv t
forall t env aenv. ScalarType t -> t -> OpenExp env aenv t
Const ScalarType t
tp t
c
      Undef ScalarType t
tp                  -> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenExp env aenv t -> (Any, OpenExp env aenv t))
-> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall a b. (a -> b) -> a -> b
$ ScalarType t -> OpenExp env aenv t
forall t env aenv. ScalarType t -> OpenExp env aenv t
Undef ScalarType t
tp
      OpenExp env aenv t
Nil                       -> OpenExp env aenv () -> (Any, OpenExp env aenv ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure OpenExp env aenv ()
forall env aenv. OpenExp env aenv ()
Nil
      Pair OpenExp env aenv t1
e1 OpenExp env aenv t2
e2                -> OpenExp env aenv t1
-> OpenExp env aenv t2 -> OpenExp env aenv (t1, t2)
forall env aenv t1 t2.
OpenExp env aenv t1
-> OpenExp env aenv t2 -> OpenExp env aenv (t1, t2)
Pair (OpenExp env aenv t1
 -> OpenExp env aenv t2 -> OpenExp env aenv (t1, t2))
-> (Any, OpenExp env aenv t1)
-> (Any, OpenExp env aenv t2 -> OpenExp env aenv (t1, t2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv t1 -> (Any, OpenExp env aenv t1)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t1
e1 (Any, OpenExp env aenv t2 -> OpenExp env aenv (t1, t2))
-> (Any, OpenExp env aenv t2) -> (Any, OpenExp env aenv (t1, t2))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpenExp env aenv t2 -> (Any, OpenExp env aenv t2)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t2
e2
      VecPack   VecR n s tup
vec OpenExp env aenv tup
e           -> VecR n s tup -> OpenExp env aenv tup -> OpenExp env aenv (Vec n s)
forall (n :: Nat) s tup env aenv.
KnownNat n =>
VecR n s tup -> OpenExp env aenv tup -> OpenExp env aenv (Vec n s)
VecPack   VecR n s tup
vec (OpenExp env aenv tup -> OpenExp env aenv (Vec n s))
-> (Any, OpenExp env aenv tup) -> (Any, OpenExp env aenv (Vec n s))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv tup -> (Any, OpenExp env aenv tup)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv tup
e
      VecUnpack VecR n s t
vec OpenExp env aenv (Vec n s)
e           -> VecR n s t -> OpenExp env aenv (Vec n s) -> OpenExp env aenv t
forall (n :: Nat) s tup env aenv.
KnownNat n =>
VecR n s tup -> OpenExp env aenv (Vec n s) -> OpenExp env aenv tup
VecUnpack VecR n s t
vec (OpenExp env aenv (Vec n s) -> OpenExp env aenv t)
-> (Any, OpenExp env aenv (Vec n s)) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv (Vec n s) -> (Any, OpenExp env aenv (Vec n s))
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv (Vec n s)
e
      IndexSlice SliceIndex slix t co sh
x OpenExp env aenv slix
ix OpenExp env aenv sh
sh        -> SliceIndex slix t co sh
-> OpenExp env aenv slix
-> OpenExp env aenv sh
-> OpenExp env aenv t
forall slix sl co sh env aenv.
SliceIndex slix sl co sh
-> OpenExp env aenv slix
-> OpenExp env aenv sh
-> OpenExp env aenv sl
IndexSlice SliceIndex slix t co sh
x (OpenExp env aenv slix
 -> OpenExp env aenv sh -> OpenExp env aenv t)
-> (Any, OpenExp env aenv slix)
-> (Any, OpenExp env aenv sh -> OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv slix -> (Any, OpenExp env aenv slix)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv slix
ix (Any, OpenExp env aenv sh -> OpenExp env aenv t)
-> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpenExp env aenv sh -> (Any, OpenExp env aenv sh)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv sh
sh
      IndexFull SliceIndex slix sl co t
x OpenExp env aenv slix
ix OpenExp env aenv sl
sl         -> SliceIndex slix sl co t
-> OpenExp env aenv slix
-> OpenExp env aenv sl
-> OpenExp env aenv t
forall slix sl co sh env aenv.
SliceIndex slix sl co sh
-> OpenExp env aenv slix
-> OpenExp env aenv sl
-> OpenExp env aenv sh
IndexFull SliceIndex slix sl co t
x (OpenExp env aenv slix
 -> OpenExp env aenv sl -> OpenExp env aenv t)
-> (Any, OpenExp env aenv slix)
-> (Any, OpenExp env aenv sl -> OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv slix -> (Any, OpenExp env aenv slix)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv slix
ix (Any, OpenExp env aenv sl -> OpenExp env aenv t)
-> (Any, OpenExp env aenv sl) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpenExp env aenv sl -> (Any, OpenExp env aenv sl)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv sl
sl
      ToIndex ShapeR sh
shr OpenExp env aenv sh
sh OpenExp env aenv sh
ix         -> ShapeR sh
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int)
forall sh.
ShapeR sh
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int)
toIndex ShapeR sh
shr (OpenExp env aenv sh -> (Any, OpenExp env aenv sh)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv sh
sh) (OpenExp env aenv sh -> (Any, OpenExp env aenv sh)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv sh
ix)
      FromIndex ShapeR t
shr OpenExp env aenv t
sh OpenExp env aenv Int
ix       -> ShapeR t
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv Int)
-> (Any, OpenExp env aenv t)
forall sh.
ShapeR sh
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int)
-> (Any, OpenExp env aenv sh)
fromIndex ShapeR t
shr (OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t
sh) (OpenExp env aenv Int -> (Any, OpenExp env aenv Int)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv Int
ix)
      Case OpenExp env aenv TAG
e [(TAG, OpenExp env aenv t)]
rhs Maybe (OpenExp env aenv t)
def            -> (Any, OpenExp env aenv TAG)
-> (Any, [(TAG, OpenExp env aenv t)])
-> (Any, Maybe (OpenExp env aenv t))
-> (Any, OpenExp env aenv t)
forall b.
(Any, OpenExp env aenv TAG)
-> (Any, [(TAG, OpenExp env aenv b)])
-> (Any, Maybe (OpenExp env aenv b))
-> (Any, OpenExp env aenv b)
caseof (OpenExp env aenv TAG -> (Any, OpenExp env aenv TAG)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv TAG
e) ([(Any, (TAG, OpenExp env aenv t))]
-> (Any, [(TAG, OpenExp env aenv t)])
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA [ (TAG
t,) (OpenExp env aenv t -> (TAG, OpenExp env aenv t))
-> (Any, OpenExp env aenv t) -> (Any, (TAG, OpenExp env aenv t))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t
c | (TAG
t,OpenExp env aenv t
c) <- [(TAG, OpenExp env aenv t)]
rhs ]) (Maybe (OpenExp env aenv t) -> (Any, Maybe (OpenExp env aenv t))
forall e'.
Maybe (OpenExp env aenv e') -> (Any, Maybe (OpenExp env aenv e'))
cvtMaybeE Maybe (OpenExp env aenv t)
def)
      Cond OpenExp env aenv TAG
p OpenExp env aenv t
t OpenExp env aenv t
e                -> (Any, OpenExp env aenv TAG)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
forall t.
(Any, OpenExp env aenv TAG)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
cond (OpenExp env aenv TAG -> (Any, OpenExp env aenv TAG)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv TAG
p) (OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t
t) (OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t
e)
      PrimConst PrimConst t
c               -> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenExp env aenv t -> (Any, OpenExp env aenv t))
-> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall a b. (a -> b) -> a -> b
$ PrimConst t -> OpenExp env aenv t
forall t env aenv. PrimConst t -> OpenExp env aenv t
PrimConst PrimConst t
c
      PrimApp PrimFun (a -> t)
f OpenExp env aenv a
x               -> (Any
uAny -> Any -> Any
forall a. Semigroup a => a -> a -> a
<>Any
v, OpenExp env aenv t
fx)
        where
          (Any
u, OpenExp env aenv a
x') = OpenExp env aenv a -> (Any, OpenExp env aenv a)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv a
x
          (Any
v, OpenExp env aenv t
fx) = Gamma env env aenv
-> PrimFun (a -> t)
-> OpenExp env aenv a
-> (Any, OpenExp env aenv t)
forall env aenv a r.
Gamma env env aenv
-> PrimFun (a -> r)
-> OpenExp env aenv a
-> (Any, OpenExp env aenv r)
evalPrimApp Gamma env env aenv
env PrimFun (a -> t)
f OpenExp env aenv a
x'
      Index ArrayVar aenv (Array dim t)
a OpenExp env aenv dim
sh                -> ArrayVar aenv (Array dim t)
-> OpenExp env aenv dim -> OpenExp env aenv t
forall aenv dim t env.
ArrayVar aenv (Array dim t)
-> OpenExp env aenv dim -> OpenExp env aenv t
Index ArrayVar aenv (Array dim t)
a (OpenExp env aenv dim -> OpenExp env aenv t)
-> (Any, OpenExp env aenv dim) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv dim -> (Any, OpenExp env aenv dim)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv dim
sh
      LinearIndex ArrayVar aenv (Array dim t)
a OpenExp env aenv Int
i           -> ArrayVar aenv (Array dim t)
-> OpenExp env aenv Int -> OpenExp env aenv t
forall aenv dim t env.
ArrayVar aenv (Array dim t)
-> OpenExp env aenv Int -> OpenExp env aenv t
LinearIndex ArrayVar aenv (Array dim t)
a (OpenExp env aenv Int -> OpenExp env aenv t)
-> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv Int -> (Any, OpenExp env aenv Int)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv Int
i
      Shape ArrayVar aenv (Array t e)
a                   -> ArrayVar aenv (Array t e) -> (Any, OpenExp env aenv t)
forall sh t.
ArrayVar aenv (Array sh t) -> (Any, OpenExp env aenv sh)
shape ArrayVar aenv (Array t e)
a
      ShapeSize ShapeR dim
shr OpenExp env aenv dim
sh          -> ShapeR dim
-> (Any, OpenExp env aenv dim) -> (Any, OpenExp env aenv Int)
forall sh.
ShapeR sh
-> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int)
shapeSize ShapeR dim
shr (OpenExp env aenv dim -> (Any, OpenExp env aenv dim)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv dim
sh)
      Foreign TypeR t
tp asm (x -> t)
ff Fun () (x -> t)
f OpenExp env aenv x
e         -> TypeR t
-> asm (x -> t)
-> Fun () (x -> t)
-> OpenExp env aenv x
-> OpenExp env aenv t
forall (asm :: * -> *) y x env aenv.
Foreign asm =>
TypeR y
-> asm (x -> y)
-> Fun () (x -> y)
-> OpenExp env aenv x
-> OpenExp env aenv y
Foreign TypeR t
tp asm (x -> t)
ff (Fun () (x -> t) -> OpenExp env aenv x -> OpenExp env aenv t)
-> (Any, Fun () (x -> t))
-> (Any, OpenExp env aenv x -> OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Bool -> Any) -> (Bool, Fun () (x -> t)) -> (Any, Fun () (x -> t))
forall a a' b. (a -> a') -> (a, b) -> (a', b)
first Bool -> Any
Any (Gamma () () () -> Fun () (x -> t) -> (Bool, Fun () (x -> t))
forall env aenv f.
Gamma env env aenv
-> OpenFun env aenv f -> (Bool, OpenFun env aenv f)
simplifyOpenFun Gamma () () ()
forall env env' aenv. Gamma env env' aenv
EmptyExp Fun () (x -> t)
f) (Any, OpenExp env aenv x -> OpenExp env aenv t)
-> (Any, OpenExp env aenv x) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpenExp env aenv x -> (Any, OpenExp env aenv x)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv x
e
      While OpenFun env aenv (t -> TAG)
p OpenFun env aenv (t -> t)
f OpenExp env aenv t
x               -> OpenFun env aenv (t -> TAG)
-> OpenFun env aenv (t -> t)
-> OpenExp env aenv t
-> OpenExp env aenv t
forall env aenv a.
OpenFun env aenv (a -> TAG)
-> OpenFun env aenv (a -> a)
-> OpenExp env aenv a
-> OpenExp env aenv a
While (OpenFun env aenv (t -> TAG)
 -> OpenFun env aenv (t -> t)
 -> OpenExp env aenv t
 -> OpenExp env aenv t)
-> (Any, OpenFun env aenv (t -> TAG))
-> (Any,
    OpenFun env aenv (t -> t)
    -> OpenExp env aenv t -> OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma env env aenv
-> OpenFun env aenv (t -> TAG)
-> (Any, OpenFun env aenv (t -> TAG))
forall env' f.
Gamma env' env' aenv
-> OpenFun env' aenv f -> (Any, OpenFun env' aenv f)
cvtF Gamma env env aenv
env OpenFun env aenv (t -> TAG)
p (Any,
 OpenFun env aenv (t -> t)
 -> OpenExp env aenv t -> OpenExp env aenv t)
-> (Any, OpenFun env aenv (t -> t))
-> (Any, OpenExp env aenv t -> OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gamma env env aenv
-> OpenFun env aenv (t -> t) -> (Any, OpenFun env aenv (t -> t))
forall env' f.
Gamma env' env' aenv
-> OpenFun env' aenv f -> (Any, OpenFun env' aenv f)
cvtF Gamma env env aenv
env OpenFun env aenv (t -> t)
f (Any, OpenExp env aenv t -> OpenExp env aenv t)
-> (Any, OpenExp env aenv t) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv t
x
      Coerce ScalarType a
t1 ScalarType t
t2 OpenExp env aenv a
e            -> ScalarType a
-> ScalarType t -> OpenExp env aenv a -> OpenExp env aenv t
forall a b env aenv.
BitSizeEq a b =>
ScalarType a
-> ScalarType b -> OpenExp env aenv a -> OpenExp env aenv b
Coerce ScalarType a
t1 ScalarType t
t2 (OpenExp env aenv a -> OpenExp env aenv t)
-> (Any, OpenExp env aenv a) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv a -> (Any, OpenExp env aenv a)
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv a
e

    cvtE' :: Gamma env' env' aenv -> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e')
    cvtE' :: Gamma env' env' aenv
-> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e')
cvtE' Gamma env' env' aenv
env' = (Bool -> Any)
-> (Bool, OpenExp env' aenv e') -> (Any, OpenExp env' aenv e')
forall a a' b. (a -> a') -> (a, b) -> (a', b)
first Bool -> Any
Any ((Bool, OpenExp env' aenv e') -> (Any, OpenExp env' aenv e'))
-> (OpenExp env' aenv e' -> (Bool, OpenExp env' aenv e'))
-> OpenExp env' aenv e'
-> (Any, OpenExp env' aenv e')
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gamma env' env' aenv
-> OpenExp env' aenv e' -> (Bool, OpenExp env' aenv e')
forall env aenv e.
Gamma env env aenv
-> OpenExp env aenv e -> (Bool, OpenExp env aenv e)
simplifyOpenExp Gamma env' env' aenv
env'

    cvtF :: Gamma env' env' aenv -> OpenFun env' aenv f -> (Any, OpenFun env' aenv f)
    cvtF :: Gamma env' env' aenv
-> OpenFun env' aenv f -> (Any, OpenFun env' aenv f)
cvtF Gamma env' env' aenv
env' = (Bool -> Any)
-> (Bool, OpenFun env' aenv f) -> (Any, OpenFun env' aenv f)
forall a a' b. (a -> a') -> (a, b) -> (a', b)
first Bool -> Any
Any ((Bool, OpenFun env' aenv f) -> (Any, OpenFun env' aenv f))
-> (OpenFun env' aenv f -> (Bool, OpenFun env' aenv f))
-> OpenFun env' aenv f
-> (Any, OpenFun env' aenv f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gamma env' env' aenv
-> OpenFun env' aenv f -> (Bool, OpenFun env' aenv f)
forall env aenv f.
Gamma env env aenv
-> OpenFun env aenv f -> (Bool, OpenFun env aenv f)
simplifyOpenFun Gamma env' env' aenv
env'

    cvtMaybeE :: Maybe (OpenExp env aenv e') -> (Any, Maybe (OpenExp env aenv e'))
    cvtMaybeE :: Maybe (OpenExp env aenv e') -> (Any, Maybe (OpenExp env aenv e'))
cvtMaybeE Maybe (OpenExp env aenv e')
Nothing  = Maybe (OpenExp env aenv e') -> (Any, Maybe (OpenExp env aenv e'))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (OpenExp env aenv e')
forall a. Maybe a
Nothing
    cvtMaybeE (Just OpenExp env aenv e'
e) = OpenExp env aenv e' -> Maybe (OpenExp env aenv e')
forall a. a -> Maybe a
Just (OpenExp env aenv e' -> Maybe (OpenExp env aenv e'))
-> (Any, OpenExp env aenv e') -> (Any, Maybe (OpenExp env aenv e'))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv e' -> (Any, OpenExp env aenv e')
forall t. OpenExp env aenv t -> (Any, OpenExp env aenv t)
cvtE OpenExp env aenv e'
e

    cvtLet :: Gamma env' env' aenv
           -> ELeftHandSide bnd env' env''
           -> OpenExp env' aenv bnd
           -> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t))
           -> (Any, OpenExp env' aenv t)
    cvtLet :: Gamma env' env' aenv
-> ELeftHandSide bnd env' env''
-> OpenExp env' aenv bnd
-> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t))
-> (Any, OpenExp env' aenv t)
cvtLet Gamma env' env' aenv
env' lhs :: ELeftHandSide bnd env' env''
lhs@(LeftHandSideSingle ScalarType bnd
_) OpenExp env' aenv bnd
bnd          Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body = ELeftHandSide bnd env' env''
-> OpenExp env' aenv bnd
-> OpenExp env'' aenv t
-> OpenExp env' aenv t
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
Let ELeftHandSide bnd env' env''
lhs OpenExp env' aenv bnd
bnd (OpenExp env'' aenv t -> OpenExp env' aenv t)
-> (Any, OpenExp env'' aenv t) -> (Any, OpenExp env' aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body (Gamma env' (env', bnd) aenv -> Gamma (env', bnd) (env', bnd) aenv
forall env env' aenv s.
Gamma env env' aenv -> Gamma (env, s) env' aenv
incExp (Gamma env' (env', bnd) aenv -> Gamma (env', bnd) (env', bnd) aenv)
-> Gamma env' (env', bnd) aenv
-> Gamma (env', bnd) (env', bnd) aenv
forall a b. (a -> b) -> a -> b
$ Gamma env' env' aenv
env' Gamma env' env' aenv
-> OpenExp env' aenv bnd -> Gamma env' (env', bnd) aenv
forall env env' aenv t.
Gamma env env' aenv
-> OpenExp env aenv t -> Gamma env (env', t) aenv
`pushExp` OpenExp env' aenv bnd
bnd) -- Single variable on the LHS, add binding to the environment
    cvtLet Gamma env' env' aenv
env' (LeftHandSideWildcard TupR ScalarType bnd
_)   OpenExp env' aenv bnd
_            Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body = Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body Gamma env' env' aenv
Gamma env'' env'' aenv
env'                                 -- Binding not used, remove let binding
    cvtLet Gamma env' env' aenv
env' (LeftHandSidePair LeftHandSide ScalarType v1 env' env'
l1 LeftHandSide ScalarType v2 env' env''
l2)   (Pair OpenExp env' aenv t1
e1 OpenExp env' aenv t2
e2) Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body                                             -- Split binding to multiple bindings
      = (Any -> Any)
-> (Any, OpenExp env' aenv t) -> (Any, OpenExp env' aenv t)
forall a a' b. (a -> a') -> (a, b) -> (a', b)
first (Any -> Any -> Any
forall a b. a -> b -> a
const (Any -> Any -> Any) -> Any -> Any -> Any
forall a b. (a -> b) -> a -> b
$ Bool -> Any
Any Bool
True)
      ((Any, OpenExp env' aenv t) -> (Any, OpenExp env' aenv t))
-> (Any, OpenExp env' aenv t) -> (Any, OpenExp env' aenv t)
forall a b. (a -> b) -> a -> b
$ Gamma env' env' aenv
-> LeftHandSide ScalarType v1 env' env'
-> OpenExp env' aenv v1
-> (Gamma env' env' aenv -> (Any, OpenExp env' aenv t))
-> (Any, OpenExp env' aenv t)
forall env' bnd env'' t.
Gamma env' env' aenv
-> ELeftHandSide bnd env' env''
-> OpenExp env' aenv bnd
-> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t))
-> (Any, OpenExp env' aenv t)
cvtLet Gamma env' env' aenv
env' LeftHandSide ScalarType v1 env' env'
l1 OpenExp env' aenv v1
OpenExp env' aenv t1
e1
      ((Gamma env' env' aenv -> (Any, OpenExp env' aenv t))
 -> (Any, OpenExp env' aenv t))
-> (Gamma env' env' aenv -> (Any, OpenExp env' aenv t))
-> (Any, OpenExp env' aenv t)
forall a b. (a -> b) -> a -> b
$ \Gamma env' env' aenv
env'' -> Gamma env' env' aenv
-> LeftHandSide ScalarType v2 env' env''
-> OpenExp env' aenv v2
-> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t))
-> (Any, OpenExp env' aenv t)
forall env' bnd env'' t.
Gamma env' env' aenv
-> ELeftHandSide bnd env' env''
-> OpenExp env' aenv bnd
-> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t))
-> (Any, OpenExp env' aenv t)
cvtLet Gamma env' env' aenv
env'' LeftHandSide ScalarType v2 env' env''
l2 ((env' :> env') -> OpenExp env' aenv t2 -> OpenExp env' aenv t2
forall (f :: * -> * -> * -> *) env env' aenv t.
SinkExp f =>
(env :> env') -> f env aenv t -> f env' aenv t
weakenE (LeftHandSide ScalarType v1 env' env' -> env' :> env'
forall (s :: * -> *) t env env'.
LeftHandSide s t env env' -> env :> env'
weakenWithLHS LeftHandSide ScalarType v1 env' env'
l1) OpenExp env' aenv t2
e2) Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body
    cvtLet Gamma env' env' aenv
env' ELeftHandSide bnd env' env''
lhs                        OpenExp env' aenv bnd
bnd          Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body = ELeftHandSide bnd env' env''
-> OpenExp env' aenv bnd
-> OpenExp env'' aenv t
-> OpenExp env' aenv t
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
Let ELeftHandSide bnd env' env''
lhs OpenExp env' aenv bnd
bnd (OpenExp env'' aenv t -> OpenExp env' aenv t)
-> (Any, OpenExp env'' aenv t) -> (Any, OpenExp env' aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)
body (ELeftHandSide bnd env' env''
-> Gamma env' env' aenv -> Gamma env'' env'' aenv
forall t env env' aenv.
ELeftHandSide t env env'
-> Gamma env env aenv -> Gamma env' env' aenv
lhsExpr ELeftHandSide bnd env' env''
lhs Gamma env' env' aenv
env')   -- Cannot split this binding.

    -- Simplify conditional expressions, in particular by eliminating branches
    -- when the predicate is a known constant.
    --
    cond :: (Any, OpenExp env aenv PrimBool)
         -> (Any, OpenExp env aenv t)
         -> (Any, OpenExp env aenv t)
         -> (Any, OpenExp env aenv t)
    cond :: (Any, OpenExp env aenv TAG)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
cond p :: (Any, OpenExp env aenv TAG)
p@(Any
_,OpenExp env aenv TAG
p') t :: (Any, OpenExp env aenv t)
t@(Any
_,OpenExp env aenv t
t') e :: (Any, OpenExp env aenv t)
e@(Any
_,OpenExp env aenv t
e')
      | Const ScalarType TAG
_ TAG
1 <- OpenExp env aenv TAG
p'                 = Text -> (Any, OpenExp env aenv t) -> (Any, OpenExp env aenv t)
forall a. Text -> a -> a
Stats.knownBranch Text
"True"      (OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall x. x -> (Any, x)
yes OpenExp env aenv t
t')
      | Const ScalarType TAG
_ TAG
0 <- OpenExp env aenv TAG
p'                 = Text -> (Any, OpenExp env aenv t) -> (Any, OpenExp env aenv t)
forall a. Text -> a -> a
Stats.knownBranch Text
"False"     (OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall x. x -> (Any, x)
yes OpenExp env aenv t
e')
      | Just t :~: t
Refl <- OpenExp env aenv t -> OpenExp env aenv t -> Maybe (t :~: t)
forall env aenv s t.
OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :~: t)
matchOpenExp OpenExp env aenv t
t' OpenExp env aenv t
e' = Text -> (Any, OpenExp env aenv t) -> (Any, OpenExp env aenv t)
forall a. Text -> a -> a
Stats.knownBranch Text
"redundant" (OpenExp env aenv t -> (Any, OpenExp env aenv t)
forall x. x -> (Any, x)
yes OpenExp env aenv t
e')
      | Bool
otherwise                       = OpenExp env aenv TAG
-> OpenExp env aenv t -> OpenExp env aenv t -> OpenExp env aenv t
forall env aenv t.
OpenExp env aenv TAG
-> OpenExp env aenv t -> OpenExp env aenv t -> OpenExp env aenv t
Cond (OpenExp env aenv TAG
 -> OpenExp env aenv t -> OpenExp env aenv t -> OpenExp env aenv t)
-> (Any, OpenExp env aenv TAG)
-> (Any,
    OpenExp env aenv t -> OpenExp env aenv t -> OpenExp env aenv t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Any, OpenExp env aenv TAG)
p (Any,
 OpenExp env aenv t -> OpenExp env aenv t -> OpenExp env aenv t)
-> (Any, OpenExp env aenv t)
-> (Any, OpenExp env aenv t -> OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Any, OpenExp env aenv t)
t (Any, OpenExp env aenv t -> OpenExp env aenv t)
-> (Any, OpenExp env aenv t) -> (Any, OpenExp env aenv t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Any, OpenExp env aenv t)
e

    caseof :: (Any, OpenExp env aenv TAG)
           -> (Any, [(TAG, OpenExp env aenv b)])
           -> (Any, Maybe (OpenExp env aenv b))
           -> (Any, OpenExp env aenv b)
    caseof :: (Any, OpenExp env aenv TAG)
-> (Any, [(TAG, OpenExp env aenv b)])
-> (Any, Maybe (OpenExp env aenv b))
-> (Any, OpenExp env aenv b)
caseof x :: (Any, OpenExp env aenv TAG)
x@(Any
_,OpenExp env aenv TAG
x') xs :: (Any, [(TAG, OpenExp env aenv b)])
xs@(Any
_,[(TAG, OpenExp env aenv b)]
xs') md :: (Any, Maybe (OpenExp env aenv b))
md@(Any
_,Maybe (OpenExp env aenv b)
md')
      | Const ScalarType TAG
_ TAG
t   <- OpenExp env aenv TAG
x'
      = Text -> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a. Text -> a -> a
Stats.caseElim Text
"known" (OpenExp env aenv b -> (Any, OpenExp env aenv b)
forall x. x -> (Any, x)
yes (Maybe (OpenExp env aenv b) -> OpenExp env aenv b
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (OpenExp env aenv b) -> OpenExp env aenv b)
-> Maybe (OpenExp env aenv b) -> OpenExp env aenv b
forall a b. (a -> b) -> a -> b
$ TAG -> [(TAG, OpenExp env aenv b)] -> Maybe (OpenExp env aenv b)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup TAG
t [(TAG, OpenExp env aenv b)]
xs'))
      | Just OpenExp env aenv b
d      <- Maybe (OpenExp env aenv b)
md'
      , []          <- [(TAG, OpenExp env aenv b)]
xs'
      = Text -> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a. Text -> a -> a
Stats.caseElim Text
"redundant" (OpenExp env aenv b -> (Any, OpenExp env aenv b)
forall x. x -> (Any, x)
yes OpenExp env aenv b
d)
      | Just OpenExp env aenv b
d      <- Maybe (OpenExp env aenv b)
md'
      , [(Int
_,(TAG
_,OpenExp env aenv b
u))] <- [(Int, (TAG, OpenExp env aenv b))]
us
      , Just b :~: b
Refl   <- OpenExp env aenv b -> OpenExp env aenv b -> Maybe (b :~: b)
forall env aenv s t.
OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :~: t)
matchOpenExp OpenExp env aenv b
d OpenExp env aenv b
u
      = Text -> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a. Text -> a -> a
Stats.caseDefault Text
"merge" ((Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b))
-> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a b. (a -> b) -> a -> b
$ OpenExp env aenv b -> (Any, OpenExp env aenv b)
forall x. x -> (Any, x)
yes (OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
forall env aenv b.
OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
Case OpenExp env aenv TAG
x' (((Int, (TAG, OpenExp env aenv b)) -> (TAG, OpenExp env aenv b))
-> [(Int, (TAG, OpenExp env aenv b))]
-> [(TAG, OpenExp env aenv b)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, (TAG, OpenExp env aenv b)) -> (TAG, OpenExp env aenv b)
forall a b. (a, b) -> b
snd [(Int, (TAG, OpenExp env aenv b))]
vs) (OpenExp env aenv b -> Maybe (OpenExp env aenv b)
forall a. a -> Maybe a
Just OpenExp env aenv b
u))
      | Maybe (OpenExp env aenv b)
Nothing     <- Maybe (OpenExp env aenv b)
md'
      , []          <- [(Int, (TAG, OpenExp env aenv b))]
vs
      , [(Int
_,(TAG
_,OpenExp env aenv b
u))] <- [(Int, (TAG, OpenExp env aenv b))]
us
      = Text -> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a. Text -> a -> a
Stats.caseElim Text
"overlap" (OpenExp env aenv b -> (Any, OpenExp env aenv b)
forall x. x -> (Any, x)
yes OpenExp env aenv b
u)
      | Maybe (OpenExp env aenv b)
Nothing     <- Maybe (OpenExp env aenv b)
md'
      , [(Int
_,(TAG
_,OpenExp env aenv b
u))] <- [(Int, (TAG, OpenExp env aenv b))]
us
      = Text -> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a. Text -> a -> a
Stats.caseDefault Text
"introduction" ((Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b))
-> (Any, OpenExp env aenv b) -> (Any, OpenExp env aenv b)
forall a b. (a -> b) -> a -> b
$ OpenExp env aenv b -> (Any, OpenExp env aenv b)
forall x. x -> (Any, x)
yes (OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
forall env aenv b.
OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
Case OpenExp env aenv TAG
x' (((Int, (TAG, OpenExp env aenv b)) -> (TAG, OpenExp env aenv b))
-> [(Int, (TAG, OpenExp env aenv b))]
-> [(TAG, OpenExp env aenv b)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, (TAG, OpenExp env aenv b)) -> (TAG, OpenExp env aenv b)
forall a b. (a, b) -> b
snd [(Int, (TAG, OpenExp env aenv b))]
vs) (OpenExp env aenv b -> Maybe (OpenExp env aenv b)
forall a. a -> Maybe a
Just OpenExp env aenv b
u))
      | Bool
otherwise
      = OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
forall env aenv b.
OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
Case (OpenExp env aenv TAG
 -> [(TAG, OpenExp env aenv b)]
 -> Maybe (OpenExp env aenv b)
 -> OpenExp env aenv b)
-> (Any, OpenExp env aenv TAG)
-> (Any,
    [(TAG, OpenExp env aenv b)]
    -> Maybe (OpenExp env aenv b) -> OpenExp env aenv b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Any, OpenExp env aenv TAG)
x (Any,
 [(TAG, OpenExp env aenv b)]
 -> Maybe (OpenExp env aenv b) -> OpenExp env aenv b)
-> (Any, [(TAG, OpenExp env aenv b)])
-> (Any, Maybe (OpenExp env aenv b) -> OpenExp env aenv b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Any, [(TAG, OpenExp env aenv b)])
xs (Any, Maybe (OpenExp env aenv b) -> OpenExp env aenv b)
-> (Any, Maybe (OpenExp env aenv b)) -> (Any, OpenExp env aenv b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Any, Maybe (OpenExp env aenv b))
md
      where
        ([(Int, (TAG, OpenExp env aenv b))]
us,[(Int, (TAG, OpenExp env aenv b))]
vs) = ((Int, (TAG, OpenExp env aenv b)) -> Bool)
-> [(Int, (TAG, OpenExp env aenv b))]
-> ([(Int, (TAG, OpenExp env aenv b))],
    [(Int, (TAG, OpenExp env aenv b))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(Int
n,(TAG, OpenExp env aenv b)
_) -> Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1)
                ([(Int, (TAG, OpenExp env aenv b))]
 -> ([(Int, (TAG, OpenExp env aenv b))],
     [(Int, (TAG, OpenExp env aenv b))]))
-> [(Int, (TAG, OpenExp env aenv b))]
-> ([(Int, (TAG, OpenExp env aenv b))],
    [(Int, (TAG, OpenExp env aenv b))])
forall a b. (a -> b) -> a -> b
$ Map Hash (Int, (TAG, OpenExp env aenv b))
-> [(Int, (TAG, OpenExp env aenv b))]
forall k a. Map k a -> [a]
Map.elems
                (Map Hash (Int, (TAG, OpenExp env aenv b))
 -> [(Int, (TAG, OpenExp env aenv b))])
-> ([(Hash, (Int, (TAG, OpenExp env aenv b)))]
    -> Map Hash (Int, (TAG, OpenExp env aenv b)))
-> [(Hash, (Int, (TAG, OpenExp env aenv b)))]
-> [(Int, (TAG, OpenExp env aenv b))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, (TAG, OpenExp env aenv b))
 -> (Int, (TAG, OpenExp env aenv b))
 -> (Int, (TAG, OpenExp env aenv b)))
-> [(Hash, (Int, (TAG, OpenExp env aenv b)))]
-> Map Hash (Int, (TAG, OpenExp env aenv b))
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
forall b.
(Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
merge
                ([(Hash, (Int, (TAG, OpenExp env aenv b)))]
 -> [(Int, (TAG, OpenExp env aenv b))])
-> [(Hash, (Int, (TAG, OpenExp env aenv b)))]
-> [(Int, (TAG, OpenExp env aenv b))]
forall a b. (a -> b) -> a -> b
$ [ (OpenExp env aenv b -> Hash
forall env aenv t. OpenExp env aenv t -> Hash
hashOpenExp OpenExp env aenv b
e, (Int
1,(TAG
t, OpenExp env aenv b
e))) | (TAG
t,OpenExp env aenv b
e) <- [(TAG, OpenExp env aenv b)]
xs' ]

        merge :: (Int, (TAG, OpenExp env aenv b)) -> (Int, (TAG, OpenExp env aenv b)) -> (Int, (TAG, OpenExp env aenv b))
        merge :: (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
merge (Int
n,(TAG
_,OpenExp env aenv b
a)) (Int
m,(TAG
_,OpenExp env aenv b
b))
          = String
-> Bool
-> (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
forall a. HasCallStack => String -> Bool -> a -> a
internalCheck String
"hashOpenExp/collision" (Bool -> ((b :~: b) -> Bool) -> Maybe (b :~: b) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Bool -> (b :~: b) -> Bool
forall a b. a -> b -> a
const Bool
True) (OpenExp env aenv b -> OpenExp env aenv b -> Maybe (b :~: b)
forall env aenv s t.
OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :~: t)
matchOpenExp OpenExp env aenv b
a OpenExp env aenv b
b))
          ((Int, (TAG, OpenExp env aenv b))
 -> (Int, (TAG, OpenExp env aenv b)))
-> (Int, (TAG, OpenExp env aenv b))
-> (Int, (TAG, OpenExp env aenv b))
forall a b. (a -> b) -> a -> b
$ (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m, (TAG
0xff, OpenExp env aenv b
a))

    -- Shape manipulations
    --
    shape :: ArrayVar aenv (Array sh t) -> (Any, OpenExp env aenv sh)
    shape :: ArrayVar aenv (Array sh t) -> (Any, OpenExp env aenv sh)
shape (Var (ArrayR ShapeR sh
ShapeRz TypeR e
_) Idx aenv (Array sh t)
_)
      = Text -> (Any, OpenExp env aenv ()) -> (Any, OpenExp env aenv ())
forall a. Text -> a -> a
Stats.ruleFired Text
"shape/Z" ((Any, OpenExp env aenv ()) -> (Any, OpenExp env aenv ()))
-> (Any, OpenExp env aenv ()) -> (Any, OpenExp env aenv ())
forall a b. (a -> b) -> a -> b
$ OpenExp env aenv () -> (Any, OpenExp env aenv ())
forall x. x -> (Any, x)
yes OpenExp env aenv ()
forall env aenv. OpenExp env aenv ()
Nil
    shape ArrayVar aenv (Array sh t)
a
      = OpenExp env aenv sh -> (Any, OpenExp env aenv sh)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenExp env aenv sh -> (Any, OpenExp env aenv sh))
-> OpenExp env aenv sh -> (Any, OpenExp env aenv sh)
forall a b. (a -> b) -> a -> b
$ ArrayVar aenv (Array sh t) -> OpenExp env aenv sh
forall aenv dim e env.
ArrayVar aenv (Array dim e) -> OpenExp env aenv dim
Shape ArrayVar aenv (Array sh t)
a

    shapeSize :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int)
    shapeSize :: ShapeR sh
-> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int)
shapeSize ShapeR sh
shr (Any
_, OpenExp env aenv sh
sh)
      | Just sh
c <- OpenExp env aenv sh -> Maybe sh
forall env aenv t. OpenExp env aenv t -> Maybe t
extractConstTuple OpenExp env aenv sh
sh
      = Text -> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv Int)
forall a. Text -> a -> a
Stats.ruleFired Text
"shapeSize/const" ((Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv Int))
-> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv Int)
forall a b. (a -> b) -> a -> b
$ OpenExp env aenv Int -> (Any, OpenExp env aenv Int)
forall x. x -> (Any, x)
yes (ScalarType Int -> Int -> OpenExp env aenv Int
forall t env aenv. ScalarType t -> t -> OpenExp env aenv t
Const ScalarType Int
scalarTypeInt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (ShapeR sh -> sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
shr sh
c)))
    shapeSize ShapeR sh
shr (Any, OpenExp env aenv sh)
sh
      = ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv Int
forall dim env aenv.
ShapeR dim -> OpenExp env aenv dim -> OpenExp env aenv Int
ShapeSize ShapeR sh
shr (OpenExp env aenv sh -> OpenExp env aenv Int)
-> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Any, OpenExp env aenv sh)
sh

    toIndex :: ShapeR sh
            -> (Any, OpenExp env aenv sh)
            -> (Any, OpenExp env aenv sh)
            -> (Any, OpenExp env aenv Int)
    toIndex :: ShapeR sh
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int)
toIndex ShapeR sh
_ (Any
_,OpenExp env aenv sh
sh) (Any
_,FromIndex ShapeR sh
_ OpenExp env aenv sh
sh' OpenExp env aenv Int
ix)
      | Just sh :~: sh
Refl <- OpenExp env aenv sh -> OpenExp env aenv sh -> Maybe (sh :~: sh)
forall env aenv s t.
OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :~: t)
matchOpenExp OpenExp env aenv sh
sh OpenExp env aenv sh
sh' = Text -> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv Int)
forall a. Text -> a -> a
Stats.ruleFired Text
"toIndex/fromIndex" ((Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv Int))
-> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv Int)
forall a b. (a -> b) -> a -> b
$ OpenExp env aenv Int -> (Any, OpenExp env aenv Int)
forall x. x -> (Any, x)
yes OpenExp env aenv Int
ix
    toIndex ShapeR sh
shr (Any, OpenExp env aenv sh)
sh (Any, OpenExp env aenv sh)
ix                    = ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
forall sh env aenv.
ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
ToIndex ShapeR sh
shr (OpenExp env aenv sh
 -> OpenExp env aenv sh -> OpenExp env aenv Int)
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv sh -> OpenExp env aenv Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Any, OpenExp env aenv sh)
sh (Any, OpenExp env aenv sh -> OpenExp env aenv Int)
-> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Any, OpenExp env aenv sh)
ix

    fromIndex :: ShapeR sh
              -> (Any, OpenExp env aenv sh)
              -> (Any, OpenExp env aenv Int)
              -> (Any, OpenExp env aenv sh)
    fromIndex :: ShapeR sh
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int)
-> (Any, OpenExp env aenv sh)
fromIndex ShapeR sh
_ (Any
_,OpenExp env aenv sh
sh) (Any
_,ToIndex ShapeR sh
_ OpenExp env aenv sh
sh' OpenExp env aenv sh
ix)
      | Just sh :~: sh
Refl <- OpenExp env aenv sh -> OpenExp env aenv sh -> Maybe (sh :~: sh)
forall env aenv s t.
OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :~: t)
matchOpenExp OpenExp env aenv sh
sh OpenExp env aenv sh
sh' = Text -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv sh)
forall a. Text -> a -> a
Stats.ruleFired Text
"fromIndex/toIndex" ((Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv sh))
-> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv sh)
forall a b. (a -> b) -> a -> b
$ OpenExp env aenv sh -> (Any, OpenExp env aenv sh)
forall x. x -> (Any, x)
yes OpenExp env aenv sh
ix
    fromIndex ShapeR sh
shr (Any, OpenExp env aenv sh)
sh (Any, OpenExp env aenv Int)
ix                  = ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
-> OpenExp env aenv sh
forall sh env aenv.
ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
-> OpenExp env aenv sh
FromIndex ShapeR sh
shr (OpenExp env aenv sh
 -> OpenExp env aenv Int -> OpenExp env aenv sh)
-> (Any, OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int -> OpenExp env aenv sh)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Any, OpenExp env aenv sh)
sh (Any, OpenExp env aenv Int -> OpenExp env aenv sh)
-> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv sh)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Any, OpenExp env aenv Int)
ix

    first :: (a -> a') -> (a,b) -> (a',b)
    first :: (a -> a') -> (a, b) -> (a', b)
first a -> a'
f (a
x,b
y) = (a -> a'
f a
x, b
y)

    yes :: x -> (Any, x)
    yes :: x -> (Any, x)
yes x
x = (Bool -> Any
Any Bool
True, x
x)

extractConstTuple :: OpenExp env aenv t -> Maybe t
extractConstTuple :: OpenExp env aenv t -> Maybe t
extractConstTuple OpenExp env aenv t
Nil          = () -> Maybe ()
forall a. a -> Maybe a
Just ()
extractConstTuple (Pair OpenExp env aenv t1
e1 OpenExp env aenv t2
e2) = (,) (t1 -> t2 -> (t1, t2)) -> Maybe t1 -> Maybe (t2 -> (t1, t2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenExp env aenv t1 -> Maybe t1
forall env aenv t. OpenExp env aenv t -> Maybe t
extractConstTuple OpenExp env aenv t1
e1 Maybe (t2 -> (t1, t2)) -> Maybe t2 -> Maybe (t1, t2)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpenExp env aenv t2 -> Maybe t2
forall env aenv t. OpenExp env aenv t -> Maybe t
extractConstTuple OpenExp env aenv t2
e2
extractConstTuple (Const ScalarType t
_ t
c)  = t -> Maybe t
forall a. a -> Maybe a
Just t
c
extractConstTuple OpenExp env aenv t
_            = Maybe t
forall a. Maybe a
Nothing

-- Simplification for open functions
--
simplifyOpenFun
    :: Gamma env env aenv
    -> OpenFun env aenv f
    -> (Bool, OpenFun env aenv f)
simplifyOpenFun :: Gamma env env aenv
-> OpenFun env aenv f -> (Bool, OpenFun env aenv f)
simplifyOpenFun Gamma env env aenv
env (Body OpenExp env aenv f
e)    = OpenExp env aenv f -> OpenFun env aenv f
forall env aenv t. OpenExp env aenv t -> OpenFun env aenv t
Body    (OpenExp env aenv f -> OpenFun env aenv f)
-> (Bool, OpenExp env aenv f) -> (Bool, OpenFun env aenv f)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma env env aenv
-> OpenExp env aenv f -> (Bool, OpenExp env aenv f)
forall env aenv e.
Gamma env env aenv
-> OpenExp env aenv e -> (Bool, OpenExp env aenv e)
simplifyOpenExp Gamma env env aenv
env  OpenExp env aenv f
e
simplifyOpenFun Gamma env env aenv
env (Lam ELeftHandSide a env env'
lhs OpenFun env' aenv t
f) = ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam ELeftHandSide a env env'
lhs (OpenFun env' aenv t -> OpenFun env aenv (a -> t))
-> (Bool, OpenFun env' aenv t) -> (Bool, OpenFun env aenv (a -> t))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma env' env' aenv
-> OpenFun env' aenv t -> (Bool, OpenFun env' aenv t)
forall env aenv f.
Gamma env env aenv
-> OpenFun env aenv f -> (Bool, OpenFun env aenv f)
simplifyOpenFun Gamma env' env' aenv
env' OpenFun env' aenv t
f
  where
    env' :: Gamma env' env' aenv
env' = ELeftHandSide a env env'
-> Gamma env env aenv -> Gamma env' env' aenv
forall t env env' aenv.
ELeftHandSide t env env'
-> Gamma env env aenv -> Gamma env' env' aenv
lhsExpr ELeftHandSide a env env'
lhs Gamma env env aenv
env

lhsExpr :: ELeftHandSide t env env' -> Gamma env env aenv -> Gamma env' env' aenv
lhsExpr :: ELeftHandSide t env env'
-> Gamma env env aenv -> Gamma env' env' aenv
lhsExpr (LeftHandSideWildcard TupR ScalarType t
_) Gamma env env aenv
env = Gamma env env aenv
Gamma env' env' aenv
env
lhsExpr (LeftHandSideSingle  ScalarType t
tp) Gamma env env aenv
env = Gamma env env aenv -> Gamma (env, t) env aenv
forall env env' aenv s.
Gamma env env' aenv -> Gamma (env, s) env' aenv
incExp Gamma env env aenv
env Gamma (env, t) env aenv
-> OpenExp (env, t) aenv t -> Gamma (env, t) (env, t) aenv
forall env env' aenv t.
Gamma env env' aenv
-> OpenExp env aenv t -> Gamma env (env', t) aenv
`pushExp` ExpVar (env, t) t -> OpenExp (env, t) aenv t
forall env t aenv. ExpVar env t -> OpenExp env aenv t
Evar (ScalarType t -> Idx (env, t) t -> ExpVar (env, t) t
forall (s :: * -> *) env t. s t -> Idx env t -> Var s env t
Var ScalarType t
tp Idx (env, t) t
forall env t. Idx (env, t) t
ZeroIdx)
lhsExpr (LeftHandSidePair LeftHandSide ScalarType v1 env env'
l1 LeftHandSide ScalarType v2 env' env'
l2) Gamma env env aenv
env = LeftHandSide ScalarType v2 env' env'
-> Gamma env' env' aenv -> Gamma env' env' aenv
forall t env env' aenv.
ELeftHandSide t env env'
-> Gamma env env aenv -> Gamma env' env' aenv
lhsExpr LeftHandSide ScalarType v2 env' env'
l2 (Gamma env' env' aenv -> Gamma env' env' aenv)
-> Gamma env' env' aenv -> Gamma env' env' aenv
forall a b. (a -> b) -> a -> b
$ LeftHandSide ScalarType v1 env env'
-> Gamma env env aenv -> Gamma env' env' aenv
forall t env env' aenv.
ELeftHandSide t env env'
-> Gamma env env aenv -> Gamma env' env' aenv
lhsExpr LeftHandSide ScalarType v1 env env'
l1 Gamma env env aenv
env

-- Simplify closed expressions and functions. The process is applied
-- repeatedly until no more changes are made.
--
simplifyExp :: HasCallStack => Exp aenv t -> Exp aenv t
simplifyExp :: Exp aenv t -> Exp aenv t
simplifyExp = (Exp aenv t -> Stats)
-> (forall s t.
    OpenExp () aenv s -> OpenExp () aenv t -> Maybe (s :~: t))
-> (Exp aenv t -> (Bool, Exp aenv t))
-> (Exp aenv t -> (Bool, Exp aenv t))
-> Exp aenv t
-> Exp aenv t
forall (f :: * -> *) a.
HasCallStack =>
(f a -> Stats)
-> (forall s t. f s -> f t -> Maybe (s :~: t))
-> (f a -> (Bool, f a))
-> (f a -> (Bool, f a))
-> f a
-> f a
iterate Exp aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
summariseOpenExp forall s t.
OpenExp () aenv s -> OpenExp () aenv t -> Maybe (s :~: t)
forall env aenv s t.
OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :~: t)
matchOpenExp Exp aenv t -> (Bool, Exp aenv t)
forall env aenv t.
HasCallStack =>
OpenExp env aenv t -> (Bool, OpenExp env aenv t)
shrinkExp (Gamma () () aenv -> Exp aenv t -> (Bool, Exp aenv t)
forall env aenv e.
Gamma env env aenv
-> OpenExp env aenv e -> (Bool, OpenExp env aenv e)
simplifyOpenExp Gamma () () aenv
forall env env' aenv. Gamma env env' aenv
EmptyExp)

simplifyFun :: HasCallStack => Fun aenv f -> Fun aenv f
simplifyFun :: Fun aenv f -> Fun aenv f
simplifyFun = (Fun aenv f -> Stats)
-> (forall s t.
    OpenFun () aenv s -> OpenFun () aenv t -> Maybe (s :~: t))
-> (Fun aenv f -> (Bool, Fun aenv f))
-> (Fun aenv f -> (Bool, Fun aenv f))
-> Fun aenv f
-> Fun aenv f
forall (f :: * -> *) a.
HasCallStack =>
(f a -> Stats)
-> (forall s t. f s -> f t -> Maybe (s :~: t))
-> (f a -> (Bool, f a))
-> (f a -> (Bool, f a))
-> f a
-> f a
iterate Fun aenv f -> Stats
forall env aenv f. OpenFun env aenv f -> Stats
summariseOpenFun forall s t.
OpenFun () aenv s -> OpenFun () aenv t -> Maybe (s :~: t)
forall env aenv s t.
OpenFun env aenv s -> OpenFun env aenv t -> Maybe (s :~: t)
matchOpenFun Fun aenv f -> (Bool, Fun aenv f)
forall env aenv f.
HasCallStack =>
OpenFun env aenv f -> (Bool, OpenFun env aenv f)
shrinkFun (Gamma () () aenv -> Fun aenv f -> (Bool, Fun aenv f)
forall env aenv f.
Gamma env env aenv
-> OpenFun env aenv f -> (Bool, OpenFun env aenv f)
simplifyOpenFun Gamma () () aenv
forall env env' aenv. Gamma env env' aenv
EmptyExp)


-- NOTE: [Simplifier iterations]
--
-- Run the simplification pass _before_ the shrinking step. There are cases
-- where it is better to run shrinking first, and then simplification would
-- complete in a single step, but the converse is also true. However, as
-- shrinking can remove some structure of the let bindings, which might be
-- useful for the transformations (e.g. loop recovery) we want to maintain this
-- information for at least the first pass.
--
-- We always apply the simplification step once. Following this, we iterate
-- shrinking and simplification until the expression no longer changes. Both
-- shrink and simplify return a boolean indicating whether any work was done; we
-- stop as soon as either returns false.
--
-- With internal checks on, we also issue a warning if the iteration limit is
-- reached, but it was still possible to make changes to the expression.
--

iterate
    :: forall f a. HasCallStack
    => (f a -> Stats)
    -> (forall s t. f s -> f t -> Maybe (s :~: t))  -- match
    -> (f a -> (Bool, f a))                         -- shrink
    -> (f a -> (Bool, f a))                         -- simplify
    -> f a
    -> f a
iterate :: (f a -> Stats)
-> (forall s t. f s -> f t -> Maybe (s :~: t))
-> (f a -> (Bool, f a))
-> (f a -> (Bool, f a))
-> f a
-> f a
iterate f a -> Stats
summarise forall s t. f s -> f t -> Maybe (s :~: t)
match f a -> (Bool, f a)
shrink f a -> (Bool, f a)
simplify = Int -> f a -> f a
fix Int
1 (f a -> f a) -> (f a -> f a) -> f a -> f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> f a
setup
  where
    -- The maximum number of simplifier iterations. To be conservative and avoid
    -- excessive run times, we (should) set this value very low.
    --
    -- TODO: make this tunable via debug flags.
    --
    lIMIT :: Int
lIMIT       = Int
25

    simplify' :: f a -> (Bool, f a)
simplify'   = (Bool, f a) -> (Bool, f a)
forall a. a -> a
Stats.simplifierDone ((Bool, f a) -> (Bool, f a))
-> (f a -> (Bool, f a)) -> f a -> (Bool, f a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> (Bool, f a)
simplify
    setup :: f a -> f a
setup f a
x     = Flag -> String -> f a -> f a
forall a. Flag -> String -> a -> a
Debug.trace Flag
Debug.dump_simpl_iterations (Int -> String -> f a -> String
msg Int
0 String
"init" f a
x)
                (f a -> f a) -> f a -> f a
forall a b. (a -> b) -> a -> b
$ (Bool, f a) -> f a
forall a b. (a, b) -> b
snd (Int -> String -> (Bool, f a) -> (Bool, f a)
trace Int
1 String
"simplify" (f a -> (Bool, f a)
simplify' f a
x))

    fix :: Int -> f a -> f a
    fix :: Int -> f a -> f a
fix Int
i f a
x0
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
lIMIT       = String -> Bool -> f a -> f a
forall a. HasCallStack => String -> Bool -> a -> a
internalWarning String
"iteration limit reached" (Bool -> Bool
not (f a
x0 f a -> (Bool, f a) -> Bool
==^ f a -> (Bool, f a)
simplify f a
x0)) f a
x0
      | Bool -> Bool
not Bool
shrunk      = f a
x1
      | Bool -> Bool
not Bool
simplified  = f a
x2
      | Bool
otherwise       = Int -> f a -> f a
fix (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) f a
x2
      where
        (Bool
shrunk,     f a
x1) = Int -> String -> (Bool, f a) -> (Bool, f a)
trace Int
i String
"shrink"   ((Bool, f a) -> (Bool, f a)) -> (Bool, f a) -> (Bool, f a)
forall a b. (a -> b) -> a -> b
$ f a -> (Bool, f a)
shrink f a
x0
        (Bool
simplified, f a
x2) = Int -> String -> (Bool, f a) -> (Bool, f a)
trace Int
i String
"simplify" ((Bool, f a) -> (Bool, f a)) -> (Bool, f a) -> (Bool, f a)
forall a b. (a -> b) -> a -> b
$ f a -> (Bool, f a)
simplify' f a
x1

    -- debugging support
    --
    f a
u ==^ :: f a -> (Bool, f a) -> Bool
==^ (Bool
_,f a
v)         = Maybe (a :~: a) -> Bool
forall a. Maybe a -> Bool
isJust (f a -> f a -> Maybe (a :~: a)
forall s t. f s -> f t -> Maybe (s :~: t)
match f a
u f a
v)

    trace :: Int -> String -> (Bool, f a) -> (Bool, f a)
trace Int
i String
s v :: (Bool, f a)
v@(Bool
changed,f a
x)
      | Bool
changed         = Flag -> String -> (Bool, f a) -> (Bool, f a)
forall a. Flag -> String -> a -> a
Debug.trace Flag
Debug.dump_simpl_iterations (Int -> String -> f a -> String
msg Int
i String
s f a
x) (Bool, f a)
v
      | Bool
otherwise       = (Bool, f a)
v

    msg :: Int -> String -> f a -> String
    msg :: Int -> String -> f a -> String
msg Int
i String
s f a
x = String -> String -> Int -> String -> String
forall r. PrintfType r => String -> r
printf String
"simpl-iters/%-8s [%d]: %s" String
s Int
i (f a -> String
ppr f a
x)

    ppr :: f a -> String
    ppr :: f a -> String
ppr = Stats -> String
forall a. Show a => a -> String
show (Stats -> String) -> (f a -> Stats) -> f a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Stats
summarise


-- Debugging support
-- -----------------

data Stats = Stats
  { Stats -> Int
_terms    :: {-# UNPACK #-} !Int
  , Stats -> Int
_types    :: {-# UNPACK #-} !Int
  , Stats -> Int
_binders  :: {-# UNPACK #-} !Int
  , Stats -> Int
_vars     :: {-# UNPACK #-} !Int
  , Stats -> Int
_ops      :: {-# UNPACK #-} !Int
  }

instance Show Stats where
  show :: Stats -> String
show (Stats Int
a Int
b Int
c Int
d Int
e) =
    String -> Int -> Int -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"terms = %d, types = %d, lets = %d, vars = %d, primops = %d" Int
a Int
b Int
c Int
d Int
e

instance Semigroup Stats where
  <> :: Stats -> Stats -> Stats
(<>) = Stats -> Stats -> Stats
(+++)

instance Monoid Stats where
  mempty :: Stats
mempty = Int -> Int -> Int -> Int -> Int -> Stats
Stats Int
0 Int
0 Int
0 Int
0 Int
0

infixl 6 +++
(+++) :: Stats -> Stats -> Stats
Stats Int
a1 Int
b1 Int
c1 Int
d1 Int
e1 +++ :: Stats -> Stats -> Stats
+++ Stats Int
a2 Int
b2 Int
c2 Int
d2 Int
e2 = Int -> Int -> Int -> Int -> Int -> Stats
Stats (Int
a1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
a2) (Int
b1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
b2) (Int
c1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c2) (Int
d1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
d2) (Int
e1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
e2)
{-# INLINE (+++) #-}

terms, types, binders, vars, ops :: Lens' Stats Int
terms :: (Int -> f Int) -> Stats -> f Stats
terms   = (Stats -> Int)
-> (Stats -> Int -> Stats) -> Lens Stats Stats Int Int
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Stats -> Int
_terms   (\Stats{Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
_ops :: Stats -> Int
_vars :: Stats -> Int
_binders :: Stats -> Int
_types :: Stats -> Int
_terms :: Stats -> Int
..} Int
v -> Stats :: Int -> Int -> Int -> Int -> Int -> Stats
Stats { _terms :: Int
_terms   = Int
v, Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
..})
types :: (Int -> f Int) -> Stats -> f Stats
types   = (Stats -> Int)
-> (Stats -> Int -> Stats) -> Lens Stats Stats Int Int
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Stats -> Int
_types   (\Stats{Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
_ops :: Stats -> Int
_vars :: Stats -> Int
_binders :: Stats -> Int
_types :: Stats -> Int
_terms :: Stats -> Int
..} Int
v -> Stats :: Int -> Int -> Int -> Int -> Int -> Stats
Stats { _types :: Int
_types   = Int
v, Int
_ops :: Int
_vars :: Int
_binders :: Int
_terms :: Int
_ops :: Int
_vars :: Int
_binders :: Int
_terms :: Int
..})
binders :: (Int -> f Int) -> Stats -> f Stats
binders = (Stats -> Int)
-> (Stats -> Int -> Stats) -> Lens Stats Stats Int Int
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Stats -> Int
_binders (\Stats{Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
_ops :: Stats -> Int
_vars :: Stats -> Int
_binders :: Stats -> Int
_types :: Stats -> Int
_terms :: Stats -> Int
..} Int
v -> Stats :: Int -> Int -> Int -> Int -> Int -> Stats
Stats { _binders :: Int
_binders = Int
v, Int
_ops :: Int
_vars :: Int
_types :: Int
_terms :: Int
_ops :: Int
_vars :: Int
_types :: Int
_terms :: Int
..})
vars :: (Int -> f Int) -> Stats -> f Stats
vars    = (Stats -> Int)
-> (Stats -> Int -> Stats) -> Lens Stats Stats Int Int
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Stats -> Int
_vars    (\Stats{Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
_ops :: Stats -> Int
_vars :: Stats -> Int
_binders :: Stats -> Int
_types :: Stats -> Int
_terms :: Stats -> Int
..} Int
v -> Stats :: Int -> Int -> Int -> Int -> Int -> Stats
Stats { _vars :: Int
_vars    = Int
v, Int
_ops :: Int
_binders :: Int
_types :: Int
_terms :: Int
_ops :: Int
_binders :: Int
_types :: Int
_terms :: Int
..})
ops :: (Int -> f Int) -> Stats -> f Stats
ops     = (Stats -> Int)
-> (Stats -> Int -> Stats) -> Lens Stats Stats Int Int
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens Stats -> Int
_ops     (\Stats{Int
_ops :: Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
_ops :: Stats -> Int
_vars :: Stats -> Int
_binders :: Stats -> Int
_types :: Stats -> Int
_terms :: Stats -> Int
..} Int
v -> Stats :: Int -> Int -> Int -> Int -> Int -> Stats
Stats { _ops :: Int
_ops     = Int
v, Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
_vars :: Int
_binders :: Int
_types :: Int
_terms :: Int
..})
{-# INLINE terms   #-}
{-# INLINE types   #-}
{-# INLINE binders #-}
{-# INLINE vars    #-}
{-# INLINE ops     #-}

summariseOpenFun :: OpenFun env aenv f -> Stats
summariseOpenFun :: OpenFun env aenv f -> Stats
summariseOpenFun (Body OpenExp env aenv f
e)  = OpenExp env aenv f -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
summariseOpenExp OpenExp env aenv f
e Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
summariseOpenFun (Lam ELeftHandSide a env env'
_ OpenFun env' aenv t
f) = OpenFun env' aenv t -> Stats
forall env aenv f. OpenFun env aenv f -> Stats
summariseOpenFun OpenFun env' aenv t
f Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1 Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
binders ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

summariseOpenExp :: OpenExp env aenv t -> Stats
summariseOpenExp :: OpenExp env aenv t -> Stats
summariseOpenExp = ((Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1) (Stats -> Stats)
-> (OpenExp env aenv t -> Stats) -> OpenExp env aenv t -> Stats
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
goE
  where
    zero :: Stats
zero = Int -> Int -> Int -> Int -> Int -> Stats
Stats Int
0 Int
0 Int
0 Int
0 Int
0

    travE :: OpenExp env aenv t -> Stats
    travE :: OpenExp env aenv t -> Stats
travE = OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
summariseOpenExp

    travF :: OpenFun env aenv t -> Stats
    travF :: OpenFun env aenv t -> Stats
travF = OpenFun env aenv t -> Stats
forall env aenv f. OpenFun env aenv f -> Stats
summariseOpenFun

    travA :: acc aenv a -> Stats
    travA :: acc aenv a -> Stats
travA acc aenv a
_ = Stats
zero Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
vars ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1  -- assume an array index, else we should have failed elsewhere

    travC :: PrimConst c -> Stats
    travC :: PrimConst c -> Stats
travC (PrimMinBound BoundedType c
t) = BoundedType c -> Stats
forall t. BoundedType t -> Stats
travBoundedType BoundedType c
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
    travC (PrimMaxBound BoundedType c
t) = BoundedType c -> Stats
forall t. BoundedType t -> Stats
travBoundedType BoundedType c
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
    travC (PrimPi FloatingType c
t)       = FloatingType c -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType c
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

    travIntegralType :: IntegralType t -> Stats
    travIntegralType :: IntegralType t -> Stats
travIntegralType IntegralType t
_ = Stats
zero Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
types ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

    travFloatingType :: FloatingType t -> Stats
    travFloatingType :: FloatingType t -> Stats
travFloatingType FloatingType t
_ = Stats
zero Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
types ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

    travNumType :: NumType t -> Stats
    travNumType :: NumType t -> Stats
travNumType (IntegralNumType IntegralType t
t) = IntegralType t -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType t
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
types ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
    travNumType (FloatingNumType FloatingType t
t) = FloatingType t -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType t
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
types ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

    travBoundedType :: BoundedType t -> Stats
    travBoundedType :: BoundedType t -> Stats
travBoundedType (IntegralBoundedType IntegralType t
t) = IntegralType t -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType t
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
types ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

    -- travScalarType :: ScalarType t -> Stats
    -- travScalarType (SingleScalarType t) = travSingleType t & types +~ 1
    -- travScalarType (VectorScalarType t) = travVectorType t & types +~ 1

    travSingleType :: SingleType t -> Stats
    travSingleType :: SingleType t -> Stats
travSingleType (NumSingleType NumType t
t) = NumType t -> Stats
forall t. NumType t -> Stats
travNumType NumType t
t Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
types ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1

    -- travVectorType :: VectorType t -> Stats
    -- travVectorType (Vector2Type t)  = travSingleType t & types +~ 1
    -- travVectorType (Vector3Type t)  = travSingleType t & types +~ 1
    -- travVectorType (Vector4Type t)  = travSingleType t & types +~ 1
    -- travVectorType (Vector8Type t)  = travSingleType t & types +~ 1
    -- travVectorType (Vector16Type t) = travSingleType t & types +~ 1

    -- The scrutinee has already been counted
    goE :: OpenExp env aenv t -> Stats
    goE :: OpenExp env aenv t -> Stats
goE OpenExp env aenv t
exp =
      case OpenExp env aenv t
exp of
        Let ELeftHandSide bnd_t env env'
_ OpenExp env aenv bnd_t
bnd OpenExp env' aenv t
body        -> OpenExp env aenv bnd_t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv bnd_t
bnd Stats -> Stats -> Stats
+++ OpenExp env' aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env' aenv t
body Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
binders ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
        Evar{}                -> Stats
zero Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
vars ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
        Foreign TypeR t
_ asm (x -> t)
_ Fun () (x -> t)
_ OpenExp env aenv x
x       -> OpenExp env aenv x -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv x
x Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1   -- +1 for asm, ignore fallback impls.
        Const{}               -> Stats
zero
        Undef ScalarType t
_               -> Stats
zero
        OpenExp env aenv t
Nil                   -> Stats
zero Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
        Pair OpenExp env aenv t1
e1 OpenExp env aenv t2
e2            -> OpenExp env aenv t1 -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t1
e1 Stats -> Stats -> Stats
+++ OpenExp env aenv t2 -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t2
e2 Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1
        VecPack   VecR n s tup
_ OpenExp env aenv tup
e         -> OpenExp env aenv tup -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv tup
e
        VecUnpack VecR n s t
_ OpenExp env aenv (Vec n s)
e         -> OpenExp env aenv (Vec n s) -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv (Vec n s)
e
        IndexSlice SliceIndex slix t co sh
_ OpenExp env aenv slix
slix OpenExp env aenv sh
sh  -> OpenExp env aenv slix -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv slix
slix Stats -> Stats -> Stats
+++ OpenExp env aenv sh -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv sh
sh Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1 -- +1 for sliceIndex
        IndexFull SliceIndex slix sl co t
_ OpenExp env aenv slix
slix OpenExp env aenv sl
sl   -> OpenExp env aenv slix -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv slix
slix Stats -> Stats -> Stats
+++ OpenExp env aenv sl -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv sl
sl Stats -> (Stats -> Stats) -> Stats
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
terms ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1 -- +1 for sliceIndex
        ToIndex ShapeR sh
_ OpenExp env aenv sh
sh OpenExp env aenv sh
ix       -> OpenExp env aenv sh -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv sh
sh Stats -> Stats -> Stats
+++ OpenExp env aenv sh -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv sh
ix
        FromIndex ShapeR t
_ OpenExp env aenv t
sh OpenExp env aenv Int
ix     -> OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t
sh Stats -> Stats -> Stats
+++ OpenExp env aenv Int -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv Int
ix
        Case OpenExp env aenv TAG
e [(TAG, OpenExp env aenv t)]
rhs Maybe (OpenExp env aenv t)
def        -> OpenExp env aenv TAG -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv TAG
e Stats -> Stats -> Stats
+++ [Stats] -> Stats
forall a. Monoid a => [a] -> a
mconcat [ OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t
c | (TAG
_,OpenExp env aenv t
c) <- [(TAG, OpenExp env aenv t)]
rhs ] Stats -> Stats -> Stats
+++ Stats
-> (OpenExp env aenv t -> Stats)
-> Maybe (OpenExp env aenv t)
-> Stats
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Stats
zero OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE Maybe (OpenExp env aenv t)
def
        Cond OpenExp env aenv TAG
p OpenExp env aenv t
t OpenExp env aenv t
e            -> OpenExp env aenv TAG -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv TAG
p Stats -> Stats -> Stats
+++ OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t
t Stats -> Stats -> Stats
+++ OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t
e
        While OpenFun env aenv (t -> TAG)
p OpenFun env aenv (t -> t)
f OpenExp env aenv t
x           -> OpenFun env aenv (t -> TAG) -> Stats
forall env aenv f. OpenFun env aenv f -> Stats
travF OpenFun env aenv (t -> TAG)
p Stats -> Stats -> Stats
+++ OpenFun env aenv (t -> t) -> Stats
forall env aenv f. OpenFun env aenv f -> Stats
travF OpenFun env aenv (t -> t)
f Stats -> Stats -> Stats
+++ OpenExp env aenv t -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv t
x
        PrimConst PrimConst t
c           -> PrimConst t -> Stats
forall c. PrimConst c -> Stats
travC PrimConst t
c
        Index ArrayVar aenv (Array dim t)
a OpenExp env aenv dim
ix            -> ArrayVar aenv (Array dim t) -> Stats
forall (acc :: * -> * -> *) aenv a. acc aenv a -> Stats
travA ArrayVar aenv (Array dim t)
a Stats -> Stats -> Stats
+++ OpenExp env aenv dim -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv dim
ix
        LinearIndex ArrayVar aenv (Array dim t)
a OpenExp env aenv Int
ix      -> ArrayVar aenv (Array dim t) -> Stats
forall (acc :: * -> * -> *) aenv a. acc aenv a -> Stats
travA ArrayVar aenv (Array dim t)
a Stats -> Stats -> Stats
+++ OpenExp env aenv Int -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv Int
ix
        Shape ArrayVar aenv (Array t e)
a               -> ArrayVar aenv (Array t e) -> Stats
forall (acc :: * -> * -> *) aenv a. acc aenv a -> Stats
travA ArrayVar aenv (Array t e)
a
        ShapeSize ShapeR dim
_ OpenExp env aenv dim
sh        -> OpenExp env aenv dim -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv dim
sh
        PrimApp PrimFun (a -> t)
f OpenExp env aenv a
x           -> PrimFun (a -> t) -> Stats
forall f. PrimFun f -> Stats
travPrimFun PrimFun (a -> t)
f Stats -> Stats -> Stats
+++ OpenExp env aenv a -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv a
x
        Coerce ScalarType a
_ ScalarType t
_ OpenExp env aenv a
e          -> OpenExp env aenv a -> Stats
forall env aenv t. OpenExp env aenv t -> Stats
travE OpenExp env aenv a
e

    travPrimFun :: PrimFun f -> Stats
    travPrimFun :: PrimFun f -> Stats
travPrimFun = ((Int -> Identity Int) -> Stats -> Identity Stats
Lens Stats Stats Int Int
ops ((Int -> Identity Int) -> Stats -> Identity Stats)
-> Int -> Stats -> Stats
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Int
1) (Stats -> Stats) -> (PrimFun f -> Stats) -> PrimFun f -> Stats
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimFun f -> Stats
forall f. PrimFun f -> Stats
goF
      where
        goF :: PrimFun f -> Stats
        goF :: PrimFun f -> Stats
goF PrimFun f
fun =
          case PrimFun f
fun of
            PrimAdd                NumType a
t -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
t
            PrimSub                NumType a
t -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
t
            PrimMul                NumType a
t -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
t
            PrimNeg                NumType a
t -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
t
            PrimAbs                NumType a
t -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
t
            PrimSig                NumType a
t -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
t
            PrimQuot               IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimRem                IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimQuotRem            IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimIDiv               IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimMod                IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimDivMod             IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBAnd               IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBOr                IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBXor               IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBNot               IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBShiftL            IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBShiftR            IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBRotateL           IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimBRotateR           IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimPopCount           IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimCountLeadingZeros  IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimCountTrailingZeros IntegralType a
t -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
t
            PrimFDiv               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimRecip              FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimSin                FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimCos                FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimTan                FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAsin               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAcos               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAtan               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimSinh               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimCosh               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimTanh               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAsinh              FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAcosh              FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAtanh              FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimExpFloating        FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimSqrt               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimLog                FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimFPow               FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimLogBase            FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimTruncate         FloatingType a
f IntegralType b
i -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
f Stats -> Stats -> Stats
+++ IntegralType b -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType b
i
            PrimRound            FloatingType a
f IntegralType b
i -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
f Stats -> Stats -> Stats
+++ IntegralType b -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType b
i
            PrimFloor            FloatingType a
f IntegralType b
i -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
f Stats -> Stats -> Stats
+++ IntegralType b -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType b
i
            PrimCeiling          FloatingType a
f IntegralType b
i -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
f Stats -> Stats -> Stats
+++ IntegralType b -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType b
i
            PrimIsNaN              FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimIsInfinite         FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimAtan2              FloatingType a
t -> FloatingType a -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType a
t
            PrimLt                 SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimGt                 SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimLtEq               SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimGtEq               SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimEq                 SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimNEq                SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimMax                SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimMin                SingleType a
t -> SingleType a -> Stats
forall t. SingleType t -> Stats
travSingleType SingleType a
t
            PrimFun f
PrimLAnd                 -> Stats
zero
            PrimFun f
PrimLOr                  -> Stats
zero
            PrimFun f
PrimLNot                 -> Stats
zero
            PrimFromIntegral     IntegralType a
i NumType b
n -> IntegralType a -> Stats
forall t. IntegralType t -> Stats
travIntegralType IntegralType a
i Stats -> Stats -> Stats
+++ NumType b -> Stats
forall t. NumType t -> Stats
travNumType NumType b
n
            PrimToFloating       NumType a
n FloatingType b
f -> NumType a -> Stats
forall t. NumType t -> Stats
travNumType NumType a
n Stats -> Stats -> Stats
+++ FloatingType b -> Stats
forall t. FloatingType t -> Stats
travFloatingType FloatingType b
f