{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedLists     #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo.Sharing
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- This module implements HOAS to de Bruijn conversion of array expressions
-- while incorporating sharing information.
--

module Data.Array.Accelerate.Trafo.Sharing (

  -- * HOAS to de Bruijn conversion
  convertAcc, convertAccWith,

  Afunction, AfunctionR, ArraysFunctionR, AfunctionRepr(..), afunctionRepr,
  convertAfun, convertAfunWith,

  Function, FunctionR, EltFunctionR, FunctionRepr(..), functionRepr,
  convertExp, convertExpWith,
  convertFun, convertFunWith,

  -- convertSeq

) where

import Data.Array.Accelerate.AST                                    hiding ( PreOpenAcc(..), OpenAcc(..), Acc, OpenExp(..), Exp, Boundary(..), HasArraysR(..), showPreAccOp )
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Debug.Flags                            as Debug
import Data.Array.Accelerate.Debug.Trace                            as Debug
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array                   ( Array, ArraysR, ArrayR(..), showArraysR )
import Data.Array.Accelerate.Representation.Shape                   hiding ( zip )
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Smart                                  as Smart hiding ( StencilR )
import Data.Array.Accelerate.Sugar.Array                            hiding ( Array, ArraysR, (!!) )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Trafo.Config
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Trafo.Var
import Data.Array.Accelerate.Type
import Data.BitSet                                                  ( (\\), member )
import qualified Data.Array.Accelerate.AST                          as AST
import qualified Data.Array.Accelerate.Representation.Stencil       as R
import qualified Data.Array.Accelerate.Sugar.Array                  as Sugar

import Control.Applicative                                          hiding ( Const )
import Control.Lens                                                 ( over, mapped, _1, _2 )
import Control.Monad.Fix
import Data.Function                                                ( on )
import Data.Hashable
import Data.List                                                    ( elemIndex, findIndex, groupBy, intercalate, partition )
import Data.Maybe
import Data.Monoid                                                  ( Any(..) )
import System.IO.Unsafe                                             ( unsafePerformIO )
import System.Mem.StableName
import Text.Printf
import qualified Data.HashMap.Strict                                as Map
import qualified Data.HashSet                                       as Set
import qualified Data.HashTable.IO                                  as Hash
import qualified Data.IntMap                                        as IntMap
import Prelude


-- Layouts
-- -------

-- A layout of an environment has an entry for each entry of the environment.
-- Each entry in the layout holds the de Bruijn index that refers to the
-- corresponding entry in the environment.
--
data Layout s env env' where
  EmptyLayout :: Layout s env ()
  PushLayout  :: Layout s env env1
              -> LeftHandSide s t env1 env2
              -> Vars s env t
              -> Layout s env env2

type ELayout = Layout ScalarType
type ArrayLayout = Layout ArrayR


-- Project the nth index out of an environment layout.
--
-- The first argument provides context information for error messages in the
-- case of failure.
--
prjIdx :: forall s t env env1. HasCallStack
       => String
       -> (forall t'. TupR s t' -> ShowS)
       -> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
       -> TupR s t
       -> Int
       -> Layout s env env1
       -> Vars s env t
prjIdx :: String
-> (forall t'. TupR s t' -> ShowS)
-> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
-> TupR s t
-> Int
-> Layout s env env1
-> Vars s env t
prjIdx String
context forall t'. TupR s t' -> ShowS
showTp forall u v. TupR s u -> TupR s v -> Maybe (u :~: v)
matchTp TupR s t
tp = Int -> Layout s env env1 -> Vars s env t
forall env'.
HasCallStack =>
Int -> Layout s env env' -> Vars s env t
go
  where
    go :: forall env'. HasCallStack => Int -> Layout s env env' -> Vars s env t
    go :: Int -> Layout s env env' -> Vars s env t
go Int
_ Layout s env env'
EmptyLayout                        = String -> Vars s env t
forall a. HasCallStack => String -> a
no String
"environment does not contain index"
    go Int
0 (PushLayout Layout s env env1
_ LeftHandSide s t env1 env'
lhs Vars s env t
vars)
      | Just t :~: t
Refl <- TupR s t -> TupR s t -> Maybe (t :~: t)
forall u v. TupR s u -> TupR s v -> Maybe (u :~: v)
matchTp TupR s t
tp TupR s t
tp'         = Vars s env t
Vars s env t
vars
      | Bool
otherwise                           = String -> Vars s env t
forall a. HasCallStack => String -> a
no (String -> Vars s env t) -> String -> Vars s env t
forall a b. (a -> b) -> a -> b
$ String -> String -> ShowS
forall r. PrintfType r => String -> r
printf String
"couldn't match expected type `%s' with actual type `%s'"
                                                          (TupR s t -> ShowS
forall t'. TupR s t' -> ShowS
showTp TupR s t
tp  String
"")
                                                          (TupR s t -> ShowS
forall t'. TupR s t' -> ShowS
showTp TupR s t
tp' String
"")
      where
        tp' :: TupR s t
tp' = LeftHandSide s t env1 env' -> TupR s t
forall (s :: * -> *) v env env'.
LeftHandSide s v env env' -> TupR s v
lhsToTupR LeftHandSide s t env1 env'
lhs
    go Int
n (PushLayout Layout s env env1
l LeftHandSide s t env1 env'
_ Vars s env t
_)                 = Int -> Layout s env env1 -> Vars s env t
forall env'.
HasCallStack =>
Int -> Layout s env env' -> Vars s env t
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Layout s env env1
l

    no :: HasCallStack => String -> a
    no :: String -> a
no String
reason = String -> a
forall a. HasCallStack => String -> a
internalError (String -> String -> ShowS
forall r. PrintfType r => String -> r
printf String
"%s\nin the context: %s" String
reason String
context)

-- Add an entry to a layout, incrementing all indices
--
incLayout :: env1 :> env2 -> Layout s env1 env' -> Layout s env2 env'
incLayout :: (env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout env1 :> env2
_ Layout s env1 env'
EmptyLayout            = Layout s env2 env'
forall (s :: * -> *) env. Layout s env ()
EmptyLayout
incLayout env1 :> env2
k (PushLayout Layout s env1 env1
lyt LeftHandSide s t env1 env'
lhs Vars s env1 t
v) = Layout s env2 env1
-> LeftHandSide s t env1 env'
-> Vars s env2 t
-> Layout s env2 env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((env1 :> env2) -> Layout s env1 env1 -> Layout s env2 env1
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout env1 :> env2
k Layout s env1 env1
lyt) LeftHandSide s t env1 env'
lhs ((env1 :> env2) -> Vars s env1 t -> Vars s env2 t
forall env env' (s :: * -> *) t.
(env :> env') -> Vars s env t -> Vars s env' t
weakenVars env1 :> env2
k Vars s env1 t
v)

sizeLayout :: Layout s env env' -> Int
sizeLayout :: Layout s env env' -> Int
sizeLayout Layout s env env'
EmptyLayout          = Int
0
sizeLayout (PushLayout Layout s env env1
lyt LeftHandSide s t env1 env'
_ Vars s env t
_) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Layout s env env1 -> Int
forall (s :: * -> *) env env'. Layout s env env' -> Int
sizeLayout Layout s env env1
lyt

-- Conversion from HOAS to de Bruijn computation AST
-- =================================================

-- Array computations
-- ------------------

-- | Convert a closed array expression to de Bruijn form while also incorporating sharing
-- information.
--
convertAcc :: HasCallStack => Acc arrs -> AST.Acc (Sugar.ArraysR arrs)
convertAcc :: Acc arrs -> Acc (ArraysR arrs)
convertAcc = Config -> Acc arrs -> Acc (ArraysR arrs)
forall arrs.
HasCallStack =>
Config -> Acc arrs -> Acc (ArraysR arrs)
convertAccWith Config
defaultOptions

convertAccWith :: HasCallStack => Config -> Acc arrs -> AST.Acc (Sugar.ArraysR arrs)
convertAccWith :: Config -> Acc arrs -> Acc (ArraysR arrs)
convertAccWith Config
config (Acc SmartAcc (ArraysR arrs)
acc) = Config
-> ArrayLayout () ()
-> SmartAcc (ArraysR arrs)
-> Acc (ArraysR arrs)
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv -> SmartAcc arrs -> OpenAcc aenv arrs
convertOpenAcc Config
config ArrayLayout () ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout SmartAcc (ArraysR arrs)
acc


-- | Convert a closed function over array computations, while incorporating
-- sharing information.
--
convertAfun :: HasCallStack => Afunction f => f -> AST.Afun (ArraysFunctionR f)
convertAfun :: f -> Afun (ArraysFunctionR f)
convertAfun = Config -> f -> Afun (ArraysFunctionR f)
forall f.
(HasCallStack, Afunction f) =>
Config -> f -> Afun (ArraysFunctionR f)
convertAfunWith Config
defaultOptions

convertAfunWith :: HasCallStack => Afunction f => Config -> f -> AST.Afun (ArraysFunctionR f)
convertAfunWith :: Config -> f -> Afun (ArraysFunctionR f)
convertAfunWith Config
config = Config -> ArrayLayout () () -> f -> Afun (ArraysFunctionR f)
forall f aenv.
(Afunction f, HasCallStack) =>
Config
-> ArrayLayout aenv aenv -> f -> OpenAfun aenv (ArraysFunctionR f)
convertOpenAfun Config
config ArrayLayout () ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout

data AfunctionRepr f ar areprr where
  AfunctionReprBody
    :: Arrays b => AfunctionRepr (Acc b) b (Sugar.ArraysR b)

  AfunctionReprLam
    :: Arrays a
    => AfunctionRepr b br breprr
    -> AfunctionRepr (Acc a -> b) (a -> br) (Sugar.ArraysR a -> breprr)

-- Convert a HOAS fragment into de Bruijn form, binding variables into the typed
-- environment layout one binder at a time.
--
-- NOTE: Because we convert one binder at a time left-to-right, the bound
--       variables ('vars') will have de Bruijn index _zero_ as the outermost
--       binding, and thus go to the end of the list.
--
class Afunction f where
  type AfunctionR f
  type ArraysFunctionR f
  afunctionRepr   :: HasCallStack => AfunctionRepr f (AfunctionR f) (ArraysFunctionR f)
  convertOpenAfun :: HasCallStack => Config -> ArrayLayout aenv aenv -> f -> AST.OpenAfun aenv (ArraysFunctionR f)

instance (Arrays a, Afunction r) => Afunction (Acc a -> r) where
  type AfunctionR      (Acc a -> r) = a -> AfunctionR r
  type ArraysFunctionR (Acc a -> r) = Sugar.ArraysR a -> ArraysFunctionR r

  afunctionRepr :: AfunctionRepr
  (Acc a -> r)
  (AfunctionR (Acc a -> r))
  (ArraysFunctionR (Acc a -> r))
afunctionRepr = AfunctionRepr r (AfunctionR r) (ArraysFunctionR r)
-> AfunctionRepr
     (Acc a -> r) (a -> AfunctionR r) (ArraysR a -> ArraysFunctionR r)
forall a b br breprr.
Arrays a =>
AfunctionRepr b br breprr
-> AfunctionRepr (Acc a -> b) (a -> br) (ArraysR a -> breprr)
AfunctionReprLam (AfunctionRepr r (AfunctionR r) (ArraysFunctionR r)
 -> AfunctionRepr
      (Acc a -> r) (a -> AfunctionR r) (ArraysR a -> ArraysFunctionR r))
-> AfunctionRepr r (AfunctionR r) (ArraysFunctionR r)
-> AfunctionRepr
     (Acc a -> r) (a -> AfunctionR r) (ArraysR a -> ArraysFunctionR r)
forall a b. (a -> b) -> a -> b
$ (Afunction r, HasCallStack) =>
AfunctionRepr r (AfunctionR r) (ArraysFunctionR r)
forall f.
(Afunction f, HasCallStack) =>
AfunctionRepr f (AfunctionR f) (ArraysFunctionR f)
afunctionRepr @r
  convertOpenAfun :: Config
-> ArrayLayout aenv aenv
-> (Acc a -> r)
-> OpenAfun aenv (ArraysFunctionR (Acc a -> r))
convertOpenAfun Config
config ArrayLayout aenv aenv
alyt Acc a -> r
f
    | ArraysR (ArraysR a)
repr <- Arrays a => ArraysR (ArraysR a)
forall a. Arrays a => ArraysR (ArraysR a)
Sugar.arraysR @a
    , DeclareVars LeftHandSide ArrayR (ArraysR a) aenv env'
lhs aenv :> env'
k forall env''. (env' :> env'') -> Vars ArrayR env'' (ArraysR a)
value <- ArraysR (ArraysR a) -> DeclareVars ArrayR (ArraysR a) aenv
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars ArraysR (ArraysR a)
repr
    = let
        a :: Acc a
a     = SmartAcc (ArraysR a) -> Acc a
forall a. SmartAcc (ArraysR a) -> Acc a
Acc (SmartAcc (ArraysR a) -> Acc a) -> SmartAcc (ArraysR a) -> Acc a
forall a b. (a -> b) -> a -> b
$ PreSmartAcc SmartAcc SmartExp (ArraysR a) -> SmartAcc (ArraysR a)
forall a. PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
SmartAcc (PreSmartAcc SmartAcc SmartExp (ArraysR a) -> SmartAcc (ArraysR a))
-> PreSmartAcc SmartAcc SmartExp (ArraysR a)
-> SmartAcc (ArraysR a)
forall a b. (a -> b) -> a -> b
$ ArraysR (ArraysR a)
-> Int -> PreSmartAcc SmartAcc SmartExp (ArraysR a)
forall as (acc :: * -> *) (exp :: * -> *).
ArraysR as -> Int -> PreSmartAcc acc exp as
Atag ArraysR (ArraysR a)
repr (Int -> PreSmartAcc SmartAcc SmartExp (ArraysR a))
-> Int -> PreSmartAcc SmartAcc SmartExp (ArraysR a)
forall a b. (a -> b) -> a -> b
$ ArrayLayout aenv aenv -> Int
forall (s :: * -> *) env env'. Layout s env env' -> Int
sizeLayout ArrayLayout aenv aenv
alyt
        alyt' :: Layout ArrayR env' env'
alyt' = Layout ArrayR env' aenv
-> LeftHandSide ArrayR (ArraysR a) aenv env'
-> Vars ArrayR env' (ArraysR a)
-> Layout ArrayR env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((aenv :> env') -> ArrayLayout aenv aenv -> Layout ArrayR env' aenv
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout aenv :> env'
k ArrayLayout aenv aenv
alyt) LeftHandSide ArrayR (ArraysR a) aenv env'
lhs ((env' :> env') -> Vars ArrayR env' (ArraysR a)
forall env''. (env' :> env'') -> Vars ArrayR env'' (ArraysR a)
value env' :> env'
forall env. env :> env
weakenId)
      in
        LeftHandSide ArrayR (ArraysR a) aenv env'
-> PreOpenAfun OpenAcc env' (ArraysFunctionR r)
-> PreOpenAfun OpenAcc aenv (ArraysR a -> ArraysFunctionR r)
forall a aenv aenv' (acc :: * -> * -> *) t.
ALeftHandSide a aenv aenv'
-> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t)
Alam LeftHandSide ArrayR (ArraysR a) aenv env'
lhs (PreOpenAfun OpenAcc env' (ArraysFunctionR r)
 -> PreOpenAfun OpenAcc aenv (ArraysR a -> ArraysFunctionR r))
-> PreOpenAfun OpenAcc env' (ArraysFunctionR r)
-> PreOpenAfun OpenAcc aenv (ArraysR a -> ArraysFunctionR r)
forall a b. (a -> b) -> a -> b
$ Config
-> Layout ArrayR env' env'
-> r
-> PreOpenAfun OpenAcc env' (ArraysFunctionR r)
forall f aenv.
(Afunction f, HasCallStack) =>
Config
-> ArrayLayout aenv aenv -> f -> OpenAfun aenv (ArraysFunctionR f)
convertOpenAfun Config
config Layout ArrayR env' env'
alyt' (r -> PreOpenAfun OpenAcc env' (ArraysFunctionR r))
-> r -> PreOpenAfun OpenAcc env' (ArraysFunctionR r)
forall a b. (a -> b) -> a -> b
$ Acc a -> r
f Acc a
a

instance Arrays b => Afunction (Acc b) where
  type AfunctionR      (Acc b) = b
  type ArraysFunctionR (Acc b) = Sugar.ArraysR b
  afunctionRepr :: AfunctionRepr
  (Acc b) (AfunctionR (Acc b)) (ArraysFunctionR (Acc b))
afunctionRepr = AfunctionRepr
  (Acc b) (AfunctionR (Acc b)) (ArraysFunctionR (Acc b))
forall b. Arrays b => AfunctionRepr (Acc b) b (ArraysR b)
AfunctionReprBody
  convertOpenAfun :: Config
-> ArrayLayout aenv aenv
-> Acc b
-> OpenAfun aenv (ArraysFunctionR (Acc b))
convertOpenAfun Config
config ArrayLayout aenv aenv
alyt (Acc SmartAcc (ArraysR b)
body) = OpenAcc aenv (ArraysR b) -> PreOpenAfun OpenAcc aenv (ArraysR b)
forall (acc :: * -> * -> *) aenv t.
acc aenv t -> PreOpenAfun acc aenv t
Abody (OpenAcc aenv (ArraysR b) -> PreOpenAfun OpenAcc aenv (ArraysR b))
-> OpenAcc aenv (ArraysR b) -> PreOpenAfun OpenAcc aenv (ArraysR b)
forall a b. (a -> b) -> a -> b
$ Config
-> ArrayLayout aenv aenv
-> SmartAcc (ArraysR b)
-> OpenAcc aenv (ArraysR b)
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv -> SmartAcc arrs -> OpenAcc aenv arrs
convertOpenAcc Config
config ArrayLayout aenv aenv
alyt SmartAcc (ArraysR b)
body

convertSmartAfun1
    :: HasCallStack
    => Config
    -> ArraysR a
    -> (SmartAcc a -> SmartAcc b)
    -> AST.Afun (a -> b)
convertSmartAfun1 :: Config -> ArraysR a -> (SmartAcc a -> SmartAcc b) -> Afun (a -> b)
convertSmartAfun1 Config
config ArraysR a
repr SmartAcc a -> SmartAcc b
f
  | DeclareVars LeftHandSide ArrayR a () env'
lhs () :> env'
_ forall env''. (env' :> env'') -> Vars ArrayR env'' a
value <- ArraysR a -> DeclareVars ArrayR a ()
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars ArraysR a
repr
  = let
      a :: SmartAcc a
a     = PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
forall a. PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
SmartAcc (PreSmartAcc SmartAcc SmartExp a -> SmartAcc a)
-> PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
forall a b. (a -> b) -> a -> b
$ ArraysR a -> Int -> PreSmartAcc SmartAcc SmartExp a
forall as (acc :: * -> *) (exp :: * -> *).
ArraysR as -> Int -> PreSmartAcc acc exp as
Atag ArraysR a
repr Int
0
      alyt' :: Layout ArrayR env' env'
alyt' = Layout ArrayR env' ()
-> LeftHandSide ArrayR a () env'
-> Vars ArrayR env' a
-> Layout ArrayR env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout Layout ArrayR env' ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout LeftHandSide ArrayR a () env'
lhs ((env' :> env') -> Vars ArrayR env' a
forall env''. (env' :> env'') -> Vars ArrayR env'' a
value env' :> env'
forall env. env :> env
weakenId)
    in
      LeftHandSide ArrayR a () env'
-> PreOpenAfun OpenAcc env' b -> Afun (a -> b)
forall a aenv aenv' (acc :: * -> * -> *) t.
ALeftHandSide a aenv aenv'
-> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t)
Alam LeftHandSide ArrayR a () env'
lhs (PreOpenAfun OpenAcc env' b -> Afun (a -> b))
-> PreOpenAfun OpenAcc env' b -> Afun (a -> b)
forall a b. (a -> b) -> a -> b
$ OpenAcc env' b -> PreOpenAfun OpenAcc env' b
forall (acc :: * -> * -> *) aenv t.
acc aenv t -> PreOpenAfun acc aenv t
Abody (OpenAcc env' b -> PreOpenAfun OpenAcc env' b)
-> OpenAcc env' b -> PreOpenAfun OpenAcc env' b
forall a b. (a -> b) -> a -> b
$ Config -> Layout ArrayR env' env' -> SmartAcc b -> OpenAcc env' b
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv -> SmartAcc arrs -> OpenAcc aenv arrs
convertOpenAcc Config
config Layout ArrayR env' env'
alyt' (SmartAcc b -> OpenAcc env' b) -> SmartAcc b -> OpenAcc env' b
forall a b. (a -> b) -> a -> b
$ SmartAcc a -> SmartAcc b
f SmartAcc a
a

-- | Convert an open array expression to de Bruijn form while also incorporating sharing
-- information.
--
convertOpenAcc
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> SmartAcc arrs
    -> AST.OpenAcc aenv arrs
convertOpenAcc :: Config
-> ArrayLayout aenv aenv -> SmartAcc arrs -> OpenAcc aenv arrs
convertOpenAcc Config
config ArrayLayout aenv aenv
alyt SmartAcc arrs
acc =
  let lvl :: Int
lvl                      = ArrayLayout aenv aenv -> Int
forall (s :: * -> *) env env'. Layout s env env' -> Int
sizeLayout ArrayLayout aenv aenv
alyt
      fvs :: [Int]
fvs                      = [Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2 .. Item [Int]
0]
      (ScopedAcc arrs
sharingAcc, [StableSharingAcc]
initialEnv) = Config
-> Int
-> [Int]
-> SmartAcc arrs
-> (ScopedAcc arrs, [StableSharingAcc])
forall a.
HasCallStack =>
Config
-> Int -> [Int] -> SmartAcc a -> (ScopedAcc a, [StableSharingAcc])
recoverSharingAcc Config
config Int
lvl [Int]
fvs SmartAcc arrs
acc
  in
  Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
initialEnv ScopedAcc arrs
sharingAcc


-- | Convert an array expression with given array environment layout and sharing information into
-- de Bruijn form while recovering sharing at the same time (by introducing appropriate let
-- bindings).  The latter implements the third phase of sharing recovery.
--
-- The sharing environment 'env' keeps track of all currently bound sharing variables, keeping them
-- in reverse chronological order (outermost variable is at the end of the list).
--
convertSharingAcc
    :: forall aenv arrs. HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]
    -> ScopedAcc arrs
    -> AST.OpenAcc aenv arrs
convertSharingAcc :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
_ ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv (ScopedAcc [StableSharingAcc]
lams (AvarSharing StableAccName arrs
sa ArraysR arrs
repr))
  | Just Int
i <- (StableSharingAcc -> Bool) -> [StableSharingAcc] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (StableAccName arrs -> StableSharingAcc -> Bool
forall arrs. StableAccName arrs -> StableSharingAcc -> Bool
matchStableAcc StableAccName arrs
sa) [StableSharingAcc]
aenv'
  = InjectAcc OpenAcc -> ArrayVars aenv arrs -> OpenAcc aenv arrs
forall (acc :: * -> * -> *) aenv arrs.
InjectAcc acc -> ArrayVars aenv arrs -> acc aenv arrs
avarsIn InjectAcc OpenAcc
AST.OpenAcc
  (ArrayVars aenv arrs -> OpenAcc aenv arrs)
-> ArrayVars aenv arrs -> OpenAcc aenv arrs
forall a b. (a -> b) -> a -> b
$ String
-> (forall t'. TupR ArrayR t' -> ShowS)
-> (forall u v. TupR ArrayR u -> TupR ArrayR v -> Maybe (u :~: v))
-> ArraysR arrs
-> Int
-> ArrayLayout aenv aenv
-> ArrayVars aenv arrs
forall (s :: * -> *) t env env1.
HasCallStack =>
String
-> (forall t'. TupR s t' -> ShowS)
-> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
-> TupR s t
-> Int
-> Layout s env env1
-> Vars s env t
prjIdx (String
ctxt String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; i = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) forall t'. TupR ArrayR t' -> ShowS
showArraysR forall u v. TupR ArrayR u -> TupR ArrayR v -> Maybe (u :~: v)
matchArraysR ArraysR arrs
repr Int
i ArrayLayout aenv aenv
alyt
  | [StableSharingAcc] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [StableSharingAcc]
aenv'
  = String -> OpenAcc aenv arrs
forall a. HasCallStack => String -> a
error (String -> OpenAcc aenv arrs) -> String -> OpenAcc aenv arrs
forall a b. (a -> b) -> a -> b
$ String
"Cyclic definition of a value of type 'Acc' (sa = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (StableAccName arrs -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableAccName arrs
sa) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  | Bool
otherwise
  = String -> OpenAcc aenv arrs
forall a. HasCallStack => String -> a
internalError String
err
  where
    aenv' :: [StableSharingAcc]
aenv' = [StableSharingAcc]
lams [StableSharingAcc] -> [StableSharingAcc] -> [StableSharingAcc]
forall a. [a] -> [a] -> [a]
++ [StableSharingAcc]
aenv
    ctxt :: String
ctxt = String
"shared 'Acc' tree with stable name " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (StableAccName arrs -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableAccName arrs
sa)
    err :: String
err  = String
"inconsistent valuation @ " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
ctxt String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
";\n  aenv = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [StableSharingAcc] -> String
forall a. Show a => a -> String
show [StableSharingAcc]
aenv'

convertSharingAcc Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv (ScopedAcc [StableSharingAcc]
lams (AletSharing sa :: StableSharingAcc
sa@(StableSharingAcc (StableAccName arrs
_ :: StableAccName as) SharingAcc ScopedAcc ScopedExp arrs
boundAcc) ScopedAcc arrs
bodyAcc))
  = case TupR ArrayR arrs -> DeclareVars ArrayR arrs aenv
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars (TupR ArrayR arrs -> DeclareVars ArrayR arrs aenv)
-> TupR ArrayR arrs -> DeclareVars ArrayR arrs aenv
forall a b. (a -> b) -> a -> b
$ OpenAcc aenv arrs -> TupR ArrayR arrs
forall (f :: * -> * -> *) aenv a.
HasArraysR f =>
f aenv a -> ArraysR a
AST.arraysR OpenAcc aenv arrs
bound of
      DeclareVars LeftHandSide ArrayR arrs aenv env'
lhs aenv :> env'
k forall env''. (env' :> env'') -> Vars ArrayR env'' arrs
value ->
        let
          alyt' :: Layout ArrayR env' env'
alyt' = Layout ArrayR env' aenv
-> LeftHandSide ArrayR arrs aenv env'
-> Vars ArrayR env' arrs
-> Layout ArrayR env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((aenv :> env') -> ArrayLayout aenv aenv -> Layout ArrayR env' aenv
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout aenv :> env'
k ArrayLayout aenv aenv
alyt) LeftHandSide ArrayR arrs aenv env'
lhs ((env' :> env') -> Vars ArrayR env' arrs
forall env''. (env' :> env'') -> Vars ArrayR env'' arrs
value env' :> env'
forall env. env :> env
weakenId)
        in
          PreOpenAcc OpenAcc aenv arrs -> OpenAcc aenv arrs
InjectAcc OpenAcc
AST.OpenAcc (PreOpenAcc OpenAcc aenv arrs -> OpenAcc aenv arrs)
-> PreOpenAcc OpenAcc aenv arrs -> OpenAcc aenv arrs
forall a b. (a -> b) -> a -> b
$ LeftHandSide ArrayR arrs aenv env'
-> OpenAcc aenv arrs
-> OpenAcc env' arrs
-> PreOpenAcc OpenAcc aenv arrs
forall bndArrs aenv aenv' (acc :: * -> * -> *) bodyArrs.
ALeftHandSide bndArrs aenv aenv'
-> acc aenv bndArrs
-> acc aenv' bodyArrs
-> PreOpenAcc acc aenv bodyArrs
AST.Alet
            LeftHandSide ArrayR arrs aenv env'
lhs
            OpenAcc aenv arrs
bound
            (Config
-> Layout ArrayR env' env'
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc env' arrs
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
config Layout ArrayR env' env'
alyt' (StableSharingAcc
saStableSharingAcc -> [StableSharingAcc] -> [StableSharingAcc]
forall a. a -> [a] -> [a]
:[StableSharingAcc]
aenv') ScopedAcc arrs
bodyAcc)
  where
    aenv' :: [StableSharingAcc]
aenv' = [StableSharingAcc]
lams [StableSharingAcc] -> [StableSharingAcc] -> [StableSharingAcc]
forall a. [a] -> [a] -> [a]
++ [StableSharingAcc]
aenv
    bound :: OpenAcc aenv arrs
bound = Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv' ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs -> ScopedAcc arrs
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [] SharingAcc ScopedAcc ScopedExp arrs
boundAcc)

convertSharingAcc Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv (ScopedAcc [StableSharingAcc]
lams (AccSharing StableAccName arrs
_ PreSmartAcc ScopedAcc ScopedExp arrs
preAcc))
  = PreOpenAcc OpenAcc aenv arrs -> OpenAcc aenv arrs
InjectAcc OpenAcc
AST.OpenAcc
  (PreOpenAcc OpenAcc aenv arrs -> OpenAcc aenv arrs)
-> PreOpenAcc OpenAcc aenv arrs -> OpenAcc aenv arrs
forall a b. (a -> b) -> a -> b
$ let aenv' :: [StableSharingAcc]
aenv' = [StableSharingAcc]
lams [StableSharingAcc] -> [StableSharingAcc] -> [StableSharingAcc]
forall a. [a] -> [a] -> [a]
++ [StableSharingAcc]
aenv

        cvtA :: ScopedAcc a -> AST.OpenAcc aenv a
        cvtA :: ScopedAcc a -> OpenAcc aenv a
cvtA = Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc a
-> OpenAcc aenv a
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv'

        cvtE :: ScopedExp t -> AST.Exp aenv t
        cvtE :: ScopedExp t -> Exp aenv t
cvtE = Config
-> ELayout () ()
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> Exp aenv t
forall t env aenv.
HasCallStack =>
Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config ELayout () ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout ArrayLayout aenv aenv
alyt [] [StableSharingAcc]
aenv'

        cvtF1 :: TypeR a -> (SmartExp a -> ScopedExp b) -> AST.Fun aenv (a -> b)
        cvtF1 :: TypeR a -> (SmartExp a -> ScopedExp b) -> Fun aenv (a -> b)
cvtF1 = Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> (SmartExp a -> ScopedExp b)
-> Fun aenv (a -> b)
forall aenv a b.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> (SmartExp a -> ScopedExp b)
-> Fun aenv (a -> b)
convertSharingFun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv'

        cvtF2 :: TypeR a -> TypeR b -> (SmartExp a -> SmartExp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
        cvtF2 :: TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 = Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
forall aenv a b c.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
convertSharingFun2 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv'

        cvtAfun1 :: ArraysR a -> (SmartAcc a -> ScopedAcc b) -> AST.OpenAfun aenv (a -> b)
        cvtAfun1 :: ArraysR a -> (SmartAcc a -> ScopedAcc b) -> OpenAfun aenv (a -> b)
cvtAfun1 = Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ArraysR a
-> (SmartAcc a -> ScopedAcc b)
-> OpenAfun aenv (a -> b)
forall aenv a b.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ArraysR a
-> (SmartAcc a -> ScopedAcc b)
-> OpenAfun aenv (a -> b)
convertSharingAfun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv'

        cvtAprj :: forall a b c. PairIdx (a, b) c -> ScopedAcc (a, b) -> AST.OpenAcc aenv c
        cvtAprj :: PairIdx (a, b) c -> ScopedAcc (a, b) -> OpenAcc aenv c
cvtAprj PairIdx (a, b) c
ix ScopedAcc (a, b)
a = PairIdx (a, b) c -> OpenAcc aenv (a, b) -> OpenAcc aenv c
forall a b c aenv1.
PairIdx (a, b) c -> OpenAcc aenv1 (a, b) -> OpenAcc aenv1 c
cvtAprj' PairIdx (a, b) c
ix (OpenAcc aenv (a, b) -> OpenAcc aenv c)
-> OpenAcc aenv (a, b) -> OpenAcc aenv c
forall a b. (a -> b) -> a -> b
$ ScopedAcc (a, b) -> OpenAcc aenv (a, b)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (a, b)
a

        cvtAprj' :: forall a b c aenv1. PairIdx (a, b) c -> AST.OpenAcc aenv1 (a, b) -> AST.OpenAcc aenv1 c
        cvtAprj' :: PairIdx (a, b) c -> OpenAcc aenv1 (a, b) -> OpenAcc aenv1 c
cvtAprj' PairIdx (a, b) c
PairIdxLeft  (AST.OpenAcc (AST.Apair OpenAcc aenv1 as
a OpenAcc aenv1 bs
_)) = OpenAcc aenv1 c
OpenAcc aenv1 as
a
        cvtAprj' PairIdx (a, b) c
PairIdxRight (AST.OpenAcc (AST.Apair OpenAcc aenv1 as
_ OpenAcc aenv1 bs
b)) = OpenAcc aenv1 c
OpenAcc aenv1 bs
b
        cvtAprj' PairIdx (a, b) c
ix OpenAcc aenv1 (a, b)
a = case TupR ArrayR (a, b) -> DeclareVars ArrayR (a, b) aenv1
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars (TupR ArrayR (a, b) -> DeclareVars ArrayR (a, b) aenv1)
-> TupR ArrayR (a, b) -> DeclareVars ArrayR (a, b) aenv1
forall a b. (a -> b) -> a -> b
$ OpenAcc aenv1 (a, b) -> TupR ArrayR (a, b)
forall (f :: * -> * -> *) aenv a.
HasArraysR f =>
f aenv a -> ArraysR a
AST.arraysR OpenAcc aenv1 (a, b)
a of
          DeclareVars LeftHandSide ArrayR (a, b) aenv1 env'
lhs aenv1 :> env'
_ forall env''. (env' :> env'') -> Vars ArrayR env'' (a, b)
value ->
            PreOpenAcc OpenAcc aenv1 c -> OpenAcc aenv1 c
InjectAcc OpenAcc
AST.OpenAcc (PreOpenAcc OpenAcc aenv1 c -> OpenAcc aenv1 c)
-> PreOpenAcc OpenAcc aenv1 c -> OpenAcc aenv1 c
forall a b. (a -> b) -> a -> b
$ LeftHandSide ArrayR (a, b) aenv1 env'
-> OpenAcc aenv1 (a, b)
-> OpenAcc env' c
-> PreOpenAcc OpenAcc aenv1 c
forall bndArrs aenv aenv' (acc :: * -> * -> *) bodyArrs.
ALeftHandSide bndArrs aenv aenv'
-> acc aenv bndArrs
-> acc aenv' bodyArrs
-> PreOpenAcc acc aenv bodyArrs
AST.Alet LeftHandSide ArrayR (a, b) aenv1 env'
lhs OpenAcc aenv1 (a, b)
a (OpenAcc env' c -> PreOpenAcc OpenAcc aenv1 c)
-> OpenAcc env' c -> PreOpenAcc OpenAcc aenv1 c
forall a b. (a -> b) -> a -> b
$ PairIdx (a, b) c -> OpenAcc env' (a, b) -> OpenAcc env' c
forall a b c aenv1.
PairIdx (a, b) c -> OpenAcc aenv1 (a, b) -> OpenAcc aenv1 c
cvtAprj' PairIdx (a, b) c
ix (OpenAcc env' (a, b) -> OpenAcc env' c)
-> OpenAcc env' (a, b) -> OpenAcc env' c
forall a b. (a -> b) -> a -> b
$ InjectAcc OpenAcc -> ArrayVars env' (a, b) -> OpenAcc env' (a, b)
forall (acc :: * -> * -> *) aenv arrs.
InjectAcc acc -> ArrayVars aenv arrs -> acc aenv arrs
avarsIn InjectAcc OpenAcc
AST.OpenAcc (ArrayVars env' (a, b) -> OpenAcc env' (a, b))
-> ArrayVars env' (a, b) -> OpenAcc env' (a, b)
forall a b. (a -> b) -> a -> b
$ (env' :> env') -> ArrayVars env' (a, b)
forall env''. (env' :> env'') -> Vars ArrayR env'' (a, b)
value env' :> env'
forall env. env :> env
weakenId
    in
    case PreSmartAcc ScopedAcc ScopedExp arrs
preAcc of

      Atag ArraysR arrs
repr Int
i
        -> let AST.OpenAcc PreOpenAcc OpenAcc aenv arrs
a = InjectAcc OpenAcc -> ArrayVars aenv arrs -> OpenAcc aenv arrs
forall (acc :: * -> * -> *) aenv arrs.
InjectAcc acc -> ArrayVars aenv arrs -> acc aenv arrs
avarsIn InjectAcc OpenAcc
AST.OpenAcc (ArrayVars aenv arrs -> OpenAcc aenv arrs)
-> ArrayVars aenv arrs -> OpenAcc aenv arrs
forall a b. (a -> b) -> a -> b
$ String
-> (forall t'. TupR ArrayR t' -> ShowS)
-> (forall u v. TupR ArrayR u -> TupR ArrayR v -> Maybe (u :~: v))
-> ArraysR arrs
-> Int
-> ArrayLayout aenv aenv
-> ArrayVars aenv arrs
forall (s :: * -> *) t env env1.
HasCallStack =>
String
-> (forall t'. TupR s t' -> ShowS)
-> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
-> TupR s t
-> Int
-> Layout s env env1
-> Vars s env t
prjIdx (String
"de Bruijn conversion tag " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) forall t'. TupR ArrayR t' -> ShowS
showArraysR forall u v. TupR ArrayR u -> TupR ArrayR v -> Maybe (u :~: v)
matchArraysR ArraysR arrs
repr Int
i ArrayLayout aenv aenv
alyt
           in  PreOpenAcc OpenAcc aenv arrs
a

      Pipe ArraysR as
reprA ArraysR bs
reprB ArraysR arrs
reprC (SmartAcc as -> ScopedAcc bs
afun1 :: SmartAcc as -> ScopedAcc bs) (SmartAcc bs -> ScopedAcc arrs
afun2 :: SmartAcc bs -> ScopedAcc cs) ScopedAcc as
acc
        | DeclareVars LeftHandSide ArrayR bs aenv env'
lhs aenv :> env'
k forall env''. (env' :> env'') -> Vars ArrayR env'' bs
value <- ArraysR bs -> DeclareVars ArrayR bs aenv
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars ArraysR bs
reprB ->
          let
            noStableSharing :: StableSharingAcc
noStableSharing = StableAccName ()
-> SharingAcc ScopedAcc ScopedExp () -> StableSharingAcc
forall arrs.
StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
StableSharingAcc StableAccName ()
forall arrs. StableAccName arrs
noStableAccName (forall a. HasCallStack => a
forall (acc :: * -> *) (exp :: * -> *). SharingAcc acc exp ()
undefined :: SharingAcc acc exp ())
            boundAcc :: PreOpenAcc OpenAcc aenv bs
boundAcc = ArraysR bs
-> PreOpenAfun OpenAcc aenv (as -> bs)
-> OpenAcc aenv as
-> PreOpenAcc OpenAcc aenv bs
forall arrs2 (acc :: * -> * -> *) aenv arrs1.
ArraysR arrs2
-> PreOpenAfun acc aenv (arrs1 -> arrs2)
-> acc aenv arrs1
-> PreOpenAcc acc aenv arrs2
AST.Apply ArraysR bs
reprB (ArraysR as
-> (SmartAcc as -> ScopedAcc bs)
-> PreOpenAfun OpenAcc aenv (as -> bs)
forall a b.
ArraysR a -> (SmartAcc a -> ScopedAcc b) -> OpenAfun aenv (a -> b)
cvtAfun1 ArraysR as
reprA SmartAcc as -> ScopedAcc bs
afun1) (ScopedAcc as -> OpenAcc aenv as
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc as
acc)
            alyt' :: Layout ArrayR env' env'
alyt'   = Layout ArrayR env' aenv
-> LeftHandSide ArrayR bs aenv env'
-> Vars ArrayR env' bs
-> Layout ArrayR env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((aenv :> env') -> ArrayLayout aenv aenv -> Layout ArrayR env' aenv
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout aenv :> env'
k ArrayLayout aenv aenv
alyt) LeftHandSide ArrayR bs aenv env'
lhs ((env' :> env') -> Vars ArrayR env' bs
forall env''. (env' :> env'') -> Vars ArrayR env'' bs
value env' :> env'
forall env. env :> env
weakenId)
            bodyAcc :: PreOpenAcc OpenAcc env' arrs
bodyAcc = ArraysR arrs
-> PreOpenAfun OpenAcc env' (bs -> arrs)
-> OpenAcc env' bs
-> PreOpenAcc OpenAcc env' arrs
forall arrs2 (acc :: * -> * -> *) aenv arrs1.
ArraysR arrs2
-> PreOpenAfun acc aenv (arrs1 -> arrs2)
-> acc aenv arrs1
-> PreOpenAcc acc aenv arrs2
AST.Apply ArraysR arrs
reprC
                        (Config
-> Layout ArrayR env' env'
-> [StableSharingAcc]
-> ArraysR bs
-> (SmartAcc bs -> ScopedAcc arrs)
-> PreOpenAfun OpenAcc env' (bs -> arrs)
forall aenv a b.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ArraysR a
-> (SmartAcc a -> ScopedAcc b)
-> OpenAfun aenv (a -> b)
convertSharingAfun1 Config
config Layout ArrayR env' env'
alyt' (StableSharingAcc
noStableSharing StableSharingAcc -> [StableSharingAcc] -> [StableSharingAcc]
forall a. a -> [a] -> [a]
: [StableSharingAcc]
aenv') ArraysR bs
reprB SmartAcc bs -> ScopedAcc arrs
afun2)
                        (InjectAcc OpenAcc -> Vars ArrayR env' bs -> OpenAcc env' bs
forall (acc :: * -> * -> *) aenv arrs.
InjectAcc acc -> ArrayVars aenv arrs -> acc aenv arrs
avarsIn InjectAcc OpenAcc
AST.OpenAcc (Vars ArrayR env' bs -> OpenAcc env' bs)
-> Vars ArrayR env' bs -> OpenAcc env' bs
forall a b. (a -> b) -> a -> b
$ (env' :> env') -> Vars ArrayR env' bs
forall env''. (env' :> env'') -> Vars ArrayR env'' bs
value env' :> env'
forall env. env :> env
weakenId)
          in LeftHandSide ArrayR bs aenv env'
-> OpenAcc aenv bs
-> OpenAcc env' arrs
-> PreOpenAcc OpenAcc aenv arrs
forall bndArrs aenv aenv' (acc :: * -> * -> *) bodyArrs.
ALeftHandSide bndArrs aenv aenv'
-> acc aenv bndArrs
-> acc aenv' bodyArrs
-> PreOpenAcc acc aenv bodyArrs
AST.Alet LeftHandSide ArrayR bs aenv env'
lhs (PreOpenAcc OpenAcc aenv bs -> OpenAcc aenv bs
InjectAcc OpenAcc
AST.OpenAcc PreOpenAcc OpenAcc aenv bs
boundAcc) (PreOpenAcc OpenAcc env' arrs -> OpenAcc env' arrs
InjectAcc OpenAcc
AST.OpenAcc PreOpenAcc OpenAcc env' arrs
bodyAcc)

      Aforeign ArraysR arrs
repr asm (as -> arrs)
ff SmartAcc as -> SmartAcc arrs
afun ScopedAcc as
acc
        -> ArraysR arrs
-> asm (as -> arrs)
-> PreAfun OpenAcc (as -> arrs)
-> OpenAcc aenv as
-> PreOpenAcc OpenAcc aenv arrs
forall (asm :: * -> *) bs as (acc :: * -> * -> *) aenv.
Foreign asm =>
ArraysR bs
-> asm (as -> bs)
-> PreAfun acc (as -> bs)
-> acc aenv as
-> PreOpenAcc acc aenv bs
AST.Aforeign ArraysR arrs
repr asm (as -> arrs)
ff (Config
-> ArraysR as
-> (SmartAcc as -> SmartAcc arrs)
-> PreAfun OpenAcc (as -> arrs)
forall a b.
HasCallStack =>
Config -> ArraysR a -> (SmartAcc a -> SmartAcc b) -> Afun (a -> b)
convertSmartAfun1 Config
config (ScopedAcc as -> ArraysR as
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR ScopedAcc as
acc) SmartAcc as -> SmartAcc arrs
afun) (ScopedAcc as -> OpenAcc aenv as
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc as
acc)

      Acond ScopedExp PrimBool
b ScopedAcc arrs
acc1 ScopedAcc arrs
acc2           -> Exp aenv PrimBool
-> OpenAcc aenv arrs
-> OpenAcc aenv arrs
-> PreOpenAcc OpenAcc aenv arrs
forall aenv (acc :: * -> * -> *) arrs.
Exp aenv PrimBool
-> acc aenv arrs -> acc aenv arrs -> PreOpenAcc acc aenv arrs
AST.Acond (ScopedExp PrimBool -> Exp aenv PrimBool
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp PrimBool
b) (ScopedAcc arrs -> OpenAcc aenv arrs
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc arrs
acc1) (ScopedAcc arrs -> OpenAcc aenv arrs
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc arrs
acc2)
      Awhile ArraysR arrs
reprA SmartAcc arrs -> ScopedAcc (Scalar PrimBool)
pred SmartAcc arrs -> ScopedAcc arrs
iter ScopedAcc arrs
init -> PreOpenAfun OpenAcc aenv (arrs -> Scalar PrimBool)
-> PreOpenAfun OpenAcc aenv (arrs -> arrs)
-> OpenAcc aenv arrs
-> PreOpenAcc OpenAcc aenv arrs
forall (acc :: * -> * -> *) aenv arrs.
PreOpenAfun acc aenv (arrs -> Scalar PrimBool)
-> PreOpenAfun acc aenv (arrs -> arrs)
-> acc aenv arrs
-> PreOpenAcc acc aenv arrs
AST.Awhile (ArraysR arrs
-> (SmartAcc arrs -> ScopedAcc (Scalar PrimBool))
-> PreOpenAfun OpenAcc aenv (arrs -> Scalar PrimBool)
forall a b.
ArraysR a -> (SmartAcc a -> ScopedAcc b) -> OpenAfun aenv (a -> b)
cvtAfun1 ArraysR arrs
reprA SmartAcc arrs -> ScopedAcc (Scalar PrimBool)
pred) (ArraysR arrs
-> (SmartAcc arrs -> ScopedAcc arrs)
-> PreOpenAfun OpenAcc aenv (arrs -> arrs)
forall a b.
ArraysR a -> (SmartAcc a -> ScopedAcc b) -> OpenAfun aenv (a -> b)
cvtAfun1 ArraysR arrs
reprA SmartAcc arrs -> ScopedAcc arrs
iter) (ScopedAcc arrs -> OpenAcc aenv arrs
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc arrs
init)
      PreSmartAcc ScopedAcc ScopedExp arrs
Anil                        -> PreOpenAcc OpenAcc aenv arrs
forall (acc :: * -> * -> *) aenv. PreOpenAcc acc aenv ()
AST.Anil
      Apair ScopedAcc arrs1
acc1 ScopedAcc arrs2
acc2             -> OpenAcc aenv arrs1
-> OpenAcc aenv arrs2 -> PreOpenAcc OpenAcc aenv (arrs1, arrs2)
forall (acc :: * -> * -> *) aenv as bs.
acc aenv as -> acc aenv bs -> PreOpenAcc acc aenv (as, bs)
AST.Apair (ScopedAcc arrs1 -> OpenAcc aenv arrs1
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc arrs1
acc1) (ScopedAcc arrs2 -> OpenAcc aenv arrs2
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc arrs2
acc2)
      Aprj PairIdx (arrs1, arrs2) arrs
ix ScopedAcc (arrs1, arrs2)
a                   -> let AST.OpenAcc PreOpenAcc OpenAcc aenv arrs
a' = PairIdx (arrs1, arrs2) arrs
-> ScopedAcc (arrs1, arrs2) -> OpenAcc aenv arrs
forall a b c.
PairIdx (a, b) c -> ScopedAcc (a, b) -> OpenAcc aenv c
cvtAprj PairIdx (arrs1, arrs2) arrs
ix ScopedAcc (arrs1, arrs2)
a
                                     in PreOpenAcc OpenAcc aenv arrs
a'
      Use ArrayR (Array sh e)
repr Array sh e
array              -> ArrayR (Array sh e)
-> Array sh e -> PreOpenAcc OpenAcc aenv (Array sh e)
forall sh e (acc :: * -> * -> *) aenv.
ArrayR (Array sh e)
-> Array sh e -> PreOpenAcc acc aenv (Array sh e)
AST.Use ArrayR (Array sh e)
repr Array sh e
array
      Unit TypeR e
tp ScopedExp e
e                   -> TypeR e -> Exp aenv e -> PreOpenAcc OpenAcc aenv (Scalar e)
forall e aenv (acc :: * -> * -> *).
TypeR e -> Exp aenv e -> PreOpenAcc acc aenv (Scalar e)
AST.Unit TypeR e
tp (ScopedExp e -> Exp aenv e
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp e
e)
      Generate repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
_) ScopedExp sh
sh SmartExp sh -> ScopedExp e
f
                                  -> ArrayR (Array sh e)
-> Exp aenv sh
-> Fun aenv (sh -> e)
-> PreOpenAcc OpenAcc aenv (Array sh e)
forall sh e aenv (acc :: * -> * -> *).
ArrayR (Array sh e)
-> Exp aenv sh
-> Fun aenv (sh -> e)
-> PreOpenAcc acc aenv (Array sh e)
AST.Generate ArrayR (Array sh e)
repr (ScopedExp sh -> Exp aenv sh
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp sh
sh) (TypeR sh -> (SmartExp sh -> ScopedExp e) -> Fun aenv (sh -> e)
forall a b.
TypeR a -> (SmartExp a -> ScopedExp b) -> Fun aenv (a -> b)
cvtF1 (ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr) SmartExp sh -> ScopedExp e
SmartExp sh -> ScopedExp e
f)
      Reshape ShapeR sh
shr ScopedExp sh
e ScopedAcc (Array sh' e)
acc           -> ShapeR sh
-> Exp aenv sh
-> OpenAcc aenv (Array sh' e)
-> PreOpenAcc OpenAcc aenv (Array sh e)
forall sh aenv (acc :: * -> * -> *) sh' e.
ShapeR sh
-> Exp aenv sh
-> acc aenv (Array sh' e)
-> PreOpenAcc acc aenv (Array sh e)
AST.Reshape ShapeR sh
shr (ScopedExp sh -> Exp aenv sh
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp sh
e) (ScopedAcc (Array sh' e) -> OpenAcc aenv (Array sh' e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh' e)
acc)
      Replicate SliceIndex slix sl co sh
si ScopedExp slix
ix ScopedAcc (Array sl e)
acc         -> SliceIndex slix sl co sh
-> Exp aenv slix
-> OpenAcc aenv (Array sl e)
-> PreOpenAcc OpenAcc aenv (Array sh e)
forall slix sl co sh aenv (acc :: * -> * -> *) e.
SliceIndex slix sl co sh
-> Exp aenv slix
-> acc aenv (Array sl e)
-> PreOpenAcc acc aenv (Array sh e)
AST.Replicate SliceIndex slix sl co sh
si (ScopedExp slix -> Exp aenv slix
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp slix
ix) (ScopedAcc (Array sl e) -> OpenAcc aenv (Array sl e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sl e)
acc)
      Slice SliceIndex slix sl co sh
si ScopedAcc (Array sh e)
acc ScopedExp slix
ix             -> SliceIndex slix sl co sh
-> OpenAcc aenv (Array sh e)
-> Exp aenv slix
-> PreOpenAcc OpenAcc aenv (Array sl e)
forall slix sl co sh (acc :: * -> * -> *) aenv e.
SliceIndex slix sl co sh
-> acc aenv (Array sh e)
-> Exp aenv slix
-> PreOpenAcc acc aenv (Array sl e)
AST.Slice SliceIndex slix sl co sh
si (ScopedAcc (Array sh e) -> OpenAcc aenv (Array sh e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh e)
acc) (ScopedExp slix -> Exp aenv slix
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp slix
ix)
      Map TypeR e
t1 TypeR e'
t2 SmartExp e -> ScopedExp e'
f ScopedAcc (Array sh e)
acc             -> TypeR e'
-> Fun aenv (e -> e')
-> OpenAcc aenv (Array sh e)
-> PreOpenAcc OpenAcc aenv (Array sh e')
forall e' aenv e (acc :: * -> * -> *) sh.
TypeR e'
-> Fun aenv (e -> e')
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh e')
AST.Map TypeR e'
t2 (TypeR e -> (SmartExp e -> ScopedExp e') -> Fun aenv (e -> e')
forall a b.
TypeR a -> (SmartExp a -> ScopedExp b) -> Fun aenv (a -> b)
cvtF1 TypeR e
t1 SmartExp e -> ScopedExp e'
f) (ScopedAcc (Array sh e) -> OpenAcc aenv (Array sh e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh e)
acc)
      ZipWith TypeR e1
t1 TypeR e2
t2 TypeR e3
t3 SmartExp e1 -> SmartExp e2 -> ScopedExp e3
f ScopedAcc (Array sh e1)
acc1 ScopedAcc (Array sh e2)
acc2
                                  -> TypeR e3
-> Fun aenv (e1 -> e2 -> e3)
-> OpenAcc aenv (Array sh e1)
-> OpenAcc aenv (Array sh e2)
-> PreOpenAcc OpenAcc aenv (Array sh e3)
forall e3 aenv e1 e2 (acc :: * -> * -> *) sh.
TypeR e3
-> Fun aenv (e1 -> e2 -> e3)
-> acc aenv (Array sh e1)
-> acc aenv (Array sh e2)
-> PreOpenAcc acc aenv (Array sh e3)
AST.ZipWith TypeR e3
t3 (TypeR e1
-> TypeR e2
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3)
-> Fun aenv (e1 -> e2 -> e3)
forall a b c.
TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 TypeR e1
t1 TypeR e2
t2 SmartExp e1 -> SmartExp e2 -> ScopedExp e3
f) (ScopedAcc (Array sh e1) -> OpenAcc aenv (Array sh e1)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh e1)
acc1) (ScopedAcc (Array sh e2) -> OpenAcc aenv (Array sh e2)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh e2)
acc2)
      Fold TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f Maybe (ScopedExp e)
e ScopedAcc (Array (sh, Int) e)
acc             -> Fun aenv (e -> e -> e)
-> Maybe (Exp aenv e)
-> OpenAcc aenv (Array (sh, Int) e)
-> PreOpenAcc OpenAcc aenv (Array sh e)
forall aenv e (acc :: * -> * -> *) i.
Fun aenv (e -> e -> e)
-> Maybe (Exp aenv e)
-> acc aenv (Array (i, Int) e)
-> PreOpenAcc acc aenv (Array i e)
AST.Fold (TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Fun aenv (e -> e -> e)
forall a b c.
TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f) (ScopedExp e -> Exp aenv e
forall t. ScopedExp t -> Exp aenv t
cvtE (ScopedExp e -> Exp aenv e)
-> Maybe (ScopedExp e) -> Maybe (Exp aenv e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (ScopedExp e)
e) (ScopedAcc (Array (sh, Int) e) -> OpenAcc aenv (Array (sh, Int) e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array (sh, Int) e)
acc)
      FoldSeg IntegralType i
i TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f Maybe (ScopedExp e)
e ScopedAcc (Array (sh, Int) e)
acc1 ScopedAcc (Segments i)
acc2  -> IntegralType i
-> Fun aenv (e -> e -> e)
-> Maybe (Exp aenv e)
-> OpenAcc aenv (Array (sh, Int) e)
-> OpenAcc aenv (Segments i)
-> PreOpenAcc OpenAcc aenv (Array (sh, Int) e)
forall i aenv e (acc :: * -> * -> *) e.
IntegralType i
-> Fun aenv (e -> e -> e)
-> Maybe (Exp aenv e)
-> acc aenv (Array (e, Int) e)
-> acc aenv (Segments i)
-> PreOpenAcc acc aenv (Array (e, Int) e)
AST.FoldSeg IntegralType i
i (TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Fun aenv (e -> e -> e)
forall a b c.
TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f) (ScopedExp e -> Exp aenv e
forall t. ScopedExp t -> Exp aenv t
cvtE (ScopedExp e -> Exp aenv e)
-> Maybe (ScopedExp e) -> Maybe (Exp aenv e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (ScopedExp e)
e) (ScopedAcc (Array (sh, Int) e) -> OpenAcc aenv (Array (sh, Int) e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array (sh, Int) e)
acc1) (ScopedAcc (Segments i) -> OpenAcc aenv (Segments i)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Segments i)
acc2)
      Scan  Direction
d TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f Maybe (ScopedExp e)
e ScopedAcc (Array (sh, Int) e)
acc          -> Direction
-> Fun aenv (e -> e -> e)
-> Maybe (Exp aenv e)
-> OpenAcc aenv (Array (sh, Int) e)
-> PreOpenAcc OpenAcc aenv (Array (sh, Int) e)
forall aenv e (acc :: * -> * -> *) sh.
Direction
-> Fun aenv (e -> e -> e)
-> Maybe (Exp aenv e)
-> acc aenv (Array (sh, Int) e)
-> PreOpenAcc acc aenv (Array (sh, Int) e)
AST.Scan  Direction
d (TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Fun aenv (e -> e -> e)
forall a b c.
TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f) (ScopedExp e -> Exp aenv e
forall t. ScopedExp t -> Exp aenv t
cvtE (ScopedExp e -> Exp aenv e)
-> Maybe (ScopedExp e) -> Maybe (Exp aenv e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (ScopedExp e)
e) (ScopedAcc (Array (sh, Int) e) -> OpenAcc aenv (Array (sh, Int) e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array (sh, Int) e)
acc)
      Scan' Direction
d TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f ScopedExp e
e ScopedAcc (Array (sh, Int) e)
acc          -> Direction
-> Fun aenv (e -> e -> e)
-> Exp aenv e
-> OpenAcc aenv (Array (sh, Int) e)
-> PreOpenAcc OpenAcc aenv (Array (sh, Int) e, Array sh e)
forall aenv e (acc :: * -> * -> *) sh.
Direction
-> Fun aenv (e -> e -> e)
-> Exp aenv e
-> acc aenv (Array (sh, Int) e)
-> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e)
AST.Scan' Direction
d (TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Fun aenv (e -> e -> e)
forall a b c.
TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f) (ScopedExp e -> Exp aenv e
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp e
e)     (ScopedAcc (Array (sh, Int) e) -> OpenAcc aenv (Array (sh, Int) e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array (sh, Int) e)
acc)
      Permute (ArrayR ShapeR sh
shr TypeR e
tp) SmartExp e -> SmartExp e -> ScopedExp e
f ScopedAcc (Array sh' e)
dftAcc SmartExp sh -> ScopedExp (PrimMaybe sh')
perm ScopedAcc (Array sh e)
acc
                                  -> Fun aenv (e -> e -> e)
-> OpenAcc aenv (Array sh' e)
-> Fun aenv (sh -> PrimMaybe sh')
-> OpenAcc aenv (Array sh e)
-> PreOpenAcc OpenAcc aenv (Array sh' e)
forall aenv e (acc :: * -> * -> *) sh' sh.
Fun aenv (e -> e -> e)
-> acc aenv (Array sh' e)
-> Fun aenv (sh -> PrimMaybe sh')
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh' e)
AST.Permute (TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Fun aenv (e -> e -> e)
forall a b c.
TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
cvtF2 TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
SmartExp e -> SmartExp e -> ScopedExp e
f) (ScopedAcc (Array sh' e) -> OpenAcc aenv (Array sh' e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh' e)
dftAcc) (TypeR sh
-> (SmartExp sh -> ScopedExp (PrimMaybe sh'))
-> Fun aenv (sh -> PrimMaybe sh')
forall a b.
TypeR a -> (SmartExp a -> ScopedExp b) -> Fun aenv (a -> b)
cvtF1 (ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr) SmartExp sh -> ScopedExp (PrimMaybe sh')
SmartExp sh -> ScopedExp (PrimMaybe sh')
perm) (ScopedAcc (Array sh e) -> OpenAcc aenv (Array sh e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh e)
acc)
      Backpermute ShapeR sh'
shr ScopedExp sh'
newDim SmartExp sh' -> ScopedExp sh
perm ScopedAcc (Array sh e)
acc
                                  -> ShapeR sh'
-> Exp aenv sh'
-> Fun aenv (sh' -> sh)
-> OpenAcc aenv (Array sh e)
-> PreOpenAcc OpenAcc aenv (Array sh' e)
forall sh' aenv sh (acc :: * -> * -> *) e.
ShapeR sh'
-> Exp aenv sh'
-> Fun aenv (sh' -> sh)
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh' e)
AST.Backpermute ShapeR sh'
shr (ScopedExp sh' -> Exp aenv sh'
forall t. ScopedExp t -> Exp aenv t
cvtE ScopedExp sh'
newDim) (TypeR sh' -> (SmartExp sh' -> ScopedExp sh) -> Fun aenv (sh' -> sh)
forall a b.
TypeR a -> (SmartExp a -> ScopedExp b) -> Fun aenv (a -> b)
cvtF1 (ShapeR sh' -> TypeR sh'
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh'
shr) SmartExp sh' -> ScopedExp sh
perm) (ScopedAcc (Array sh e) -> OpenAcc aenv (Array sh e)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh e)
acc)
      Stencil StencilR sh a stencil
stencil TypeR b
tp SmartExp stencil -> ScopedExp b
f PreBoundary ScopedAcc ScopedExp (Array sh a)
boundary ScopedAcc (Array sh a)
acc
        -> StencilR sh a stencil
-> TypeR b
-> Fun aenv (stencil -> b)
-> Boundary aenv (Array sh a)
-> OpenAcc aenv (Array sh a)
-> PreOpenAcc OpenAcc aenv (Array sh b)
forall sh e stencil sh aenv (acc :: * -> * -> *).
StencilR sh e stencil
-> TypeR sh
-> Fun aenv (stencil -> sh)
-> Boundary aenv (Array sh e)
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh sh)
AST.Stencil StencilR sh a stencil
stencil
                       TypeR b
tp
                       (Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> StencilR sh a stencil
-> (SmartExp stencil -> ScopedExp b)
-> Fun aenv (stencil -> b)
forall aenv sh a stencil b.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> StencilR sh a stencil
-> (SmartExp stencil -> ScopedExp b)
-> Fun aenv (stencil -> b)
convertSharingStencilFun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv' StencilR sh a stencil
stencil SmartExp stencil -> ScopedExp b
f)
                       (Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh a)
-> Boundary aenv (Array sh a)
forall aenv sh e.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh e)
-> Boundary aenv (Array sh e)
convertSharingBoundary Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv' (StencilR sh a stencil -> ShapeR sh
forall sh e pat. StencilR sh e pat -> ShapeR sh
stencilShapeR StencilR sh a stencil
stencil) PreBoundary ScopedAcc ScopedExp (Array sh a)
boundary)
                       (ScopedAcc (Array sh a) -> OpenAcc aenv (Array sh a)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh a)
acc)
      Stencil2 StencilR sh a stencil1
stencil1 StencilR sh b stencil2
stencil2 TypeR c
tp SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c
f PreBoundary ScopedAcc ScopedExp (Array sh a)
bndy1 ScopedAcc (Array sh a)
acc1 PreBoundary ScopedAcc ScopedExp (Array sh b)
bndy2 ScopedAcc (Array sh b)
acc2
        | ShapeR sh
shr <- StencilR sh a stencil1 -> ShapeR sh
forall sh e pat. StencilR sh e pat -> ShapeR sh
stencilShapeR StencilR sh a stencil1
stencil1
        -> StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> Fun aenv (stencil1 -> stencil2 -> c)
-> Boundary aenv (Array sh a)
-> OpenAcc aenv (Array sh a)
-> Boundary aenv (Array sh b)
-> OpenAcc aenv (Array sh b)
-> PreOpenAcc OpenAcc aenv (Array sh c)
forall sh a stencil1 b stencil2 c aenv (acc :: * -> * -> *).
StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> Fun aenv (stencil1 -> stencil2 -> c)
-> Boundary aenv (Array sh a)
-> acc aenv (Array sh a)
-> Boundary aenv (Array sh b)
-> acc aenv (Array sh b)
-> PreOpenAcc acc aenv (Array sh c)
AST.Stencil2 StencilR sh a stencil1
stencil1
                        StencilR sh b stencil2
stencil2
                        TypeR c
tp
                        (Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
-> Fun aenv (stencil1 -> stencil2 -> c)
forall aenv sh a stencil1 b stencil2 c.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
-> Fun aenv (stencil1 -> stencil2 -> c)
convertSharingStencilFun2 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv' StencilR sh a stencil1
stencil1 StencilR sh b stencil2
stencil2 SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c
f)
                        (Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh a)
-> Boundary aenv (Array sh a)
forall aenv sh e.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh e)
-> Boundary aenv (Array sh e)
convertSharingBoundary Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv' ShapeR sh
shr PreBoundary ScopedAcc ScopedExp (Array sh a)
bndy1)
                        (ScopedAcc (Array sh a) -> OpenAcc aenv (Array sh a)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh a)
acc1)
                        (Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh b)
-> Boundary aenv (Array sh b)
forall aenv sh e.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh e)
-> Boundary aenv (Array sh e)
convertSharingBoundary Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv' ShapeR sh
shr PreBoundary ScopedAcc ScopedExp (Array sh b)
bndy2)
                        (ScopedAcc (Array sh b) -> OpenAcc aenv (Array sh b)
forall a. ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc (Array sh b)
acc2)
      -- Collect seq -> AST.Collect (convertSharingSeq config alyt EmptyLayout aenv' [] seq)

{--
-- Sequence expressions
-- --------------------

-- | Convert a closed sequence expression to de Bruijn form while incorporating
-- sharing information.
--
convertSeq
    :: Typeable s
    => Bool             -- ^ recover sharing of array computations ?
    -> Bool             -- ^ recover sharing of scalar expressions ?
    -> Bool             -- ^ recover sharing of sequence computations ?
    -> Bool             -- ^ always float array computations out of expressions?
    -> Seq s            -- ^ computation to be converted
    -> AST.Seq s
convertSeq shareAcc shareExp shareSeq floatAcc seq
  = let config = Config shareAcc shareExp shareSeq floatAcc
        (sharingSeq, initialEnv) = recoverSharingSeq config seq
    in
    convertSharingSeq config EmptyLayout EmptyLayout [] initialEnv sharingSeq

convertSharingSeq
    :: forall aenv senv arrs.
       Config
    -> Layout aenv aenv
    -> Layout senv senv
    -> [StableSharingAcc]
    -> [StableSharingSeq]
    -> ScopedSeq arrs
    -> AST.PreOpenSeq AST.OpenAcc aenv senv arrs
convertSharingSeq _ _ slyt _ senv (ScopedSeq (SvarSharing sn))
  | Just i <- findIndex (matchStableSeq sn) senv
  = AST.Reify $ prjIdx (ctxt ++ "; i = " ++ show i) i slyt
  | null senv
  = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++
            show (hashStableNameHeight sn) ++ ")"
  | otherwise
  = $internalError "convertSharingSeq" err
  where
    ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn)
    err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  senv = " ++ show senv
convertSharingSeq config alyt slyt aenv senv (ScopedSeq (SletSharing sa@(StableSharingSeq _ (SeqSharing _ boundSeq)) bodySeq))
  = convSeq boundSeq bodySeq
  where
    convSeq :: forall bnd body.
               PreSeq ScopedAcc ScopedSeq ScopedExp bnd
            -> ScopedSeq body
            -> AST.PreOpenSeq AST.OpenAcc aenv senv body
    convSeq bnd body =
      case bnd of
        StreamIn arrs               -> producer $ AST.StreamIn arrs
        ToSeq slix acc              -> producer $ mkToSeq slix (cvtA acc)
        MapSeq afun x               -> producer $ AST.MapSeq (cvtAF1 afun) (asIdx x)
        ZipWithSeq afun x y         -> producer $ AST.ZipWithSeq (cvtAF2 afun) (asIdx x) (asIdx y)
        ScanSeq fun e x             -> producer $ AST.ScanSeq (cvtF2 fun) (cvtE e) (asIdx x)
        _                           -> $internalError "convertSharingSeq:convSeq" "Consumer appears to have been let bound"
      where
        producer :: Arrays a
                 => AST.Producer AST.OpenAcc aenv senv a
                 -> AST.PreOpenSeq AST.OpenAcc aenv senv body
        producer p = AST.Producer p $ convertSharingSeq config alyt slyt' aenv (sa:senv) body
          where
            slyt' = incLayout slyt `PushLayout` ZeroIdx

        asIdx :: (HasCallStack, Arrays a)
              => ScopedSeq [a]
              -> Idx senv a
        asIdx (ScopedSeq (SvarSharing sn))
          | Just i <- findIndex (matchStableSeq sn) senv
          = prjIdx (ctxt ++ "; i = " ++ show i) i slyt
          | null senv
          = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++
                    show (hashStableNameHeight sn) ++ ")"
          | otherwise
          = $internalError "convertSharingSeq" err
          where
            ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn)
            err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  senv = " ++ show senv
        asIdx _
          = $internalError "convertSharingSeq:asIdx" "Sequence computation not in A-normal form"

        cvtA :: forall a. Arrays a => ScopedAcc a -> AST.OpenAcc aenv a
        cvtA acc = convertSharingAcc config alyt aenv acc

        cvtE :: forall t. Elt t => ScopedExp t -> AST.Exp aenv t
        cvtE = convertSharingExp config EmptyLayout alyt [] aenv

        cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
        cvtF2 = convertSharingFun2 config alyt aenv

        cvtAF1 :: forall a b. (Arrays a, Arrays b) => (Acc a -> ScopedAcc b) -> OpenAfun aenv (a -> b)
        cvtAF1 afun = convertSharingAfun1 config alyt aenv afun

        cvtAF2 :: forall a b c. (Arrays a, Arrays b, Arrays c) => (Acc a -> Acc b -> ScopedAcc c) -> OpenAfun aenv (a -> b -> c)
        cvtAF2 afun = convertSharingAfun2 config alyt aenv afun

convertSharingSeq _ _ _ _ _ (ScopedSeq (SletSharing _ _))
 = $internalError "convertSharingSeq" "Sequence computation not in A-normal form"

convertSharingSeq config alyt slyt aenv senv s
  = cvtC s
  where
    cvtC :: ScopedSeq a -> AST.PreOpenSeq AST.OpenAcc aenv senv a
    cvtC (ScopedSeq (SeqSharing _ s)) =
      case s of
        FoldSeq fun e x                    -> AST.Consumer $ AST.FoldSeq (cvtF2 fun) (cvtE e) (asIdx x)
        FoldSeqFlatten afun acc x          -> AST.Consumer $ AST.FoldSeqFlatten (cvtAF3 afun) (cvtA acc) (asIdx x)
        Stuple t                           -> AST.Consumer $ AST.Stuple (cvtST t)
        _                                  -> $internalError "convertSharingSeq" "Producer has not been let bound"
    cvtC _ = $internalError "convertSharingSeq" "Unreachable"

    asIdx :: Arrays a
          => ScopedSeq [a]
          -> Idx senv a
    asIdx (ScopedSeq (SvarSharing sn))
      | Just i <- findIndex (matchStableSeq sn) senv
      = prjIdx (ctxt ++ "; i = " ++ show i) i slyt
      | null senv
      = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++
                show (hashStableNameHeight sn) ++ ")"
      | otherwise
      = $internalError "convertSharingSeq" err
      where
        ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn)
        err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  senv = " ++ show senv
    asIdx _
      = $internalError "convertSharingSeq:asIdx" "Sequence computation not in A-normal form"

    cvtA :: forall a. Arrays a => ScopedAcc a -> AST.OpenAcc aenv a
    cvtA acc = convertSharingAcc config alyt aenv acc

    cvtE :: forall t. Elt t => ScopedExp t -> AST.Exp aenv t
    cvtE = convertSharingExp config EmptyLayout alyt [] aenv

    cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
    cvtF2 = convertSharingFun2 config alyt aenv

    cvtAF3 :: forall a b c d. (Arrays a, Arrays b, Arrays c, Arrays d) => (Acc a -> Acc b -> Acc c -> ScopedAcc d) -> OpenAfun aenv (a -> b -> c -> d)
    cvtAF3 afun = convertSharingAfun3 config alyt aenv afun

    cvtST :: Atuple ScopedSeq t -> Atuple (AST.Consumer AST.OpenAcc aenv senv) t
    cvtST NilAtup        = NilAtup
    cvtST (SnocAtup t c) | AST.Consumer c' <- cvtC c
                         = SnocAtup (cvtST t) c'
                         | otherwise
                         = $internalError "convertSharingSeq" "Unreachable"
--}

convertSharingAfun1
    :: forall aenv a b. HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]
    -> ArraysR a
    -> (SmartAcc a -> ScopedAcc b)
    -> OpenAfun aenv (a -> b)
convertSharingAfun1 :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ArraysR a
-> (SmartAcc a -> ScopedAcc b)
-> OpenAfun aenv (a -> b)
convertSharingAfun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv ArraysR a
reprA SmartAcc a -> ScopedAcc b
f
  | DeclareVars LeftHandSide ArrayR a aenv env'
lhs aenv :> env'
k forall env''. (env' :> env'') -> Vars ArrayR env'' a
value <- ArraysR a -> DeclareVars ArrayR a aenv
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars ArraysR a
reprA
  = let
      alyt' :: Layout ArrayR env' env'
alyt' = Layout ArrayR env' aenv
-> LeftHandSide ArrayR a aenv env'
-> Vars ArrayR env' a
-> Layout ArrayR env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((aenv :> env') -> ArrayLayout aenv aenv -> Layout ArrayR env' aenv
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout aenv :> env'
k ArrayLayout aenv aenv
alyt) LeftHandSide ArrayR a aenv env'
lhs ((env' :> env') -> Vars ArrayR env' a
forall env''. (env' :> env'') -> Vars ArrayR env'' a
value env' :> env'
forall env. env :> env
weakenId)
      body :: ScopedAcc b
body = SmartAcc a -> ScopedAcc b
f SmartAcc a
forall a. HasCallStack => a
undefined
    in
      LeftHandSide ArrayR a aenv env'
-> PreOpenAfun OpenAcc env' b -> OpenAfun aenv (a -> b)
forall a aenv aenv' (acc :: * -> * -> *) t.
ALeftHandSide a aenv aenv'
-> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t)
Alam LeftHandSide ArrayR a aenv env'
lhs (OpenAcc env' b -> PreOpenAfun OpenAcc env' b
forall (acc :: * -> * -> *) aenv t.
acc aenv t -> PreOpenAfun acc aenv t
Abody (Config
-> Layout ArrayR env' env'
-> [StableSharingAcc]
-> ScopedAcc b
-> OpenAcc env' b
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
config Layout ArrayR env' env'
alyt' [StableSharingAcc]
aenv ScopedAcc b
body))

-- | Convert a boundary condition
--
convertSharingBoundary
    :: forall aenv sh e. HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]
    -> ShapeR sh
    -> PreBoundary ScopedAcc ScopedExp (Array sh e)
    -> AST.Boundary aenv (Array sh e)
convertSharingBoundary :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ShapeR sh
-> PreBoundary ScopedAcc ScopedExp (Array sh e)
-> Boundary aenv (Array sh e)
convertSharingBoundary Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv ShapeR sh
shr = PreBoundary ScopedAcc ScopedExp (Array sh e)
-> Boundary aenv (Array sh e)
cvt
  where
    cvt :: PreBoundary ScopedAcc ScopedExp (Array sh e) -> AST.Boundary aenv (Array sh e)
    cvt :: PreBoundary ScopedAcc ScopedExp (Array sh e)
-> Boundary aenv (Array sh e)
cvt PreBoundary ScopedAcc ScopedExp (Array sh e)
bndy =
      case PreBoundary ScopedAcc ScopedExp (Array sh e)
bndy of
        PreBoundary ScopedAcc ScopedExp (Array sh e)
Clamp       -> Boundary aenv (Array sh e)
forall aenv t. Boundary aenv t
AST.Clamp
        PreBoundary ScopedAcc ScopedExp (Array sh e)
Mirror      -> Boundary aenv (Array sh e)
forall aenv t. Boundary aenv t
AST.Mirror
        PreBoundary ScopedAcc ScopedExp (Array sh e)
Wrap        -> Boundary aenv (Array sh e)
forall aenv t. Boundary aenv t
AST.Wrap
        Constant e
v  -> e -> Boundary aenv (Array sh e)
forall e aenv sh. e -> Boundary aenv (Array sh e)
AST.Constant e
v
        Function SmartExp sh -> ScopedExp e
f  -> Fun aenv (sh -> e) -> Boundary aenv (Array sh e)
forall aenv sh e. Fun aenv (sh -> e) -> Boundary aenv (Array sh e)
AST.Function (Fun aenv (sh -> e) -> Boundary aenv (Array sh e))
-> Fun aenv (sh -> e) -> Boundary aenv (Array sh e)
forall a b. (a -> b) -> a -> b
$ Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR sh
-> (SmartExp sh -> ScopedExp e)
-> Fun aenv (sh -> e)
forall aenv a b.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> (SmartExp a -> ScopedExp b)
-> Fun aenv (a -> b)
convertSharingFun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv (ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr) SmartExp sh -> ScopedExp e
SmartExp sh -> ScopedExp e
f


-- mkToSeq :: forall slsix slix e aenv senv. (Division slsix, DivisionSlice slsix ~ slix, Elt e, Elt slix, Slice slix)
--         => slsix
--         -> AST.OpenAcc              aenv (Array (FullShape  slix) e)
--         -> AST.Producer AST.OpenAcc aenv senv (Array (SliceShape slix) e)
-- mkToSeq _ = AST.ToSeq (sliceIndex slix) (Proxy :: Proxy slix)
--   where
--     slix = undefined :: slix


-- Scalar functions
-- ----------------

-- | Convert a closed scalar function to de Bruijn form while incorporating
-- sharing information.
--
-- The current design requires all free variables to be bound at the outermost
-- level --- we have no general apply term, and so lambdas are always outermost.
-- In higher-order abstract syntax, this represents an n-ary, polyvariadic
-- function.
--
convertFun :: (HasCallStack, Function f) => f -> AST.Fun () (EltFunctionR f)
convertFun :: f -> Fun () (EltFunctionR f)
convertFun
  = Config -> f -> Fun () (EltFunctionR f)
forall f.
(HasCallStack, Function f) =>
Config -> f -> Fun () (EltFunctionR f)
convertFunWith
  (Config -> f -> Fun () (EltFunctionR f))
-> Config -> f -> Fun () (EltFunctionR f)
forall a b. (a -> b) -> a -> b
$ Config
defaultOptions { options :: BitSet Word32 Flag
options = Config -> BitSet Word32 Flag
options Config
defaultOptions BitSet Word32 Flag -> BitSet Word32 Flag -> BitSet Word32 Flag
forall c a. Bits c => BitSet c a -> BitSet c a -> BitSet c a
\\ [Item (BitSet Word32 Flag)
Flag
seq_sharing, Item (BitSet Word32 Flag)
Flag
acc_sharing] }

convertFunWith :: (HasCallStack, Function f) => Config -> f -> AST.Fun () (EltFunctionR f)
convertFunWith :: Config -> f -> Fun () (EltFunctionR f)
convertFunWith Config
config = Config -> ELayout () () -> f -> Fun () (EltFunctionR f)
forall f env.
(Function f, HasCallStack) =>
Config -> ELayout env env -> f -> OpenFun env () (EltFunctionR f)
convertOpenFun Config
config ELayout () ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout

data FunctionRepr f r reprr where
  FunctionReprBody
    :: Elt b => FunctionRepr (Exp b) b (EltR b)

  FunctionReprLam
    :: Elt a
    => FunctionRepr b br breprr
    -> FunctionRepr (Exp a -> b) (a -> br) (EltR a -> breprr)

class Function f where
  type FunctionR f
  type EltFunctionR f

  functionRepr   :: HasCallStack => FunctionRepr f (FunctionR f) (EltFunctionR f)
  convertOpenFun :: HasCallStack => Config -> ELayout env env -> f -> AST.OpenFun env () (EltFunctionR f)

instance (Elt a, Function r) => Function (Exp a -> r) where
  type FunctionR (Exp a -> r) = a -> FunctionR r
  type EltFunctionR (Exp a -> r) = EltR a -> EltFunctionR r

  functionRepr :: FunctionRepr
  (Exp a -> r) (FunctionR (Exp a -> r)) (EltFunctionR (Exp a -> r))
functionRepr = FunctionRepr r (FunctionR r) (EltFunctionR r)
-> FunctionRepr
     (Exp a -> r) (a -> FunctionR r) (EltR a -> EltFunctionR r)
forall a b br breprr.
Elt a =>
FunctionRepr b br breprr
-> FunctionRepr (Exp a -> b) (a -> br) (EltR a -> breprr)
FunctionReprLam (FunctionRepr r (FunctionR r) (EltFunctionR r)
 -> FunctionRepr
      (Exp a -> r) (a -> FunctionR r) (EltR a -> EltFunctionR r))
-> FunctionRepr r (FunctionR r) (EltFunctionR r)
-> FunctionRepr
     (Exp a -> r) (a -> FunctionR r) (EltR a -> EltFunctionR r)
forall a b. (a -> b) -> a -> b
$ (Function r, HasCallStack) =>
FunctionRepr r (FunctionR r) (EltFunctionR r)
forall f.
(Function f, HasCallStack) =>
FunctionRepr f (FunctionR f) (EltFunctionR f)
functionRepr @r
  convertOpenFun :: Config
-> ELayout env env
-> (Exp a -> r)
-> OpenFun env () (EltFunctionR (Exp a -> r))
convertOpenFun Config
config ELayout env env
lyt Exp a -> r
f
    | TypeR (EltR a)
tp <- Elt a => TypeR (EltR a)
forall a. Elt a => TypeR (EltR a)
eltR @a
    , DeclareVars LeftHandSide ScalarType (EltR a) env env'
lhs env :> env'
k forall env''. (env' :> env'') -> Vars ScalarType env'' (EltR a)
value <- TypeR (EltR a) -> DeclareVars ScalarType (EltR a) env
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars TypeR (EltR a)
tp
    = let
        e :: Exp a
e    = SmartExp (EltR a) -> Exp a
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (EltR a) -> Exp a) -> SmartExp (EltR a) -> Exp a
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a))
-> PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a)
forall a b. (a -> b) -> a -> b
$ TypeR (EltR a) -> Int -> PreSmartExp SmartAcc SmartExp (EltR a)
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR (EltR a)
tp (Int -> PreSmartExp SmartAcc SmartExp (EltR a))
-> Int -> PreSmartExp SmartAcc SmartExp (EltR a)
forall a b. (a -> b) -> a -> b
$ ELayout env env -> Int
forall (s :: * -> *) env env'. Layout s env env' -> Int
sizeLayout ELayout env env
lyt
        lyt' :: Layout ScalarType env' env'
lyt' = Layout ScalarType env' env
-> LeftHandSide ScalarType (EltR a) env env'
-> Vars ScalarType env' (EltR a)
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((env :> env') -> ELayout env env -> Layout ScalarType env' env
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout env :> env'
k ELayout env env
lyt) LeftHandSide ScalarType (EltR a) env env'
lhs ((env' :> env') -> Vars ScalarType env' (EltR a)
forall env''. (env' :> env'') -> Vars ScalarType env'' (EltR a)
value env' :> env'
forall env. env :> env
weakenId)
      in
        LeftHandSide ScalarType (EltR a) env env'
-> OpenFun env' () (EltFunctionR r)
-> OpenFun env () (EltR a -> EltFunctionR r)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam LeftHandSide ScalarType (EltR a) env env'
lhs (OpenFun env' () (EltFunctionR r)
 -> OpenFun env () (EltR a -> EltFunctionR r))
-> OpenFun env' () (EltFunctionR r)
-> OpenFun env () (EltR a -> EltFunctionR r)
forall a b. (a -> b) -> a -> b
$ Config
-> Layout ScalarType env' env'
-> r
-> OpenFun env' () (EltFunctionR r)
forall f env.
(Function f, HasCallStack) =>
Config -> ELayout env env -> f -> OpenFun env () (EltFunctionR f)
convertOpenFun Config
config Layout ScalarType env' env'
lyt' (r -> OpenFun env' () (EltFunctionR r))
-> r -> OpenFun env' () (EltFunctionR r)
forall a b. (a -> b) -> a -> b
$ Exp a -> r
f Exp a
e

instance Elt b => Function (Exp b) where
  type FunctionR (Exp b) = b
  type EltFunctionR (Exp b) = EltR b

  functionRepr :: FunctionRepr (Exp b) (FunctionR (Exp b)) (EltFunctionR (Exp b))
functionRepr = FunctionRepr (Exp b) (FunctionR (Exp b)) (EltFunctionR (Exp b))
forall b. Elt b => FunctionRepr (Exp b) b (EltR b)
FunctionReprBody
  convertOpenFun :: Config
-> ELayout env env
-> Exp b
-> OpenFun env () (EltFunctionR (Exp b))
convertOpenFun Config
config ELayout env env
lyt (Exp SmartExp (EltR b)
body) = OpenExp env () (EltR b) -> OpenFun env () (EltR b)
forall env aenv t. OpenExp env aenv t -> OpenFun env aenv t
Body (OpenExp env () (EltR b) -> OpenFun env () (EltR b))
-> OpenExp env () (EltR b) -> OpenFun env () (EltR b)
forall a b. (a -> b) -> a -> b
$ Config
-> ELayout env env -> SmartExp (EltR b) -> OpenExp env () (EltR b)
forall env e.
HasCallStack =>
Config -> ELayout env env -> SmartExp e -> OpenExp env () e
convertOpenExp Config
config ELayout env env
lyt SmartExp (EltR b)
body

convertSmartFun
    :: HasCallStack
    => Config
    -> TypeR a
    -> (SmartExp a -> SmartExp b)
    -> AST.Fun () (a -> b)
convertSmartFun :: Config -> TypeR a -> (SmartExp a -> SmartExp b) -> Fun () (a -> b)
convertSmartFun Config
config TypeR a
tp SmartExp a -> SmartExp b
f
  | DeclareVars LeftHandSide ScalarType a () env'
lhs () :> env'
_ forall env''. (env' :> env'') -> Vars ScalarType env'' a
value <- TypeR a -> DeclareVars ScalarType a ()
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars TypeR a
tp
  = let
      e :: SmartExp a
e    = PreSmartExp SmartAcc SmartExp a -> SmartExp a
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp a -> SmartExp a)
-> PreSmartExp SmartAcc SmartExp a -> SmartExp a
forall a b. (a -> b) -> a -> b
$ TypeR a -> Int -> PreSmartExp SmartAcc SmartExp a
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR a
tp Int
0
      lyt' :: Layout ScalarType env' env'
lyt' = Layout ScalarType env' ()
-> LeftHandSide ScalarType a () env'
-> Vars ScalarType env' a
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout Layout ScalarType env' ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout LeftHandSide ScalarType a () env'
lhs ((env' :> env') -> Vars ScalarType env' a
forall env''. (env' :> env'') -> Vars ScalarType env'' a
value env' :> env'
forall env. env :> env
weakenId)
    in
      LeftHandSide ScalarType a () env'
-> OpenFun env' () b -> Fun () (a -> b)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam LeftHandSide ScalarType a () env'
lhs (OpenFun env' () b -> Fun () (a -> b))
-> OpenFun env' () b -> Fun () (a -> b)
forall a b. (a -> b) -> a -> b
$ OpenExp env' () b -> OpenFun env' () b
forall env aenv t. OpenExp env aenv t -> OpenFun env aenv t
Body (OpenExp env' () b -> OpenFun env' () b)
-> OpenExp env' () b -> OpenFun env' () b
forall a b. (a -> b) -> a -> b
$ Config
-> Layout ScalarType env' env' -> SmartExp b -> OpenExp env' () b
forall env e.
HasCallStack =>
Config -> ELayout env env -> SmartExp e -> OpenExp env () e
convertOpenExp Config
config Layout ScalarType env' env'
lyt' (SmartExp b -> OpenExp env' () b)
-> SmartExp b -> OpenExp env' () b
forall a b. (a -> b) -> a -> b
$ SmartExp a -> SmartExp b
f SmartExp a
e

-- Scalar expressions
-- ------------------

-- | Convert a closed scalar expression to de Bruijn form while incorporating
-- sharing information.
--
convertExp
    :: HasCallStack
    => Exp e
    -> AST.Exp () (EltR e)
convertExp :: Exp e -> Exp () (EltR e)
convertExp
  = Config -> Exp e -> Exp () (EltR e)
forall e. HasCallStack => Config -> Exp e -> Exp () (EltR e)
convertExpWith
  (Config -> Exp e -> Exp () (EltR e))
-> Config -> Exp e -> Exp () (EltR e)
forall a b. (a -> b) -> a -> b
$ Config
defaultOptions { options :: BitSet Word32 Flag
options = Config -> BitSet Word32 Flag
options Config
defaultOptions BitSet Word32 Flag -> BitSet Word32 Flag -> BitSet Word32 Flag
forall c a. Bits c => BitSet c a -> BitSet c a -> BitSet c a
\\ [Item (BitSet Word32 Flag)
Flag
seq_sharing, Item (BitSet Word32 Flag)
Flag
acc_sharing] }

convertExpWith
      :: HasCallStack
      => Config
      -> Exp e
      -> AST.Exp () (EltR e)
convertExpWith :: Config -> Exp e -> Exp () (EltR e)
convertExpWith Config
config (Exp SmartExp (EltR e)
e) = Config -> ELayout () () -> SmartExp (EltR e) -> Exp () (EltR e)
forall env e.
HasCallStack =>
Config -> ELayout env env -> SmartExp e -> OpenExp env () e
convertOpenExp Config
config ELayout () ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout SmartExp (EltR e)
e

convertOpenExp
    :: HasCallStack
    => Config
    -> ELayout env env
    -> SmartExp e
    -> AST.OpenExp env () e
convertOpenExp :: Config -> ELayout env env -> SmartExp e -> OpenExp env () e
convertOpenExp Config
config ELayout env env
lyt SmartExp e
exp =
  let lvl :: Int
lvl                      = ELayout env env -> Int
forall (s :: * -> *) env env'. Layout s env env' -> Int
sizeLayout ELayout env env
lyt
      fvs :: [Int]
fvs                      = [Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2 .. Item [Int]
0]
      (ScopedExp e
sharingExp, [StableSharingExp]
initialEnv) = Config
-> Int -> [Int] -> SmartExp e -> (ScopedExp e, [StableSharingExp])
forall e.
HasCallStack =>
Config
-> Int -> [Int] -> SmartExp e -> (ScopedExp e, [StableSharingExp])
recoverSharingExp Config
config Int
lvl [Int]
fvs SmartExp e
exp
  in
  Config
-> ELayout env env
-> ArrayLayout () ()
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp e
-> OpenExp env () e
forall t env aenv.
HasCallStack =>
Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config ELayout env env
lyt ArrayLayout () ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout [StableSharingExp]
initialEnv [] ScopedExp e
sharingExp


-- | Convert an open expression with given environment layouts and sharing information into
-- de Bruijn form while recovering sharing at the same time (by introducing appropriate let
-- bindings).  The latter implements the third phase of sharing recovery.
--
-- The sharing environments 'env' and 'aenv' keep track of all currently bound sharing variables,
-- keeping them in reverse chronological order (outermost variable is at the end of the list).
--
convertSharingExp
    :: forall t env aenv. HasCallStack
    => Config
    -> ELayout env env          -- scalar environment
    -> ArrayLayout aenv aenv    -- array environment
    -> [StableSharingExp]       -- currently bound sharing variables of expressions
    -> [StableSharingAcc]       -- currently bound sharing variables of array computations
    -> ScopedExp t              -- expression to be converted
    -> AST.OpenExp env aenv t
convertSharingExp :: Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config ELayout env env
lyt ArrayLayout aenv aenv
alyt [StableSharingExp]
env [StableSharingAcc]
aenv exp :: ScopedExp t
exp@(ScopedExp [StableSharingExp]
lams SharingExp ScopedAcc ScopedExp t
_) = ScopedExp t -> OpenExp env aenv t
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t
exp
  where
    -- scalar environment with any lambda bound variables this expression is rooted in
    env' :: [StableSharingExp]
env' = [StableSharingExp]
lams [StableSharingExp] -> [StableSharingExp] -> [StableSharingExp]
forall a. [a] -> [a] -> [a]
++ [StableSharingExp]
env

    cvt :: HasCallStack => ScopedExp t' -> AST.OpenExp env aenv t'
    cvt :: ScopedExp t' -> OpenExp env aenv t'
cvt (ScopedExp [StableSharingExp]
_ (VarSharing StableExpName t'
se TypeR t'
tp))
      | Just Int
i <- (StableSharingExp -> Bool) -> [StableSharingExp] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (StableExpName t' -> StableSharingExp -> Bool
forall t. StableExpName t -> StableSharingExp -> Bool
matchStableExp StableExpName t'
se) [StableSharingExp]
env' = ExpVars env t' -> OpenExp env aenv t'
forall env t aenv. ExpVars env t -> OpenExp env aenv t
expVars (String
-> (forall t'. TupR ScalarType t' -> ShowS)
-> (forall u v.
    TupR ScalarType u -> TupR ScalarType v -> Maybe (u :~: v))
-> TypeR t'
-> Int
-> ELayout env env
-> ExpVars env t'
forall (s :: * -> *) t env env1.
HasCallStack =>
String
-> (forall t'. TupR s t' -> ShowS)
-> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
-> TupR s t
-> Int
-> Layout s env env1
-> Vars s env t
prjIdx (Int -> String
ctx Int
i) forall a. Show a => a -> ShowS
forall t'. TupR ScalarType t' -> ShowS
shows forall u v.
TupR ScalarType u -> TupR ScalarType v -> Maybe (u :~: v)
matchTypeR TypeR t'
tp Int
i ELayout env env
lyt)
      | Bool
otherwise                                    = String -> OpenExp env aenv t'
forall a. HasCallStack => String -> a
internalError String
msg
      where
        ctx :: Int -> String
ctx Int
i = String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"shared 'Exp' tree with stable name %d; i=%d" (StableExpName t' -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableExpName t'
se) Int
i
        msg :: String
msg   = [String] -> String
unlines
          [ if [StableSharingExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [StableSharingExp]
env'
               then String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"cyclic definition of a value of type 'Exp' (sa=%d)" (StableExpName t' -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableExpName t'
se)
               else String -> Int -> ShowS
forall r. PrintfType r => String -> r
printf String
"inconsistent valuation at shared 'Exp' tree (sa=%d; env=%s)" (StableExpName t' -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableExpName t'
se) ([StableSharingExp] -> String
forall a. Show a => a -> String
show [StableSharingExp]
env')
          , String
Item [String]
""
          , String
Item [String]
"Note that this error usually arises due to the presence of nested data"
          , String
Item [String]
"parallelism; when a parallel computation attempts to initiate new parallel"
          , String
Item [String]
"work _which depends on_ a scalar variable given by the first computation."
          , String
Item [String]
""
          , String
Item [String]
"For example, suppose we wish to sum the columns of a two-dimensional array."
          , String
Item [String]
"You might think to do this in the following (incorrect) way: by constructing"
          , String
Item [String]
"a vector using 'generate' where at each index we 'slice' out the"
          , String
Item [String]
"corresponding column of the matrix and 'sum' it:"
          , String
Item [String]
""
          , String
Item [String]
"> sum_columns_ndp :: Num a => Acc (Matrix a) -> Acc (Vector a)"
          , String
Item [String]
"> sum_columns_ndp mat ="
          , String
Item [String]
">   let I2 rows cols = shape mat"
          , String
Item [String]
">   in  generate (I1 cols)"
          , String
Item [String]
">                (\\(I1 col) -> the $ sum (slice mat (lift (Z :. All :. col))))"
          , String
Item [String]
""
          , String
Item [String]
"However, since both 'generate' and 'slice' are data-parallel operators, and"
          , String
Item [String]
"moreover that 'slice' _depends on_ the argument 'col' given to it by the"
          , String
Item [String]
"'generate' function, this operation requires nested parallelism and is thus"
          , String
Item [String]
"not (at this time) permitted. The clue that this definition is invalid is"
          , String
Item [String]
"that in order to create a program which will be accepted by the type checker,"
          , String
Item [String]
"we had to use the function 'the' to retrieve the result of the parallel"
          , String
Item [String]
"'sum', effectively concealing that this is a collective operation in order to"
          , String
Item [String]
"match the type expected by 'generate'."
          , String
Item [String]
""
          , String
Item [String]
"To solve this particular example, we can make use of the fact that (most)"
          , String
Item [String]
"collective operations in Accelerate are _rank polymorphic_. The 'sum'"
          , String
Item [String]
"operation reduces along the innermost dimension of an array of arbitrary"
          , String
Item [String]
"rank, reducing the dimensionality of the array by one. To reduce the array"
          , String
Item [String]
"column-wise then, we first need to simply 'transpose' the array:"
          , String
Item [String]
""
          , String
Item [String]
"> sum_columns :: Num a => Acc (Matrix a) -> Acc (Vector a)"
          , String
Item [String]
"> sum_columns = sum . transpose"
          , String
Item [String]
""
          , String
Item [String]
"If you feel like this is not the cause of your error, or you would like some"
          , String
Item [String]
"advice locating the problem and perhaps with a workaround, feel free to"
          , String
Item [String]
"submit an issue at the above URL."
          ]

    cvt (ScopedExp [StableSharingExp]
_ (LetSharing se :: StableSharingExp
se@(StableSharingExp StableExpName t
_ SharingExp ScopedAcc ScopedExp t
boundExp) ScopedExp t'
bodyExp))
      | DeclareVars LeftHandSide ScalarType t env env'
lhs env :> env'
k forall env''. (env' :> env'') -> Vars ScalarType env'' t
value <- TupR ScalarType t -> DeclareVars ScalarType t env
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars (TupR ScalarType t -> DeclareVars ScalarType t env)
-> TupR ScalarType t -> DeclareVars ScalarType t env
forall a b. (a -> b) -> a -> b
$ SharingExp ScopedAcc ScopedExp t -> TupR ScalarType t
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
typeR SharingExp ScopedAcc ScopedExp t
boundExp
      = let
          lyt' :: Layout ScalarType env' env'
lyt' = Layout ScalarType env' env
-> LeftHandSide ScalarType t env env'
-> Vars ScalarType env' t
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((env :> env') -> ELayout env env -> Layout ScalarType env' env
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout env :> env'
k ELayout env env
lyt) LeftHandSide ScalarType t env env'
lhs ((env' :> env') -> Vars ScalarType env' t
forall env''. (env' :> env'') -> Vars ScalarType env'' t
value env' :> env'
forall env. env :> env
weakenId)
        in
          LeftHandSide ScalarType t env env'
-> OpenExp env aenv t
-> OpenExp env' aenv t'
-> OpenExp env aenv t'
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
AST.Let LeftHandSide ScalarType t env env'
lhs (ScopedExp t -> OpenExp env aenv t
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] SharingExp ScopedAcc ScopedExp t
boundExp)) (Config
-> Layout ScalarType env' env'
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t'
-> OpenExp env' aenv t'
forall t env aenv.
HasCallStack =>
Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config Layout ScalarType env' env'
lyt' ArrayLayout aenv aenv
alyt (StableSharingExp
seStableSharingExp -> [StableSharingExp] -> [StableSharingExp]
forall a. a -> [a] -> [a]
:[StableSharingExp]
env') [StableSharingAcc]
aenv ScopedExp t'
bodyExp)
    cvt (ScopedExp [StableSharingExp]
_ (ExpSharing StableExpName t'
_ PreSmartExp ScopedAcc ScopedExp t'
pexp))
      = case PreSmartExp ScopedAcc ScopedExp t'
pexp of
          Tag TypeR t'
tp Int
i              -> ExpVars env t' -> OpenExp env aenv t'
forall env t aenv. ExpVars env t -> OpenExp env aenv t
expVars (ExpVars env t' -> OpenExp env aenv t')
-> ExpVars env t' -> OpenExp env aenv t'
forall a b. (a -> b) -> a -> b
$ String
-> (forall t'. TupR ScalarType t' -> ShowS)
-> (forall u v.
    TupR ScalarType u -> TupR ScalarType v -> Maybe (u :~: v))
-> TypeR t'
-> Int
-> ELayout env env
-> ExpVars env t'
forall (s :: * -> *) t env env1.
HasCallStack =>
String
-> (forall t'. TupR s t' -> ShowS)
-> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
-> TupR s t
-> Int
-> Layout s env env1
-> Vars s env t
prjIdx (String
"de Bruijn conversion tag " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i) forall a. Show a => a -> ShowS
forall t'. TupR ScalarType t' -> ShowS
shows forall u v.
TupR ScalarType u -> TupR ScalarType v -> Maybe (u :~: v)
matchTypeR TypeR t'
tp Int
i ELayout env env
lyt
          Match TagR t'
_ ScopedExp t'
e             -> ScopedExp t' -> OpenExp env aenv t'
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t'
e  -- XXX: this should probably be an error
          Const ScalarType t'
tp t'
v            -> ScalarType t' -> t' -> OpenExp env aenv t'
forall t env aenv. ScalarType t -> t -> OpenExp env aenv t
AST.Const ScalarType t'
tp t'
v
          Undef ScalarType t'
tp              -> ScalarType t' -> OpenExp env aenv t'
forall t env aenv. ScalarType t -> OpenExp env aenv t
AST.Undef ScalarType t'
tp
          Prj PairIdx (t1, t2) t'
idx ScopedExp (t1, t2)
e             -> PairIdx (t1, t2) t'
-> OpenExp env aenv (t1, t2) -> OpenExp env aenv t'
forall a b c env1 aenv1.
PairIdx (a, b) c
-> OpenExp env1 aenv1 (a, b) -> OpenExp env1 aenv1 c
cvtPrj PairIdx (t1, t2) t'
idx (ScopedExp (t1, t2) -> OpenExp env aenv (t1, t2)
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp (t1, t2)
e)
          PreSmartExp ScopedAcc ScopedExp t'
Nil                   -> OpenExp env aenv t'
forall env aenv. OpenExp env aenv ()
AST.Nil
          Pair ScopedExp t1
e1 ScopedExp t2
e2            -> OpenExp env aenv t1
-> OpenExp env aenv t2 -> OpenExp env aenv (t1, t2)
forall env aenv t1 t2.
OpenExp env aenv t1
-> OpenExp env aenv t2 -> OpenExp env aenv (t1, t2)
AST.Pair (ScopedExp t1 -> OpenExp env aenv t1
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t1
e1) (ScopedExp t2 -> OpenExp env aenv t2
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t2
e2)
          VecPack   VecR n s tup
vec ScopedExp tup
e       -> VecR n s tup -> OpenExp env aenv tup -> OpenExp env aenv (Vec n s)
forall (n :: Nat) s tup env aenv.
KnownNat n =>
VecR n s tup -> OpenExp env aenv tup -> OpenExp env aenv (Vec n s)
AST.VecPack   VecR n s tup
vec (ScopedExp tup -> OpenExp env aenv tup
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp tup
e)
          VecUnpack VecR n s t'
vec ScopedExp (Vec n s)
e       -> VecR n s t' -> OpenExp env aenv (Vec n s) -> OpenExp env aenv t'
forall (n :: Nat) s tup env aenv.
KnownNat n =>
VecR n s tup -> OpenExp env aenv (Vec n s) -> OpenExp env aenv tup
AST.VecUnpack VecR n s t'
vec (ScopedExp (Vec n s) -> OpenExp env aenv (Vec n s)
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp (Vec n s)
e)
          ToIndex ShapeR sh
shr ScopedExp sh
sh ScopedExp sh
ix     -> ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
forall sh env aenv.
ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
AST.ToIndex ShapeR sh
shr (ScopedExp sh -> OpenExp env aenv sh
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp sh
sh) (ScopedExp sh -> OpenExp env aenv sh
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp sh
ix)
          FromIndex ShapeR t'
shr ScopedExp t'
sh ScopedExp Int
e    -> ShapeR t'
-> OpenExp env aenv t'
-> OpenExp env aenv Int
-> OpenExp env aenv t'
forall sh env aenv.
ShapeR sh
-> OpenExp env aenv sh
-> OpenExp env aenv Int
-> OpenExp env aenv sh
AST.FromIndex ShapeR t'
shr (ScopedExp t' -> OpenExp env aenv t'
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t'
sh) (ScopedExp Int -> OpenExp env aenv Int
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp Int
e)
          Case ScopedExp a
e [(TagR a, ScopedExp t')]
rhs            -> OpenExp env aenv a
-> [(TagR a, OpenExp env aenv t')] -> OpenExp env aenv t'
forall env' aenv' a b.
HasCallStack =>
OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
cvtCase (ScopedExp a -> OpenExp env aenv a
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp a
e) (ASetter
  [(TagR a, ScopedExp t')]
  [(TagR a, OpenExp env aenv t')]
  (ScopedExp t')
  (OpenExp env aenv t')
-> (ScopedExp t' -> OpenExp env aenv t')
-> [(TagR a, ScopedExp t')]
-> [(TagR a, OpenExp env aenv t')]
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over (((TagR a, ScopedExp t') -> Identity (TagR a, OpenExp env aenv t'))
-> [(TagR a, ScopedExp t')]
-> Identity [(TagR a, OpenExp env aenv t')]
forall (f :: * -> *) a b. Functor f => Setter (f a) (f b) a b
mapped (((TagR a, ScopedExp t') -> Identity (TagR a, OpenExp env aenv t'))
 -> [(TagR a, ScopedExp t')]
 -> Identity [(TagR a, OpenExp env aenv t')])
-> ((ScopedExp t' -> Identity (OpenExp env aenv t'))
    -> (TagR a, ScopedExp t')
    -> Identity (TagR a, OpenExp env aenv t'))
-> ASetter
     [(TagR a, ScopedExp t')]
     [(TagR a, OpenExp env aenv t')]
     (ScopedExp t')
     (OpenExp env aenv t')
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ScopedExp t' -> Identity (OpenExp env aenv t'))
-> (TagR a, ScopedExp t') -> Identity (TagR a, OpenExp env aenv t')
forall s t a b. Field2 s t a b => Lens s t a b
_2) ScopedExp t' -> OpenExp env aenv t'
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt [(TagR a, ScopedExp t')]
rhs)
          Cond ScopedExp PrimBool
e1 ScopedExp t'
e2 ScopedExp t'
e3         -> OpenExp env aenv PrimBool
-> OpenExp env aenv t'
-> OpenExp env aenv t'
-> OpenExp env aenv t'
forall env aenv t.
OpenExp env aenv PrimBool
-> OpenExp env aenv t -> OpenExp env aenv t -> OpenExp env aenv t
AST.Cond (ScopedExp PrimBool -> OpenExp env aenv PrimBool
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp PrimBool
e1) (ScopedExp t' -> OpenExp env aenv t'
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t'
e2) (ScopedExp t' -> OpenExp env aenv t'
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t'
e3)
          While TypeR t'
tp SmartExp t' -> ScopedExp PrimBool
p SmartExp t' -> ScopedExp t'
it ScopedExp t'
i       -> OpenFun env aenv (t' -> PrimBool)
-> OpenFun env aenv (t' -> t')
-> OpenExp env aenv t'
-> OpenExp env aenv t'
forall env aenv a.
OpenFun env aenv (a -> PrimBool)
-> OpenFun env aenv (a -> a)
-> OpenExp env aenv a
-> OpenExp env aenv a
AST.While (TypeR t'
-> (SmartExp t' -> ScopedExp PrimBool)
-> OpenFun env aenv (t' -> PrimBool)
forall a b.
HasCallStack =>
TypeR a -> (SmartExp a -> ScopedExp b) -> OpenFun env aenv (a -> b)
cvtFun1 TypeR t'
tp SmartExp t' -> ScopedExp PrimBool
p) (TypeR t'
-> (SmartExp t' -> ScopedExp t') -> OpenFun env aenv (t' -> t')
forall a b.
HasCallStack =>
TypeR a -> (SmartExp a -> ScopedExp b) -> OpenFun env aenv (a -> b)
cvtFun1 TypeR t'
tp SmartExp t' -> ScopedExp t'
it) (ScopedExp t' -> OpenExp env aenv t'
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp t'
i)
          PrimConst PrimConst t'
c           -> PrimConst t' -> OpenExp env aenv t'
forall t env aenv. PrimConst t -> OpenExp env aenv t
AST.PrimConst PrimConst t'
c
          PrimApp PrimFun (a -> t')
f ScopedExp a
e           -> PrimFun (a -> t') -> OpenExp env aenv a -> OpenExp env aenv t'
forall a r env' aenv'.
HasCallStack =>
PrimFun (a -> r) -> OpenExp env' aenv' a -> OpenExp env' aenv' r
cvtPrimFun PrimFun (a -> t')
f (ScopedExp a -> OpenExp env aenv a
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp a
e)
          Index TypeR t'
_ ScopedAcc (Array sh t')
a ScopedExp sh
e           -> ArrayVar aenv (Array sh t')
-> OpenExp env aenv sh -> OpenExp env aenv t'
forall aenv dim t env.
ArrayVar aenv (Array dim t)
-> OpenExp env aenv dim -> OpenExp env aenv t
AST.Index (ScopedAcc (Array sh t') -> ArrayVar aenv (Array sh t')
forall a. HasCallStack => ScopedAcc a -> ArrayVar aenv a
cvtAvar ScopedAcc (Array sh t')
a) (ScopedExp sh -> OpenExp env aenv sh
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp sh
e)
          LinearIndex TypeR t'
_ ScopedAcc (Array sh t')
a ScopedExp Int
i     -> ArrayVar aenv (Array sh t')
-> OpenExp env aenv Int -> OpenExp env aenv t'
forall aenv dim t env.
ArrayVar aenv (Array dim t)
-> OpenExp env aenv Int -> OpenExp env aenv t
AST.LinearIndex (ScopedAcc (Array sh t') -> ArrayVar aenv (Array sh t')
forall a. HasCallStack => ScopedAcc a -> ArrayVar aenv a
cvtAvar ScopedAcc (Array sh t')
a) (ScopedExp Int -> OpenExp env aenv Int
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp Int
i)
          Shape ShapeR t'
_ ScopedAcc (Array t' e)
a             -> ArrayVar aenv (Array t' e) -> OpenExp env aenv t'
forall aenv dim e env.
ArrayVar aenv (Array dim e) -> OpenExp env aenv dim
AST.Shape (ScopedAcc (Array t' e) -> ArrayVar aenv (Array t' e)
forall a. HasCallStack => ScopedAcc a -> ArrayVar aenv a
cvtAvar ScopedAcc (Array t' e)
a)
          ShapeSize ShapeR sh
shr ScopedExp sh
e       -> ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv Int
forall dim env aenv.
ShapeR dim -> OpenExp env aenv dim -> OpenExp env aenv Int
AST.ShapeSize ShapeR sh
shr (ScopedExp sh -> OpenExp env aenv sh
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp sh
e)
          Foreign TypeR t'
repr asm (x -> t')
ff SmartExp x -> SmartExp t'
f ScopedExp x
e   -> TypeR t'
-> asm (x -> t')
-> Fun () (x -> t')
-> OpenExp env aenv x
-> OpenExp env aenv t'
forall (asm :: * -> *) y x env aenv.
Foreign asm =>
TypeR y
-> asm (x -> y)
-> Fun () (x -> y)
-> OpenExp env aenv x
-> OpenExp env aenv y
AST.Foreign TypeR t'
repr asm (x -> t')
ff (Config
-> TypeR x -> (SmartExp x -> SmartExp t') -> Fun () (x -> t')
forall a b.
HasCallStack =>
Config -> TypeR a -> (SmartExp a -> SmartExp b) -> Fun () (a -> b)
convertSmartFun Config
config (ScopedExp x -> TypeR x
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
typeR ScopedExp x
e) SmartExp x -> SmartExp t'
f) (ScopedExp x -> OpenExp env aenv x
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp x
e)
          Coerce ScalarType a
t1 ScalarType t'
t2 ScopedExp a
e        -> ScalarType a
-> ScalarType t' -> OpenExp env aenv a -> OpenExp env aenv t'
forall a b env aenv.
BitSizeEq a b =>
ScalarType a
-> ScalarType b -> OpenExp env aenv a -> OpenExp env aenv b
AST.Coerce ScalarType a
t1 ScalarType t'
t2 (ScopedExp a -> OpenExp env aenv a
forall t'. HasCallStack => ScopedExp t' -> OpenExp env aenv t'
cvt ScopedExp a
e)

    cvtPrj :: forall a b c env1 aenv1. PairIdx (a, b) c -> AST.OpenExp env1 aenv1 (a, b) -> AST.OpenExp env1 aenv1 c
    cvtPrj :: PairIdx (a, b) c
-> OpenExp env1 aenv1 (a, b) -> OpenExp env1 aenv1 c
cvtPrj PairIdx (a, b) c
PairIdxLeft  (AST.Pair OpenExp env1 aenv1 t1
a OpenExp env1 aenv1 t2
_) = OpenExp env1 aenv1 c
OpenExp env1 aenv1 t1
a
    cvtPrj PairIdx (a, b) c
PairIdxRight (AST.Pair OpenExp env1 aenv1 t1
_ OpenExp env1 aenv1 t2
b) = OpenExp env1 aenv1 c
OpenExp env1 aenv1 t2
b
    cvtPrj PairIdx (a, b) c
ix OpenExp env1 aenv1 (a, b)
a
      | DeclareVars LeftHandSide ScalarType (a, b) env1 env'
lhs env1 :> env'
_ forall env''. (env' :> env'') -> Vars ScalarType env'' (a, b)
value <- TupR ScalarType (a, b) -> DeclareVars ScalarType (a, b) env1
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars (TupR ScalarType (a, b) -> DeclareVars ScalarType (a, b) env1)
-> TupR ScalarType (a, b) -> DeclareVars ScalarType (a, b) env1
forall a b. (a -> b) -> a -> b
$ OpenExp env1 aenv1 (a, b) -> TupR ScalarType (a, b)
forall aenv env t. HasCallStack => OpenExp aenv env t -> TypeR t
AST.expType OpenExp env1 aenv1 (a, b)
a
      = LeftHandSide ScalarType (a, b) env1 env'
-> OpenExp env1 aenv1 (a, b)
-> OpenExp env' aenv1 c
-> OpenExp env1 aenv1 c
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
AST.Let LeftHandSide ScalarType (a, b) env1 env'
lhs OpenExp env1 aenv1 (a, b)
a (PairIdx (a, b) c
-> OpenExp env' aenv1 (a, b) -> OpenExp env' aenv1 c
forall a b c env1 aenv1.
PairIdx (a, b) c
-> OpenExp env1 aenv1 (a, b) -> OpenExp env1 aenv1 c
cvtPrj PairIdx (a, b) c
ix (ExpVars env' (a, b) -> OpenExp env' aenv1 (a, b)
forall env t aenv. ExpVars env t -> OpenExp env aenv t
expVars ((env' :> env') -> ExpVars env' (a, b)
forall env''. (env' :> env'') -> Vars ScalarType env'' (a, b)
value env' :> env'
forall env. env :> env
weakenId)))

    cvtA :: HasCallStack => ScopedAcc a -> AST.OpenAcc aenv a
    cvtA :: ScopedAcc a -> OpenAcc aenv a
cvtA = Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc a
-> OpenAcc aenv a
forall aenv arrs.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> OpenAcc aenv arrs
convertSharingAcc Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv

    cvtAvar :: HasCallStack => ScopedAcc a -> AST.ArrayVar aenv a
    cvtAvar :: ScopedAcc a -> ArrayVar aenv a
cvtAvar ScopedAcc a
a = case ScopedAcc a -> OpenAcc aenv a
forall a. HasCallStack => ScopedAcc a -> OpenAcc aenv a
cvtA ScopedAcc a
a of
      AST.OpenAcc (AST.Avar ArrayVar aenv (Array sh e)
var) -> ArrayVar aenv a
ArrayVar aenv (Array sh e)
var
      OpenAcc aenv a
_                          -> String -> ArrayVar aenv a
forall a. HasCallStack => String -> a
internalError String
"Expected array computation in expression to be floated out"

    cvtFun1 :: HasCallStack => TypeR a -> (SmartExp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b)
    cvtFun1 :: TypeR a -> (SmartExp a -> ScopedExp b) -> OpenFun env aenv (a -> b)
cvtFun1 TypeR a
tp SmartExp a -> ScopedExp b
f
      | DeclareVars LeftHandSide ScalarType a env env'
lhs env :> env'
k forall env''. (env' :> env'') -> Vars ScalarType env'' a
value <- TypeR a -> DeclareVars ScalarType a env
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars TypeR a
tp
      = let
          lyt' :: Layout ScalarType env' env'
lyt' = Layout ScalarType env' env
-> LeftHandSide ScalarType a env env'
-> Vars ScalarType env' a
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout ((env :> env') -> ELayout env env -> Layout ScalarType env' env
forall env1 env2 (s :: * -> *) env'.
(env1 :> env2) -> Layout s env1 env' -> Layout s env2 env'
incLayout env :> env'
k ELayout env env
lyt) LeftHandSide ScalarType a env env'
lhs ((env' :> env') -> Vars ScalarType env' a
forall env''. (env' :> env'') -> Vars ScalarType env'' a
value env' :> env'
forall env. env :> env
weakenId)
          body :: ScopedExp b
body = SmartExp a -> ScopedExp b
f SmartExp a
forall a. HasCallStack => a
undefined
        in
          LeftHandSide ScalarType a env env'
-> OpenFun env' aenv b -> OpenFun env aenv (a -> b)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam LeftHandSide ScalarType a env env'
lhs (OpenFun env' aenv b -> OpenFun env aenv (a -> b))
-> OpenFun env' aenv b -> OpenFun env aenv (a -> b)
forall a b. (a -> b) -> a -> b
$ OpenExp env' aenv b -> OpenFun env' aenv b
forall env aenv t. OpenExp env aenv t -> OpenFun env aenv t
Body (OpenExp env' aenv b -> OpenFun env' aenv b)
-> OpenExp env' aenv b -> OpenFun env' aenv b
forall a b. (a -> b) -> a -> b
$ Config
-> Layout ScalarType env' env'
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp b
-> OpenExp env' aenv b
forall t env aenv.
HasCallStack =>
Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config Layout ScalarType env' env'
lyt' ArrayLayout aenv aenv
alyt [StableSharingExp]
env' [StableSharingAcc]
aenv ScopedExp b
body

    -- Push primitive function applications down through let bindings so that
    -- they are adjacent to their arguments. It looks a bit nicer this way.
    --
    cvtPrimFun :: HasCallStack => AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r
    cvtPrimFun :: PrimFun (a -> r) -> OpenExp env' aenv' a -> OpenExp env' aenv' r
cvtPrimFun PrimFun (a -> r)
f OpenExp env' aenv' a
e = case OpenExp env' aenv' a
e of
      AST.Let ELeftHandSide bnd_t env' env'
lhs OpenExp env' aenv' bnd_t
bnd OpenExp env' aenv' a
body -> ELeftHandSide bnd_t env' env'
-> OpenExp env' aenv' bnd_t
-> OpenExp env' aenv' r
-> OpenExp env' aenv' r
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
AST.Let ELeftHandSide bnd_t env' env'
lhs OpenExp env' aenv' bnd_t
bnd (PrimFun (a -> r) -> OpenExp env' aenv' a -> OpenExp env' aenv' r
forall a r env' aenv'.
HasCallStack =>
PrimFun (a -> r) -> OpenExp env' aenv' a -> OpenExp env' aenv' r
cvtPrimFun PrimFun (a -> r)
f OpenExp env' aenv' a
body)
      OpenExp env' aenv' a
x                    -> PrimFun (a -> r) -> OpenExp env' aenv' a -> OpenExp env' aenv' r
forall a r env aenv.
PrimFun (a -> r) -> OpenExp env aenv a -> OpenExp env aenv r
AST.PrimApp PrimFun (a -> r)
f OpenExp env' aenv' a
x

    -- Convert the flat list of equations into nested case statement
    -- directly on the tag variables.
    --
    cvtCase :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b
    cvtCase :: OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
cvtCase OpenExp env' aenv' a
s [(TagR a, OpenExp env' aenv' b)]
es
      | AST.Pair{} <- OpenExp env' aenv' a
s
      = OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
forall env' aenv' a b.
HasCallStack =>
OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
nested OpenExp env' aenv' a
s [(TagR a, OpenExp env' aenv' b)]
es
      | DeclareVars LeftHandSide ScalarType a env' env'
lhs env' :> env'
_ forall env''. (env' :> env'') -> Vars ScalarType env'' a
value <- TupR ScalarType a -> DeclareVars ScalarType a env'
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars (OpenExp env' aenv' a -> TupR ScalarType a
forall aenv env t. HasCallStack => OpenExp aenv env t -> TypeR t
AST.expType OpenExp env' aenv' a
s)
      = LeftHandSide ScalarType a env' env'
-> OpenExp env' aenv' a
-> OpenExp env' aenv' b
-> OpenExp env' aenv' b
forall bnd_t env env' aenv body_t.
ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
AST.Let LeftHandSide ScalarType a env' env'
lhs OpenExp env' aenv' a
s (OpenExp env' aenv' b -> OpenExp env' aenv' b)
-> OpenExp env' aenv' b -> OpenExp env' aenv' b
forall a b. (a -> b) -> a -> b
$ OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
forall env' aenv' a b.
HasCallStack =>
OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
nested (ExpVars env' a -> OpenExp env' aenv' a
forall env t aenv. ExpVars env t -> OpenExp env aenv t
expVars ((env' :> env') -> ExpVars env' a
forall env''. (env' :> env'') -> Vars ScalarType env'' a
value env' :> env'
forall env. env :> env
weakenId)) (ASetter
  [(TagR a, OpenExp env' aenv' b)]
  [(TagR a, OpenExp env' aenv' b)]
  (OpenExp env' aenv' b)
  (OpenExp env' aenv' b)
-> (OpenExp env' aenv' b -> OpenExp env' aenv' b)
-> [(TagR a, OpenExp env' aenv' b)]
-> [(TagR a, OpenExp env' aenv' b)]
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over (((TagR a, OpenExp env' aenv' b)
 -> Identity (TagR a, OpenExp env' aenv' b))
-> [(TagR a, OpenExp env' aenv' b)]
-> Identity [(TagR a, OpenExp env' aenv' b)]
forall (f :: * -> *) a b. Functor f => Setter (f a) (f b) a b
mapped (((TagR a, OpenExp env' aenv' b)
  -> Identity (TagR a, OpenExp env' aenv' b))
 -> [(TagR a, OpenExp env' aenv' b)]
 -> Identity [(TagR a, OpenExp env' aenv' b)])
-> ((OpenExp env' aenv' b -> Identity (OpenExp env' aenv' b))
    -> (TagR a, OpenExp env' aenv' b)
    -> Identity (TagR a, OpenExp env' aenv' b))
-> ASetter
     [(TagR a, OpenExp env' aenv' b)]
     [(TagR a, OpenExp env' aenv' b)]
     (OpenExp env' aenv' b)
     (OpenExp env' aenv' b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OpenExp env' aenv' b -> Identity (OpenExp env' aenv' b))
-> (TagR a, OpenExp env' aenv' b)
-> Identity (TagR a, OpenExp env' aenv' b)
forall s t a b. Field2 s t a b => Lens s t a b
_2) ((env' :> env') -> OpenExp env' aenv' b -> OpenExp env' aenv' b
forall (f :: * -> * -> * -> *) env env' aenv t.
SinkExp f =>
(env :> env') -> f env aenv t -> f env' aenv t
weakenE (LeftHandSide ScalarType a env' env' -> env' :> env'
forall (s :: * -> *) t env env'.
LeftHandSide s t env env' -> env :> env'
weakenWithLHS LeftHandSide ScalarType a env' env'
lhs)) [(TagR a, OpenExp env' aenv' b)]
es)
      where
        nested :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b
        nested :: OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
nested OpenExp env' aenv' a
_ [(_,r)] = OpenExp env' aenv' b
r
        nested OpenExp env' aenv' a
s [(TagR a, OpenExp env' aenv' b)]
rs      =
          let groups :: [[(TagR a, OpenExp env' aenv' b)]]
groups = ((TagR a, OpenExp env' aenv' b)
 -> (TagR a, OpenExp env' aenv' b) -> Bool)
-> [(TagR a, OpenExp env' aenv' b)]
-> [[(TagR a, OpenExp env' aenv' b)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (TagR a -> TagR a -> Bool
forall a. TagR a -> TagR a -> Bool
eqT (TagR a -> TagR a -> Bool)
-> ((TagR a, OpenExp env' aenv' b) -> TagR a)
-> (TagR a, OpenExp env' aenv' b)
-> (TagR a, OpenExp env' aenv' b)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (TagR a, OpenExp env' aenv' b) -> TagR a
forall a b. (a, b) -> a
fst) [(TagR a, OpenExp env' aenv' b)]
rs
              tags :: [PrimBool]
tags   = ([(TagR a, OpenExp env' aenv' b)] -> PrimBool)
-> [[(TagR a, OpenExp env' aenv' b)]] -> [PrimBool]
forall a b. (a -> b) -> [a] -> [b]
map (TagR a -> PrimBool
forall a. TagR a -> PrimBool
firstT (TagR a -> PrimBool)
-> ([(TagR a, OpenExp env' aenv' b)] -> TagR a)
-> [(TagR a, OpenExp env' aenv' b)]
-> PrimBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TagR a, OpenExp env' aenv' b) -> TagR a
forall a b. (a, b) -> a
fst ((TagR a, OpenExp env' aenv' b) -> TagR a)
-> ([(TagR a, OpenExp env' aenv' b)]
    -> (TagR a, OpenExp env' aenv' b))
-> [(TagR a, OpenExp env' aenv' b)]
-> TagR a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(TagR a, OpenExp env' aenv' b)] -> (TagR a, OpenExp env' aenv' b)
forall a. [a] -> a
head) [[(TagR a, OpenExp env' aenv' b)]]
groups
              e :: OpenExp env' aenv' PrimBool
e      = TagR a -> OpenExp env' aenv' a -> OpenExp env' aenv' PrimBool
forall a env' aenv'.
TagR a -> OpenExp env' aenv' a -> OpenExp env' aenv' PrimBool
prjT ((TagR a, OpenExp env' aenv' b) -> TagR a
forall a b. (a, b) -> a
fst ([(TagR a, OpenExp env' aenv' b)] -> (TagR a, OpenExp env' aenv' b)
forall a. [a] -> a
head [(TagR a, OpenExp env' aenv' b)]
rs)) OpenExp env' aenv' a
s
              rhs :: [OpenExp env' aenv' b]
rhs    = ([(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b)
-> [[(TagR a, OpenExp env' aenv' b)]] -> [OpenExp env' aenv' b]
forall a b. (a -> b) -> [a] -> [b]
map (OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
forall env' aenv' a b.
HasCallStack =>
OpenExp env' aenv' a
-> [(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b
nested OpenExp env' aenv' a
s ([(TagR a, OpenExp env' aenv' b)] -> OpenExp env' aenv' b)
-> ([(TagR a, OpenExp env' aenv' b)]
    -> [(TagR a, OpenExp env' aenv' b)])
-> [(TagR a, OpenExp env' aenv' b)]
-> OpenExp env' aenv' b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TagR a, OpenExp env' aenv' b) -> (TagR a, OpenExp env' aenv' b))
-> [(TagR a, OpenExp env' aenv' b)]
-> [(TagR a, OpenExp env' aenv' b)]
forall a b. (a -> b) -> [a] -> [b]
map (ASetter
  (TagR a, OpenExp env' aenv' b)
  (TagR a, OpenExp env' aenv' b)
  (TagR a)
  (TagR a)
-> (TagR a -> TagR a)
-> (TagR a, OpenExp env' aenv' b)
-> (TagR a, OpenExp env' aenv' b)
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter
  (TagR a, OpenExp env' aenv' b)
  (TagR a, OpenExp env' aenv' b)
  (TagR a)
  (TagR a)
forall s t a b. Field1 s t a b => Lens s t a b
_1 TagR a -> TagR a
forall a. TagR a -> TagR a
ignore)) [[(TagR a, OpenExp env' aenv' b)]]
groups
          in
          OpenExp env' aenv' PrimBool
-> [(PrimBool, OpenExp env' aenv' b)]
-> Maybe (OpenExp env' aenv' b)
-> OpenExp env' aenv' b
forall env aenv b.
OpenExp env aenv PrimBool
-> [(PrimBool, OpenExp env aenv b)]
-> Maybe (OpenExp env aenv b)
-> OpenExp env aenv b
AST.Case OpenExp env' aenv' PrimBool
e ([PrimBool]
-> [OpenExp env' aenv' b] -> [(PrimBool, OpenExp env' aenv' b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimBool]
tags [OpenExp env' aenv' b]
rhs) Maybe (OpenExp env' aenv' b)
forall a. Maybe a
Nothing

        -- Extract the variable representing this particular tag from the
        -- scrutinee. This is safe because we let-bind the argument first.
        prjT :: TagR a -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' TAG
        prjT :: TagR a -> OpenExp env' aenv' a -> OpenExp env' aenv' PrimBool
prjT = Maybe (OpenExp env' aenv' PrimBool) -> OpenExp env' aenv' PrimBool
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (OpenExp env' aenv' PrimBool)
 -> OpenExp env' aenv' PrimBool)
-> (TagR a
    -> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool))
-> TagR a
-> OpenExp env' aenv' a
-> OpenExp env' aenv' PrimBool
forall b a c d. (b -> a) -> (c -> d -> b) -> c -> d -> a
$$ TagR a
-> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool)
forall a env' aenv'.
TagR a
-> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool)
go
          where
            go :: TagR a -> AST.OpenExp env' aenv' a -> Maybe (AST.OpenExp env' aenv' TAG)
            go :: TagR a
-> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool)
go TagRtag{}        (AST.Pair OpenExp env' aenv' t1
l OpenExp env' aenv' t2
_) = OpenExp env' aenv' t1 -> Maybe (OpenExp env' aenv' t1)
forall a. a -> Maybe a
Just OpenExp env' aenv' t1
l
            go (TagRpair TagR a
ta TagR b
tb) (AST.Pair OpenExp env' aenv' t1
l OpenExp env' aenv' t2
r) =
              case TagR a
-> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool)
forall a env' aenv'.
TagR a
-> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool)
go TagR a
ta OpenExp env' aenv' a
OpenExp env' aenv' t1
l of
                Just OpenExp env' aenv' PrimBool
t  -> OpenExp env' aenv' PrimBool -> Maybe (OpenExp env' aenv' PrimBool)
forall a. a -> Maybe a
Just OpenExp env' aenv' PrimBool
t
                Maybe (OpenExp env' aenv' PrimBool)
Nothing -> TagR b
-> OpenExp env' aenv' b -> Maybe (OpenExp env' aenv' PrimBool)
forall a env' aenv'.
TagR a
-> OpenExp env' aenv' a -> Maybe (OpenExp env' aenv' PrimBool)
go TagR b
tb OpenExp env' aenv' b
OpenExp env' aenv' t2
r
            go TagR a
_ OpenExp env' aenv' a
_ = Maybe (OpenExp env' aenv' PrimBool)
forall a. Maybe a
Nothing

        -- Equality up to the first constructor tag encountered
        eqT :: TagR a -> TagR a -> Bool
        eqT :: TagR a -> TagR a -> Bool
eqT TagR a
a TagR a
b = (Any, Bool) -> Bool
forall a b. (a, b) -> b
snd ((Any, Bool) -> Bool) -> (Any, Bool) -> Bool
forall a b. (a -> b) -> a -> b
$ TagR a -> TagR a -> (Any, Bool)
forall a. TagR a -> TagR a -> (Any, Bool)
go TagR a
a TagR a
b
          where
            go :: TagR a -> TagR a -> (Any, Bool)
            go :: TagR a -> TagR a -> (Any, Bool)
go TagR a
TagRunit          TagR a
TagRunit          = Bool -> (Any, Bool)
forall x. x -> (Any, x)
no Bool
True
            go TagRsingle{}      TagRsingle{}      = Bool -> (Any, Bool)
forall x. x -> (Any, x)
no Bool
True
            go TagRundef{}       TagRundef{}       = Bool -> (Any, Bool)
forall x. x -> (Any, x)
no Bool
True
            go (TagRtag PrimBool
v1 TagR a
_)    (TagRtag PrimBool
v2 TagR a
_)    = Bool -> (Any, Bool)
forall x. x -> (Any, x)
yes (PrimBool
v1 PrimBool -> PrimBool -> Bool
forall a. Eq a => a -> a -> Bool
== PrimBool
v2)
            go (TagRpair TagR a
a1 TagR b
b1)  (TagRpair TagR a
a2 TagR b
b2)  =
              let (Any Bool
r, Bool
s) = TagR a -> TagR a -> (Any, Bool)
forall a. TagR a -> TagR a -> (Any, Bool)
go TagR a
a1 TagR a
TagR a
a2
               in case Bool
r of
                    Bool
True  -> Bool -> (Any, Bool)
forall x. x -> (Any, x)
yes Bool
s
                    Bool
False -> TagR b -> TagR b -> (Any, Bool)
forall a. TagR a -> TagR a -> (Any, Bool)
go TagR b
b1 TagR b
TagR b
b2
            go TagR a
_ TagR a
_ = Bool -> (Any, Bool)
forall x. x -> (Any, x)
no Bool
False

        firstT :: TagR a -> TAG
        firstT :: TagR a -> PrimBool
firstT = Maybe PrimBool -> PrimBool
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe PrimBool -> PrimBool)
-> (TagR a -> Maybe PrimBool) -> TagR a -> PrimBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TagR a -> Maybe PrimBool
forall a. TagR a -> Maybe PrimBool
go
          where
            go :: TagR a -> Maybe TAG
            go :: TagR a -> Maybe PrimBool
go (TagRtag PrimBool
v TagR a
_)  = PrimBool -> Maybe PrimBool
forall a. a -> Maybe a
Just PrimBool
v
            go (TagRpair TagR a
a TagR b
b) =
              case TagR a -> Maybe PrimBool
forall a. TagR a -> Maybe PrimBool
go TagR a
a of
                Just PrimBool
t  -> PrimBool -> Maybe PrimBool
forall a. a -> Maybe a
Just PrimBool
t
                Maybe PrimBool
Nothing -> TagR b -> Maybe PrimBool
forall a. TagR a -> Maybe PrimBool
go TagR b
b
            go TagR a
_ = Maybe PrimBool
forall a. Maybe a
Nothing

        -- Replace the first constructor tag encountered with a regular
        -- scalar tag, so that that tag will be ignored in the recursive
        -- case.
        ignore :: TagR a -> TagR a
ignore = (Any, TagR a) -> TagR a
forall a b. (a, b) -> b
snd ((Any, TagR a) -> TagR a)
-> (TagR a -> (Any, TagR a)) -> TagR a -> TagR a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TagR a -> (Any, TagR a)
forall a. TagR a -> (Any, TagR a)
go
          where
            go :: TagR a -> (Any, TagR a)
            go :: TagR a -> (Any, TagR a)
go TagR a
TagRunit         = TagR () -> (Any, TagR ())
forall x. x -> (Any, x)
no  (TagR () -> (Any, TagR ())) -> TagR () -> (Any, TagR ())
forall a b. (a -> b) -> a -> b
$ TagR ()
TagRunit
            go (TagRsingle ScalarType a
t)   = TagR a -> (Any, TagR a)
forall x. x -> (Any, x)
no  (TagR a -> (Any, TagR a)) -> TagR a -> (Any, TagR a)
forall a b. (a -> b) -> a -> b
$ ScalarType a -> TagR a
forall a. ScalarType a -> TagR a
TagRsingle ScalarType a
t
            go (TagRundef ScalarType a
t)    = TagR a -> (Any, TagR a)
forall x. x -> (Any, x)
no  (TagR a -> (Any, TagR a)) -> TagR a -> (Any, TagR a)
forall a b. (a -> b) -> a -> b
$ ScalarType a -> TagR a
forall a. ScalarType a -> TagR a
TagRundef ScalarType a
t
            go (TagRtag PrimBool
_ TagR a
a)    = TagR (PrimBool, a) -> (Any, TagR (PrimBool, a))
forall x. x -> (Any, x)
yes (TagR (PrimBool, a) -> (Any, TagR (PrimBool, a)))
-> TagR (PrimBool, a) -> (Any, TagR (PrimBool, a))
forall a b. (a -> b) -> a -> b
$ TagR PrimBool -> TagR a -> TagR (PrimBool, a)
forall a b. TagR a -> TagR b -> TagR (a, b)
TagRpair (ScalarType PrimBool -> TagR PrimBool
forall a. ScalarType a -> TagR a
TagRundef ScalarType PrimBool
forall a. IsScalar a => ScalarType a
scalarType) TagR a
a
            go (TagRpair TagR a
a1 TagR b
a2) =
              let (Any Bool
r, TagR a
a1') = TagR a -> (Any, TagR a)
forall a. TagR a -> (Any, TagR a)
go TagR a
a1
               in case Bool
r of
                    Bool
True  -> TagR (a, b) -> (Any, TagR (a, b))
forall x. x -> (Any, x)
yes (TagR (a, b) -> (Any, TagR (a, b)))
-> TagR (a, b) -> (Any, TagR (a, b))
forall a b. (a -> b) -> a -> b
$ TagR a -> TagR b -> TagR (a, b)
forall a b. TagR a -> TagR b -> TagR (a, b)
TagRpair TagR a
a1' TagR b
a2
                    Bool
False -> TagR a -> TagR b -> TagR (a, b)
forall a b. TagR a -> TagR b -> TagR (a, b)
TagRpair TagR a
a1' (TagR b -> TagR (a, b)) -> (Any, TagR b) -> (Any, TagR (a, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TagR b -> (Any, TagR b)
forall a. TagR a -> (Any, TagR a)
go TagR b
a2

        yes :: x -> (Any, x)
        yes :: x -> (Any, x)
yes x
e = (Bool -> Any
Any Bool
True, x
e)

        no :: x -> (Any, x)
        no :: x -> (Any, x)
no = x -> (Any, x)
forall (f :: * -> *) a. Applicative f => a -> f a
pure


-- | Convert a unary functions
--
convertSharingFun1
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]       -- currently bound array sharing-variables
    -> TypeR a
    -> (SmartExp a -> ScopedExp b)
    -> AST.Fun aenv (a -> b)
convertSharingFun1 :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> (SmartExp a -> ScopedExp b)
-> Fun aenv (a -> b)
convertSharingFun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv TypeR a
tp SmartExp a -> ScopedExp b
f
  | DeclareVars LeftHandSide ScalarType a () env'
lhs () :> env'
_ forall env''. (env' :> env'') -> Vars ScalarType env'' a
value <- TypeR a -> DeclareVars ScalarType a ()
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars TypeR a
tp
  = let
      a :: SmartExp t
a               = PreSmartExp SmartAcc SmartExp t -> SmartExp t
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp t
forall a. HasCallStack => a
undefined             -- the 'tag' was already embedded in Phase 1
      lyt :: Layout ScalarType env' env'
lyt             = Layout ScalarType env' ()
-> LeftHandSide ScalarType a () env'
-> Vars ScalarType env' a
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout Layout ScalarType env' ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout LeftHandSide ScalarType a () env'
lhs ((env' :> env') -> Vars ScalarType env' a
forall env''. (env' :> env'') -> Vars ScalarType env'' a
value env' :> env'
forall env. env :> env
weakenId)
      openF :: OpenExp env' aenv b
openF           = Config
-> Layout ScalarType env' env'
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp b
-> OpenExp env' aenv b
forall t env aenv.
HasCallStack =>
Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config Layout ScalarType env' env'
lyt ArrayLayout aenv aenv
alyt [] [StableSharingAcc]
aenv (SmartExp a -> ScopedExp b
f SmartExp a
forall t. SmartExp t
a)
    in
      LeftHandSide ScalarType a () env'
-> OpenFun env' aenv b -> Fun aenv (a -> b)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam LeftHandSide ScalarType a () env'
lhs (OpenExp env' aenv b -> OpenFun env' aenv b
forall env aenv t. OpenExp env aenv t -> OpenFun env aenv t
Body OpenExp env' aenv b
openF)

-- | Convert a binary functions
--
convertSharingFun2
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]       -- currently bound array sharing-variables
    -> TypeR a
    -> TypeR b
    -> (SmartExp a -> SmartExp b -> ScopedExp c)
    -> AST.Fun aenv (a -> b -> c)
convertSharingFun2 :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
convertSharingFun2 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv TypeR a
ta TypeR b
tb SmartExp a -> SmartExp b -> ScopedExp c
f
  | DeclareVars LeftHandSide ScalarType a () env'
lhs1 () :> env'
_  forall env''. (env' :> env'') -> Vars ScalarType env'' a
value1 <- TypeR a -> DeclareVars ScalarType a ()
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars TypeR a
ta
  , DeclareVars LeftHandSide ScalarType b env' env'
lhs2 env' :> env'
k2 forall env''. (env' :> env'') -> Vars ScalarType env'' b
value2 <- TypeR b -> DeclareVars ScalarType b env'
forall (s :: * -> *) t env. TupR s t -> DeclareVars s t env
declareVars TypeR b
tb
  = let
      a :: SmartExp t
a               = PreSmartExp SmartAcc SmartExp t -> SmartExp t
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp t
forall a. HasCallStack => a
undefined
      b :: SmartExp t
b               = PreSmartExp SmartAcc SmartExp t -> SmartExp t
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp t
forall a. HasCallStack => a
undefined
      lyt1 :: Layout ScalarType env' env'
lyt1            = Layout ScalarType env' ()
-> LeftHandSide ScalarType a () env'
-> Vars ScalarType env' a
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout Layout ScalarType env' ()
forall (s :: * -> *) env. Layout s env ()
EmptyLayout LeftHandSide ScalarType a () env'
lhs1 ((env' :> env') -> Vars ScalarType env' a
forall env''. (env' :> env'') -> Vars ScalarType env'' a
value1 env' :> env'
k2)
      lyt2 :: Layout ScalarType env' env'
lyt2            = Layout ScalarType env' env'
-> LeftHandSide ScalarType b env' env'
-> Vars ScalarType env' b
-> Layout ScalarType env' env'
forall (s :: * -> *) env env1 t env2.
Layout s env env1
-> LeftHandSide s t env1 env2 -> Vars s env t -> Layout s env env2
PushLayout Layout ScalarType env' env'
lyt1        LeftHandSide ScalarType b env' env'
lhs2 ((env' :> env') -> Vars ScalarType env' b
forall env''. (env' :> env'') -> Vars ScalarType env'' b
value2 env' :> env'
forall env. env :> env
weakenId)
      openF :: OpenExp env' aenv c
openF           = Config
-> Layout ScalarType env' env'
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp c
-> OpenExp env' aenv c
forall t env aenv.
HasCallStack =>
Config
-> ELayout env env
-> ArrayLayout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> OpenExp env aenv t
convertSharingExp Config
config Layout ScalarType env' env'
lyt2 ArrayLayout aenv aenv
alyt [] [StableSharingAcc]
aenv (SmartExp a -> SmartExp b -> ScopedExp c
f SmartExp a
forall t. SmartExp t
a SmartExp b
forall t. SmartExp t
b)
    in
      LeftHandSide ScalarType a () env'
-> OpenFun env' aenv (b -> c) -> Fun aenv (a -> b -> c)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam LeftHandSide ScalarType a () env'
lhs1 (OpenFun env' aenv (b -> c) -> Fun aenv (a -> b -> c))
-> OpenFun env' aenv (b -> c) -> Fun aenv (a -> b -> c)
forall a b. (a -> b) -> a -> b
$ LeftHandSide ScalarType b env' env'
-> OpenFun env' aenv c -> OpenFun env' aenv (b -> c)
forall a env env' aenv t.
ELeftHandSide a env env'
-> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
Lam LeftHandSide ScalarType b env' env'
lhs2 (OpenFun env' aenv c -> OpenFun env' aenv (b -> c))
-> OpenFun env' aenv c -> OpenFun env' aenv (b -> c)
forall a b. (a -> b) -> a -> b
$ OpenExp env' aenv c -> OpenFun env' aenv c
forall env aenv t. OpenExp env aenv t -> OpenFun env aenv t
Body OpenExp env' aenv c
openF

-- | Convert a unary stencil function
--
convertSharingStencilFun1
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]               -- currently bound array sharing-variables
    -> R.StencilR sh a stencil
    -> (SmartExp stencil -> ScopedExp b)
    -> AST.Fun aenv (stencil -> b)
convertSharingStencilFun1 :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> StencilR sh a stencil
-> (SmartExp stencil -> ScopedExp b)
-> Fun aenv (stencil -> b)
convertSharingStencilFun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv StencilR sh a stencil
sR1 SmartExp stencil -> ScopedExp b
stencil =
  Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR stencil
-> (SmartExp stencil -> ScopedExp b)
-> Fun aenv (stencil -> b)
forall aenv a b.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> (SmartExp a -> ScopedExp b)
-> Fun aenv (a -> b)
convertSharingFun1 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv (StencilR sh a stencil -> TypeR stencil
forall sh e pat. StencilR sh e pat -> TypeR pat
R.stencilR StencilR sh a stencil
sR1) SmartExp stencil -> ScopedExp b
stencil

-- | Convert a binary stencil function
--
convertSharingStencilFun2
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]               -- currently bound array sharing-variables
    -> R.StencilR sh a stencil1
    -> R.StencilR sh b stencil2
    -> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
    -> AST.Fun aenv (stencil1 -> stencil2 -> c)
convertSharingStencilFun2 :: Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
-> Fun aenv (stencil1 -> stencil2 -> c)
convertSharingStencilFun2 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv StencilR sh a stencil1
sR1 StencilR sh b stencil2
sR2 SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c
stencil =
  Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR stencil1
-> TypeR stencil2
-> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
-> Fun aenv (stencil1 -> stencil2 -> c)
forall aenv a b c.
HasCallStack =>
Config
-> ArrayLayout aenv aenv
-> [StableSharingAcc]
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> ScopedExp c)
-> Fun aenv (a -> b -> c)
convertSharingFun2 Config
config ArrayLayout aenv aenv
alyt [StableSharingAcc]
aenv (StencilR sh a stencil1 -> TypeR stencil1
forall sh e pat. StencilR sh e pat -> TypeR pat
R.stencilR StencilR sh a stencil1
sR1) (StencilR sh b stencil2 -> TypeR stencil2
forall sh e pat. StencilR sh e pat -> TypeR pat
R.stencilR StencilR sh b stencil2
sR2) SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c
stencil


-- Sharing recovery
-- ================

-- Sharing recovery proceeds in two phases:
--
-- /Phase One: build the occurrence map/
--
-- This is a top-down traversal of the AST that computes a map from AST nodes to the number of
-- occurrences of that AST node in the overall Accelerate program.  An occurrences count of two or
-- more indicates sharing.
--
-- IMPORTANT: To avoid unfolding the sharing, we do not descent into subtrees that we have
--   previously encountered.  Hence, the complexity is proportional to the number of nodes in the
--   tree /with/ sharing.  Consequently, the occurrence count is that in the tree with sharing
--   as well.
--
-- During computation of the occurrences, the tree is annotated with stable names on every node
-- using 'AccSharing' constructors and all but the first occurrence of shared subtrees are pruned
-- using 'AvarSharing' constructors (see 'SharingAcc' below).  This phase is impure as it is based
-- on stable names.
--
-- We use a hash table (instead of 'Data.Map') as computing stable names forces us to live in IO
-- anyway.  Once, the computation of occurrence counts is complete, we freeze the hash table into
-- a 'Data.Map'.
--
-- (Implemented by 'makeOccMap*'.)
--
-- /Phase Two: determine scopes and inject sharing information/
--
-- This is a bottom-up traversal that determines the scope for every binding to be introduced
-- to share a subterm.  It uses the occurrence map to determine, for every shared subtree, the
-- lowest AST node at which the binding for that shared subtree can be placed (using a
-- 'AletSharing' constructor)— it's the meet of all the shared subtree occurrences.
--
-- The second phase is also replacing the first occurrence of each shared subtree with a
-- 'AvarSharing' node and floats the shared subtree up to its binding point.
--
--  (Implemented by 'determineScopes*'.)
--
-- /Sharing recovery for expressions/
--
-- We recover sharing for each expression (including function bodies) independently of any other
-- expression — i.e., we cannot share scalar expressions across array computations.  Hence, during
-- Phase One, we mark all scalar expression nodes with a stable name and compute one occurrence map
-- for every scalar expression (including functions) that occurs in an array computation.  These
-- occurrence maps are added to the root of scalar expressions using 'RootExp'.
--
-- NB: We do not need to worry sharing recovery will try to float a shared subexpression past a
--     binder that occurs in that subexpression.  Why?  Otherwise, the binder would already occur
--     out of scope in the original source program.
--
-- /Lambda bound variables/
--
-- During sharing recovery, lambda bound variables appear in the form of 'Atag' and 'Tag' data
-- constructors.  The tag values are determined during Phase One of sharing recovery by computing
-- the /level/ of each variable at its binding occurrence.  The level at the root of the AST is 0
-- and increases by one with each lambda on each path through the AST.

-- Stable names
-- ------------

-- Opaque stable name for AST nodes — used to key the occurrence map.
--
data StableASTName c where
  StableASTName :: StableName (c t) -> StableASTName c

instance Show (StableASTName c) where
  show :: StableASTName c -> String
show (StableASTName StableName (c t)
sn) = Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ StableName (c t) -> Int
forall a. StableName a -> Int
hashStableName StableName (c t)
sn

instance Eq (StableASTName c) where
  StableASTName StableName (c t)
sn1 == :: StableASTName c -> StableASTName c -> Bool
== StableASTName StableName (c t)
sn2 = StableName (c t) -> StableName (c t) -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName (c t)
sn1 StableName (c t)
sn2

instance Hashable (StableASTName c) where
  hashWithSalt :: Int -> StableASTName c -> Int
hashWithSalt Int
s (StableASTName StableName (c t)
sn) = Int -> StableName (c t) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s StableName (c t)
sn

makeStableAST :: c t -> IO (StableName (c t))
makeStableAST :: c t -> IO (StableName (c t))
makeStableAST c t
e = c t
e c t -> IO (StableName (c t)) -> IO (StableName (c t))
`seq` c t -> IO (StableName (c t))
forall a. a -> IO (StableName a)
makeStableName c t
e

-- Stable name for an AST node including the height of the AST representing the array computation.
--
data StableNameHeight t = StableNameHeight (StableName t) Int

instance Eq (StableNameHeight t) where
  (StableNameHeight StableName t
sn1 Int
_) == :: StableNameHeight t -> StableNameHeight t -> Bool
== (StableNameHeight StableName t
sn2 Int
_) = StableName t -> StableName t -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName t
sn1 StableName t
sn2

higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool
StableNameHeight StableName t1
_ Int
h1 higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool
`higherSNH` StableNameHeight StableName t2
_ Int
h2 = Int
h1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
h2

hashStableNameHeight :: StableNameHeight t -> Int
hashStableNameHeight :: StableNameHeight t -> Int
hashStableNameHeight (StableNameHeight StableName t
sn Int
_) = StableName t -> Int
forall a. StableName a -> Int
hashStableName StableName t
sn

-- Mutable occurrence map
-- ----------------------

-- Hash table keyed on the stable names of array computations.
--
type HashTable key val = Hash.BasicHashTable key val
type ASTHashTable c v  = HashTable (StableASTName c) v

-- Mutable hashtable version of the occurrence map, which associates each AST node with an
-- occurrence count and the height of the AST.
--
type OccMapHash c = ASTHashTable c (Int, Int)

-- Create a new hash table keyed on AST nodes.
--
newASTHashTable :: IO (ASTHashTable c v)
newASTHashTable :: IO (ASTHashTable c v)
newASTHashTable = IO (ASTHashTable c v)
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
Hash.new

-- Enter one AST node occurrence into an occurrence map.  Returns 'Just h' if this is a repeated
-- occurrence and the height of the repeatedly occurring AST is 'h'.
--
-- If this is the first occurrence, the 'height' *argument* must provide the height of the AST;
-- otherwise, the height will be *extracted* from the occurrence map.  In the latter case, this
-- function yields the AST height.
--
enterOcc :: OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc :: OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc OccMapHash c
occMap StableASTName c
sa Int
height
  = OccMapHash c
-> StableASTName c
-> (Maybe (Int, Int) -> (Maybe (Int, Int), Maybe Int))
-> IO (Maybe Int)
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> (Maybe v, a)) -> IO a
Hash.mutate OccMapHash c
occMap StableASTName c
sa
  ((Maybe (Int, Int) -> (Maybe (Int, Int), Maybe Int))
 -> IO (Maybe Int))
-> (Maybe (Int, Int) -> (Maybe (Int, Int), Maybe Int))
-> IO (Maybe Int)
forall a b. (a -> b) -> a -> b
$ \case
      Maybe (Int, Int)
Nothing           -> ((Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
1,   Int
height),  Maybe Int
forall a. Maybe a
Nothing)
      Just (Int
n, Int
heightS) -> ((Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Int
heightS), Int -> Maybe Int
forall a. a -> Maybe a
Just Int
heightS)


-- Immutable occurrence map
-- ------------------------

-- Immutable version of the occurrence map (storing the occurrence count only, not the height).  We
-- use the 'StableName' hash to index an 'IntMap' and disambiguate 'StableName's with identical
-- hashes explicitly, storing them in a list in the 'IntMap'.
--
type OccMap c = IntMap.IntMap [(StableASTName c, Int)]

-- Turn a mutable into an immutable occurrence map.
--
freezeOccMap :: OccMapHash c -> IO (OccMap c)
freezeOccMap :: OccMapHash c -> IO (OccMap c)
freezeOccMap OccMapHash c
oc
  = do
      [(StableASTName c, (Int, Int))]
ocl <- OccMapHash c -> IO [(StableASTName c, (Int, Int))]
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> IO [(k, v)]
Hash.toList OccMapHash c
oc
      String -> String -> IO ()
traceChunk String
"OccMap" ([(StableASTName c, (Int, Int))] -> String
forall a. Show a => a -> String
show [(StableASTName c, (Int, Int))]
ocl)

      OccMap c -> IO (OccMap c)
forall (m :: * -> *) a. Monad m => a -> m a
return (OccMap c -> IO (OccMap c))
-> ([(StableASTName c, (Int, Int))] -> OccMap c)
-> [(StableASTName c, (Int, Int))]
-> IO (OccMap c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, [(StableASTName c, Int)])] -> OccMap c
forall a. [(Int, a)] -> IntMap a
IntMap.fromList
             ([(Int, [(StableASTName c, Int)])] -> OccMap c)
-> ([(StableASTName c, (Int, Int))]
    -> [(Int, [(StableASTName c, Int)])])
-> [(StableASTName c, (Int, Int))]
-> OccMap c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(StableASTName c, Int)] -> (Int, [(StableASTName c, Int)]))
-> [[(StableASTName c, Int)]] -> [(Int, [(StableASTName c, Int)])]
forall a b. (a -> b) -> [a] -> [b]
map (\[(StableASTName c, Int)]
kvs -> ((StableASTName c, Int) -> Int
forall (c :: * -> *) b. (StableASTName c, b) -> Int
key ([(StableASTName c, Int)] -> (StableASTName c, Int)
forall a. [a] -> a
head [(StableASTName c, Int)]
kvs), [(StableASTName c, Int)]
kvs))
             ([[(StableASTName c, Int)]] -> [(Int, [(StableASTName c, Int)])])
-> ([(StableASTName c, (Int, Int))] -> [[(StableASTName c, Int)]])
-> [(StableASTName c, (Int, Int))]
-> [(Int, [(StableASTName c, Int)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((StableASTName c, Int) -> (StableASTName c, Int) -> Bool)
-> [(StableASTName c, Int)] -> [[(StableASTName c, Int)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (StableASTName c, Int) -> (StableASTName c, Int) -> Bool
forall (c :: * -> *) b (c :: * -> *) b.
(StableASTName c, b) -> (StableASTName c, b) -> Bool
sameKey
             ([(StableASTName c, Int)] -> [[(StableASTName c, Int)]])
-> ([(StableASTName c, (Int, Int))] -> [(StableASTName c, Int)])
-> [(StableASTName c, (Int, Int))]
-> [[(StableASTName c, Int)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((StableASTName c, (Int, Int)) -> (StableASTName c, Int))
-> [(StableASTName c, (Int, Int))] -> [(StableASTName c, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (StableASTName c, (Int, Int)) -> (StableASTName c, Int)
forall a b b. (a, (b, b)) -> (a, b)
dropHeight
             ([(StableASTName c, (Int, Int))] -> IO (OccMap c))
-> [(StableASTName c, (Int, Int))] -> IO (OccMap c)
forall a b. (a -> b) -> a -> b
$ [(StableASTName c, (Int, Int))]
ocl
  where
    key :: (StableASTName c, b) -> Int
key (StableASTName StableName (c t)
sn, b
_) = StableName (c t) -> Int
forall a. StableName a -> Int
hashStableName StableName (c t)
sn
    sameKey :: (StableASTName c, b) -> (StableASTName c, b) -> Bool
sameKey (StableASTName c, b)
kv1 (StableASTName c, b)
kv2           = (StableASTName c, b) -> Int
forall (c :: * -> *) b. (StableASTName c, b) -> Int
key (StableASTName c, b)
kv1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (StableASTName c, b) -> Int
forall (c :: * -> *) b. (StableASTName c, b) -> Int
key (StableASTName c, b)
kv2
    dropHeight :: (a, (b, b)) -> (a, b)
dropHeight (a
k, (b
cnt, b
_))  = (a
k, b
cnt)

-- Look up the occurrence map keyed by array computations using a stable name.  If the key does
-- not exist in the map, return an occurrence count of '1'.
--
lookupWithASTName :: OccMap c -> StableASTName c -> Int
lookupWithASTName :: OccMap c -> StableASTName c -> Int
lookupWithASTName OccMap c
oc sa :: StableASTName c
sa@(StableASTName StableName (c t)
sn)
  = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
1 (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> OccMap c -> Maybe [(StableASTName c, Int)]
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (StableName (c t) -> Int
forall a. StableName a -> Int
hashStableName StableName (c t)
sn) OccMap c
oc Maybe [(StableASTName c, Int)]
-> ([(StableASTName c, Int)] -> Maybe Int) -> Maybe Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StableASTName c -> [(StableASTName c, Int)] -> Maybe Int
forall a b. Eq a => a -> [(a, b)] -> Maybe b
Prelude.lookup StableASTName c
sa

-- Look up the occurrence map keyed by array computations using a sharing array computation.  If an
-- the key does not exist in the map, return an occurrence count of '1'.
--
lookupWithSharingAcc :: OccMap SmartAcc -> StableSharingAcc -> Int
lookupWithSharingAcc :: OccMap SmartAcc -> StableSharingAcc -> Int
lookupWithSharingAcc OccMap SmartAcc
oc (StableSharingAcc (StableNameHeight StableName (SmartAcc arrs)
sn Int
_) SharingAcc ScopedAcc ScopedExp arrs
_)
  = OccMap SmartAcc -> StableASTName SmartAcc -> Int
forall (c :: * -> *). OccMap c -> StableASTName c -> Int
lookupWithASTName OccMap SmartAcc
oc (StableName (SmartAcc arrs) -> StableASTName SmartAcc
forall (c :: * -> *) t. StableName (c t) -> StableASTName c
StableASTName StableName (SmartAcc arrs)
sn)

-- Look up the occurrence map keyed by scalar expressions using a sharing expression.  If an
-- the key does not exist in the map, return an occurrence count of '1'.
--
lookupWithSharingExp :: OccMap SmartExp -> StableSharingExp -> Int
lookupWithSharingExp :: OccMap SmartExp -> StableSharingExp -> Int
lookupWithSharingExp OccMap SmartExp
oc (StableSharingExp (StableNameHeight StableName (SmartExp t)
sn Int
_) SharingExp ScopedAcc ScopedExp t
_)
  = OccMap SmartExp -> StableASTName SmartExp -> Int
forall (c :: * -> *). OccMap c -> StableASTName c -> Int
lookupWithASTName OccMap SmartExp
oc (StableName (SmartExp t) -> StableASTName SmartExp
forall (c :: * -> *) t. StableName (c t) -> StableASTName c
StableASTName StableName (SmartExp t)
sn)


-- Stable 'SmartAcc' nodes
-- ------------------

-- Stable name for 'SmartAcc' nodes including the height of the AST.
--
type StableAccName t = StableNameHeight (SmartAcc t)

-- Interleave sharing annotations into an array computation AST.  Subtrees can be marked as being
-- represented by variable (binding a shared subtree) using 'AvarSharing' and as being prefixed by
-- a let binding (for a shared subtree) using 'AletSharing'.
--
data SharingAcc acc exp arrs where
  AvarSharing :: StableAccName arrs -> ArraysR arrs             -> SharingAcc acc exp arrs
  AletSharing :: StableSharingAcc -> acc arrs                   -> SharingAcc acc exp arrs
  AccSharing  :: StableAccName arrs -> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs

instance HasArraysR acc => HasArraysR (SharingAcc acc exp) where
  arraysR :: SharingAcc acc exp a -> ArraysR a
arraysR (AvarSharing StableAccName a
_ ArraysR a
repr) = ArraysR a
repr
  arraysR (AletSharing StableSharingAcc
_ acc a
acc)  = acc a -> ArraysR a
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR acc a
acc
  arraysR (AccSharing  StableAccName a
_ PreSmartAcc acc exp a
acc)  = PreSmartAcc acc exp a -> ArraysR a
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR PreSmartAcc acc exp a
acc


-- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If
-- the expression is rooted in a function, the list contains the tags of the variables bound by the
-- immediate surrounding lambdas.
data UnscopedAcc t = UnscopedAcc [Int] (SharingAcc UnscopedAcc RootExp t)

instance HasArraysR UnscopedAcc where
  arraysR :: UnscopedAcc a -> ArraysR a
arraysR (UnscopedAcc [Int]
_ SharingAcc UnscopedAcc RootExp a
acc) = SharingAcc UnscopedAcc RootExp a -> ArraysR a
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR SharingAcc UnscopedAcc RootExp a
acc


-- Array expression with sharing. For expressions rooted in functions the list holds a sorted
-- environment corresponding to the variables bound in the immediate surounding lambdas.
data ScopedAcc t = ScopedAcc [StableSharingAcc] (SharingAcc ScopedAcc ScopedExp t)

instance HasArraysR ScopedAcc where
  arraysR :: ScopedAcc a -> ArraysR a
arraysR (ScopedAcc [StableSharingAcc]
_ SharingAcc ScopedAcc ScopedExp a
acc) = SharingAcc ScopedAcc ScopedExp a -> ArraysR a
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR SharingAcc ScopedAcc ScopedExp a
acc


-- Stable name for an array computation associated with its sharing-annotated version.
--
data StableSharingAcc where
  StableSharingAcc :: StableAccName arrs
                   -> SharingAcc ScopedAcc ScopedExp arrs
                   -> StableSharingAcc

instance Show StableSharingAcc where
  show :: StableSharingAcc -> String
show (StableSharingAcc StableAccName arrs
sn SharingAcc ScopedAcc ScopedExp arrs
_) = Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ StableAccName arrs -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableAccName arrs
sn

instance Eq StableSharingAcc where
  StableSharingAcc (StableNameHeight StableName (SmartAcc arrs)
sn1 Int
_) SharingAcc ScopedAcc ScopedExp arrs
_ == :: StableSharingAcc -> StableSharingAcc -> Bool
== StableSharingAcc (StableNameHeight StableName (SmartAcc arrs)
sn2 Int
_) SharingAcc ScopedAcc ScopedExp arrs
_
    = StableName (SmartAcc arrs) -> StableName (SmartAcc arrs) -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName (SmartAcc arrs)
sn1 StableName (SmartAcc arrs)
sn2

higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool
StableSharingAcc StableAccName arrs
sn1 SharingAcc ScopedAcc ScopedExp arrs
_ higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool
`higherSSA` StableSharingAcc StableAccName arrs
sn2 SharingAcc ScopedAcc ScopedExp arrs
_ = StableAccName arrs
sn1 StableAccName arrs -> StableAccName arrs -> Bool
forall t1 t2. StableNameHeight t1 -> StableNameHeight t2 -> Bool
`higherSNH` StableAccName arrs
sn2

-- Test whether the given stable names matches an array computation with sharing.
--
matchStableAcc :: StableAccName arrs -> StableSharingAcc -> Bool
matchStableAcc :: StableAccName arrs -> StableSharingAcc -> Bool
matchStableAcc (StableNameHeight StableName (SmartAcc arrs)
sn1 Int
_) (StableSharingAcc (StableNameHeight StableName (SmartAcc arrs)
sn2 Int
_) SharingAcc ScopedAcc ScopedExp arrs
_)
  = StableName (SmartAcc arrs) -> StableName (SmartAcc arrs) -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName (SmartAcc arrs)
sn1 StableName (SmartAcc arrs)
sn2

-- Dummy entry for environments to be used for unused variables.
--
{-# NOINLINE noStableAccName #-}
noStableAccName :: StableAccName arrs
noStableAccName :: StableAccName arrs
noStableAccName = IO (StableAccName arrs) -> StableAccName arrs
forall a. IO a -> a
unsafePerformIO (IO (StableAccName arrs) -> StableAccName arrs)
-> IO (StableAccName arrs) -> StableAccName arrs
forall a b. (a -> b) -> a -> b
$ StableName (SmartAcc arrs) -> Int -> StableAccName arrs
forall t. StableName t -> Int -> StableNameHeight t
StableNameHeight (StableName (SmartAcc arrs) -> Int -> StableAccName arrs)
-> IO (StableName (SmartAcc arrs))
-> IO (Int -> StableAccName arrs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SmartAcc arrs -> IO (StableName (SmartAcc arrs))
forall a. a -> IO (StableName a)
makeStableName SmartAcc arrs
forall a. HasCallStack => a
undefined IO (Int -> StableAccName arrs) -> IO Int -> IO (StableAccName arrs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0

-- Stable 'Exp' nodes
-- ------------------

-- Stable name for 'Exp' nodes including the height of the AST.
--
type StableExpName t = StableNameHeight (SmartExp t)

-- Interleave sharing annotations into a scalar expressions AST in the same manner as 'SharingAcc'
-- do for array computations.
--
data SharingExp acc exp t where
  VarSharing :: StableExpName t -> TypeR t               -> SharingExp acc exp t
  LetSharing :: StableSharingExp -> exp t                -> SharingExp acc exp t
  ExpSharing :: StableExpName t -> PreSmartExp acc exp t -> SharingExp acc exp t

instance HasTypeR exp => HasTypeR (SharingExp acc exp) where
  typeR :: SharingExp acc exp t -> TypeR t
typeR (VarSharing StableExpName t
_ TypeR t
tp)  = TypeR t
tp
  typeR (LetSharing StableSharingExp
_ exp t
exp) = exp t -> TypeR t
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
Smart.typeR exp t
exp
  typeR (ExpSharing StableExpName t
_ PreSmartExp acc exp t
exp) = PreSmartExp acc exp t -> TypeR t
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
Smart.typeR PreSmartExp acc exp t
exp

-- Specifies a scalar expression AST with sharing annotations but no scoping; i.e. no LetSharing
-- constructors. If the expression is rooted in a function, the list contains the tags of the
-- variables bound by the immediate surrounding lambdas.
data UnscopedExp t = UnscopedExp [Int] (SharingExp UnscopedAcc UnscopedExp t)

instance HasTypeR UnscopedExp where
  typeR :: UnscopedExp t -> TypeR t
typeR (UnscopedExp [Int]
_ SharingExp UnscopedAcc UnscopedExp t
exp) = SharingExp UnscopedAcc UnscopedExp t -> TypeR t
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
Smart.typeR SharingExp UnscopedAcc UnscopedExp t
exp

-- Specifies a scalar expression AST with sharing. For expressions rooted in functions the list
-- holds a sorted environment corresponding to the variables bound in the immediate surounding
-- lambdas.
data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t)

instance HasTypeR ScopedExp where
  typeR :: ScopedExp t -> TypeR t
typeR (ScopedExp [StableSharingExp]
_ SharingExp ScopedAcc ScopedExp t
exp) = SharingExp ScopedAcc ScopedExp t -> TypeR t
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
Smart.typeR SharingExp ScopedAcc ScopedExp t
exp

-- Expressions rooted in 'SmartAcc' computations.
--
-- * When counting occurrences, the root of every expression embedded in an 'SmartAcc' is annotated by
--   an occurrence map for that one expression (excluding any subterms that are rooted in embedded
--   'SmartAcc's.)
--
data RootExp t = RootExp (OccMap SmartExp) (UnscopedExp t)

-- Stable name for an expression associated with its sharing-annotated version.
--
data StableSharingExp where
  StableSharingExp :: StableExpName t -> SharingExp ScopedAcc ScopedExp t -> StableSharingExp

instance Show StableSharingExp where
  show :: StableSharingExp -> String
show (StableSharingExp StableExpName t
sn SharingExp ScopedAcc ScopedExp t
_) = Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ StableExpName t -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableExpName t
sn

instance Eq StableSharingExp where
  StableSharingExp (StableNameHeight StableName (SmartExp t)
sn1 Int
_) SharingExp ScopedAcc ScopedExp t
_ == :: StableSharingExp -> StableSharingExp -> Bool
== StableSharingExp (StableNameHeight StableName (SmartExp t)
sn2 Int
_) SharingExp ScopedAcc ScopedExp t
_ =
    StableName (SmartExp t) -> StableName (SmartExp t) -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName (SmartExp t)
sn1 StableName (SmartExp t)
sn2

higherSSE :: StableSharingExp -> StableSharingExp -> Bool
StableSharingExp StableExpName t
sn1 SharingExp ScopedAcc ScopedExp t
_ higherSSE :: StableSharingExp -> StableSharingExp -> Bool
`higherSSE` StableSharingExp StableExpName t
sn2 SharingExp ScopedAcc ScopedExp t
_ = StableExpName t
sn1 StableExpName t -> StableExpName t -> Bool
forall t1 t2. StableNameHeight t1 -> StableNameHeight t2 -> Bool
`higherSNH` StableExpName t
sn2

-- Test whether the given stable names matches an expression with sharing.
--
matchStableExp :: StableExpName t -> StableSharingExp -> Bool
matchStableExp :: StableExpName t -> StableSharingExp -> Bool
matchStableExp (StableNameHeight StableName (SmartExp t)
sn1 Int
_) (StableSharingExp (StableNameHeight StableName (SmartExp t)
sn2 Int
_) SharingExp ScopedAcc ScopedExp t
_) = StableName (SmartExp t) -> StableName (SmartExp t) -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName (SmartExp t)
sn1 StableName (SmartExp t)
sn2

-- Dummy entry for environments to be used for unused variables.
--
{-# NOINLINE noStableExpName #-}
noStableExpName :: StableExpName t
noStableExpName :: StableExpName t
noStableExpName = IO (StableExpName t) -> StableExpName t
forall a. IO a -> a
unsafePerformIO (IO (StableExpName t) -> StableExpName t)
-> IO (StableExpName t) -> StableExpName t
forall a b. (a -> b) -> a -> b
$ StableName (SmartExp t) -> Int -> StableExpName t
forall t. StableName t -> Int -> StableNameHeight t
StableNameHeight (StableName (SmartExp t) -> Int -> StableExpName t)
-> IO (StableName (SmartExp t)) -> IO (Int -> StableExpName t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SmartExp t -> IO (StableName (SmartExp t))
forall a. a -> IO (StableName a)
makeStableName SmartExp t
forall a. HasCallStack => a
undefined IO (Int -> StableExpName t) -> IO Int -> IO (StableExpName t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0


{--
-- Stable 'Seq' nodes
-- ------------------

-- Stable name for 'Seq' nodes including the height of the AST.
--
type StableSeqName arrs = StableNameHeight (Seq arrs)

-- Interleave sharing annotations into an sequence computation AST in the same manner as SharingAcc
-- and SharingExp
--
data SharingSeq acc seq exp arrs where
  SvarSharing :: (Typeable arrs, Arrays arrs)
              => StableSeqName [arrs]                       -> SharingSeq acc seq exp [arrs]
  SletSharing :: StableSharingSeq -> seq t                  -> SharingSeq acc seq exp t
  SeqSharing  :: Typeable arrs
              => StableSeqName arrs -> PreSeq acc seq exp arrs -> SharingSeq acc seq exp arrs

-- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If
-- the expression is rooted in a function, the list contains the tags of the variables bound by the
-- immediate surrounding lambdas.
data UnscopedSeq t = UnscopedSeq (SharingSeq UnscopedAcc UnscopedSeq RootExp t)

-- Array expression with sharing. For expressions rooted in functions the list holds a sorted
-- environment corresponding to the variables bound in the immediate surounding lambdas.
data ScopedSeq t = ScopedSeq (SharingSeq ScopedAcc ScopedSeq ScopedExp t)

-- Sequences rooted in 'Acc' computations.
--
-- * When counting occurrences, the root of every sequence embedded in an 'Acc' is annotated by
--   an occurrence map for that one expression (excluding any subterms that are rooted in embedded
--   'Acc's.)
--
data RootSeq t = RootSeq (OccMap Seq) (UnscopedSeq t)

-- Stable name for an array computation associated with its sharing-annotated version.
--
data StableSharingSeq where
  StableSharingSeq :: Typeable arrs
                   => StableSeqName arrs
                   -> SharingSeq ScopedAcc ScopedSeq ScopedExp arrs
                   -> StableSharingSeq

instance Show StableSharingSeq where
  show (StableSharingSeq sn _) = show $ hashStableNameHeight sn

instance Eq StableSharingSeq where
  StableSharingSeq sn1 _ == StableSharingSeq sn2 _
    | Just sn1' <- gcast sn1 = sn1' == sn2
    | otherwise              = False

higherSSS :: StableSharingSeq -> StableSharingSeq -> Bool
StableSharingSeq sn1 _ `higherSSS` StableSharingSeq sn2 _ = sn1 `higherSNH` sn2

-- Test whether the given stable names matches an array computation with sharing.
--
matchStableSeq :: Typeable arrs => StableSeqName arrs -> StableSharingSeq -> Bool
matchStableSeq sn1 (StableSharingSeq sn2 _)
  | Just sn1' <- gcast sn1 = sn1' == sn2
  | otherwise              = False
--}


-- Occurrence counting
-- ===================

-- Compute the 'SmartAcc' occurrence map, marks all nodes (both 'Seq' and 'Exp' nodes) with stable names,
-- and drop repeated occurrences of shared 'SmartAcc' and 'Exp' subtrees (Phase One).
--
-- We compute a single 'SmartAcc' occurrence map for the whole AST, but one 'Exp' occurrence map for each
-- sub-expression rooted in an 'SmartAcc' operation.  This is as we cannot float 'Exp' subtrees across
-- 'SmartAcc' operations, but we can float 'SmartAcc' subtrees out of 'Exp' expressions.
--
-- Note [Traversing functions and side effects]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- We need to descent into function bodies to build the 'OccMap' with all occurrences in the
-- function bodies.  Due to the side effects in the construction of the occurrence map and, more
-- importantly, the dependence of the second phase on /global/ occurrence information, we may not
-- delay the body traversals by putting them under a lambda.  Hence, we apply each function, to
-- traverse its body and use a /dummy abstraction/ of the result.
--
-- For example, given a function 'f', we traverse 'f (Tag 0)', which yields a transformed body 'e'.
-- As the result of the traversal of the overall function, we use 'const e'.  Hence, it is crucial
-- that the 'Tag' supplied during the initial traversal is already the one required by the HOAS to
-- de Bruijn conversion in 'convertSharingAcc' — any subsequent application of 'const e' will only
-- yield 'e' with the embedded 'Tag 0' of the original application.  During sharing recovery, we
-- float /all/ free variables ('Atag' and 'Tag') out to construct the initial environment for
-- producing de Bruijn indices, which replaces them by 'AvarSharing' or 'VarSharing' nodes.  Hence,
-- the tag values only serve the purpose of determining the ordering in that initial environment.
-- They are /not/ directly used to compute the de Brujin indices.
--
makeOccMapAcc
    :: HasCallStack
    => Config
    -> Level
    -> SmartAcc arrs
    -> IO (UnscopedAcc arrs, OccMap SmartAcc)
makeOccMapAcc :: Config
-> Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, OccMap SmartAcc)
makeOccMapAcc Config
config Int
lvl SmartAcc arrs
acc = do
  String -> String -> IO ()
traceLine String
"makeOccMapAcc" String
"Enter"
  HashTable RealWorld (StableASTName SmartAcc) (Int, Int)
accOccMap             <- IO (HashTable RealWorld (StableASTName SmartAcc) (Int, Int))
forall (c :: * -> *) v. IO (ASTHashTable c v)
newASTHashTable
  (UnscopedAcc arrs
acc', Int
_)             <- Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc arrs
-> IO (UnscopedAcc arrs, Int)
forall arrs.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc arrs
-> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc Config
config HashTable RealWorld (StableASTName SmartAcc) (Int, Int)
OccMapHash SmartAcc
accOccMap Int
lvl SmartAcc arrs
acc
  OccMap SmartAcc
frozenAccOccMap       <- OccMapHash SmartAcc -> IO (OccMap SmartAcc)
forall (c :: * -> *). OccMapHash c -> IO (OccMap c)
freezeOccMap HashTable RealWorld (StableASTName SmartAcc) (Int, Int)
OccMapHash SmartAcc
accOccMap
  String -> String -> IO ()
traceLine String
"makeOccMapAcc" String
"Exit"
  (UnscopedAcc arrs, OccMap SmartAcc)
-> IO (UnscopedAcc arrs, OccMap SmartAcc)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedAcc arrs
acc', OccMap SmartAcc
frozenAccOccMap)


makeOccMapSharingAcc
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> SmartAcc arrs
    -> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc :: Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc arrs
-> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc Config
config OccMapHash SmartAcc
accOccMap = Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc
  where
    traverseFun1
        :: HasCallStack
        => Level
        -> TypeR a
        -> (SmartExp a -> SmartExp b)
        -> IO (SmartExp a -> RootExp b, Int)
    traverseFun1 :: Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
traverseFun1 = Config
-> OccMapHash SmartAcc
-> Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
forall a b.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
makeOccMapFun1 Config
config OccMapHash SmartAcc
accOccMap

    traverseFun2
        :: HasCallStack
        => Level
        -> TypeR a
        -> TypeR b
        -> (SmartExp a -> SmartExp b -> SmartExp c)
        -> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
    traverseFun2 :: Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
traverseFun2 = Config
-> OccMapHash SmartAcc
-> Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
forall a b c.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
makeOccMapFun2 Config
config OccMapHash SmartAcc
accOccMap

    traverseAfun1
        :: HasCallStack
        => Level
        -> ArraysR a
        -> (SmartAcc a -> SmartAcc b)
        -> IO (SmartAcc a -> UnscopedAcc b, Int)
    traverseAfun1 :: Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
traverseAfun1 = Config
-> OccMapHash SmartAcc
-> Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
forall a b.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
makeOccMapAfun1 Config
config OccMapHash SmartAcc
accOccMap

    traverseExp
      :: HasCallStack
      => Level
      -> SmartExp e
      -> IO (RootExp e, Int)
    traverseExp :: Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp = Config
-> OccMapHash SmartAcc -> Int -> SmartExp e -> IO (RootExp e, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc -> Int -> SmartExp e -> IO (RootExp e, Int)
makeOccMapExp Config
config OccMapHash SmartAcc
accOccMap

    traverseBoundary
        :: HasCallStack
        => Level
        -> ShapeR sh
        -> PreBoundary SmartAcc SmartExp (Array sh e)
        -> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
    traverseBoundary :: Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh e)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
traverseBoundary Int
lvl ShapeR sh
shr PreBoundary SmartAcc SmartExp (Array sh e)
bndy =
      case PreBoundary SmartAcc SmartExp (Array sh e)
bndy of
        PreBoundary SmartAcc SmartExp (Array sh e)
Clamp      -> (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (PreBoundary UnscopedAcc RootExp (Array sh e)
forall (acc :: * -> *) (exp :: * -> *) t. PreBoundary acc exp t
Clamp, Int
0)
        PreBoundary SmartAcc SmartExp (Array sh e)
Mirror     -> (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (PreBoundary UnscopedAcc RootExp (Array sh e)
forall (acc :: * -> *) (exp :: * -> *) t. PreBoundary acc exp t
Mirror, Int
0)
        PreBoundary SmartAcc SmartExp (Array sh e)
Wrap       -> (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (PreBoundary UnscopedAcc RootExp (Array sh e)
forall (acc :: * -> *) (exp :: * -> *) t. PreBoundary acc exp t
Wrap, Int
0)
        Constant e
v -> (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (e -> PreBoundary UnscopedAcc RootExp (Array sh e)
forall e (acc :: * -> *) (exp :: * -> *) sh.
e -> PreBoundary acc exp (Array sh e)
Constant e
v, Int
0)
        Function SmartExp sh -> SmartExp e
f -> do
          (SmartExp sh -> RootExp e
f', Int
h) <- Int
-> TypeR sh
-> (SmartExp sh -> SmartExp e)
-> IO (SmartExp sh -> RootExp e, Int)
forall a b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
traverseFun1 Int
lvl (ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr) SmartExp sh -> SmartExp e
SmartExp sh -> SmartExp e
f
          (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ((SmartExp sh -> RootExp e)
-> PreBoundary UnscopedAcc RootExp (Array sh e)
forall sh (exp :: * -> *) e (acc :: * -> *).
(SmartExp sh -> exp e) -> PreBoundary acc exp (Array sh e)
Function SmartExp sh -> RootExp e
f', Int
h)

    -- traverseSeq :: forall arrs. Typeable arrs
    --             => Level -> Seq arrs
    --             -> IO (RootSeq arrs, Int)
    -- traverseSeq = makeOccMapRootSeq config accOccMap

    traverseAcc
        :: forall arrs. HasCallStack
        => Level
        -> SmartAcc arrs
        -> IO (UnscopedAcc arrs, Int)
    traverseAcc :: Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl acc :: SmartAcc arrs
acc@(SmartAcc PreSmartAcc SmartAcc SmartExp arrs
pacc)
      = ((UnscopedAcc arrs, Int) -> IO (UnscopedAcc arrs, Int))
-> IO (UnscopedAcc arrs, Int)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (((UnscopedAcc arrs, Int) -> IO (UnscopedAcc arrs, Int))
 -> IO (UnscopedAcc arrs, Int))
-> ((UnscopedAcc arrs, Int) -> IO (UnscopedAcc arrs, Int))
-> IO (UnscopedAcc arrs, Int)
forall a b. (a -> b) -> a -> b
$ \ ~(UnscopedAcc arrs
_, Int
height) -> do
          -- Compute stable name and enter it into the occurrence map
          --
          StableName (SmartAcc arrs)
sn                         <- SmartAcc arrs -> IO (StableName (SmartAcc arrs))
forall (c :: * -> *) t. c t -> IO (StableName (c t))
makeStableAST SmartAcc arrs
acc
          Maybe Int
heightIfRepeatedOccurrence <- OccMapHash SmartAcc
-> StableASTName SmartAcc -> Int -> IO (Maybe Int)
forall (c :: * -> *).
OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc OccMapHash SmartAcc
accOccMap (StableName (SmartAcc arrs) -> StableASTName SmartAcc
forall (c :: * -> *) t. StableName (c t) -> StableASTName c
StableASTName StableName (SmartAcc arrs)
sn) Int
height

          String -> String -> IO ()
traceLine (PreSmartAcc SmartAcc SmartExp arrs -> String
forall (acc :: * -> *) (exp :: * -> *) arrs.
PreSmartAcc acc exp arrs -> String
showPreAccOp PreSmartAcc SmartAcc SmartExp arrs
pacc) (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            let hash :: String
hash = Int -> String
forall a. Show a => a -> String
show (StableName (SmartAcc arrs) -> Int
forall a. StableName a -> Int
hashStableName StableName (SmartAcc arrs)
sn)
            case Maybe Int
heightIfRepeatedOccurrence of
              Just Int
height -> String
"REPEATED occurrence (sn = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
hash String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; height = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
height String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
              Maybe Int
Nothing     -> String
"first occurrence (sn = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
hash String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

          -- Reconstruct the computation in shared form.
          --
          -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise
          -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'.
          --
          let reconstruct :: IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
                          -> IO (UnscopedAcc arrs, Int)
              reconstruct :: IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (UnscopedAcc arrs, Int)
reconstruct IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
newAcc
                = case Maybe Int
heightIfRepeatedOccurrence of
                    Just Int
height | Flag
acc_sharing Flag -> BitSet Word32 Flag -> Bool
forall a c. (Enum a, Bits c) => a -> BitSet c a -> Bool
`member` Config -> BitSet Word32 Flag
options Config
config
                      -> (UnscopedAcc arrs, Int) -> IO (UnscopedAcc arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> SharingAcc UnscopedAcc RootExp arrs -> UnscopedAcc arrs
forall t.
[Int] -> SharingAcc UnscopedAcc RootExp t -> UnscopedAcc t
UnscopedAcc [] (StableAccName arrs
-> ArraysR arrs -> SharingAcc UnscopedAcc RootExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs
AvarSharing (StableName (SmartAcc arrs) -> Int -> StableAccName arrs
forall t. StableName t -> Int -> StableNameHeight t
StableNameHeight StableName (SmartAcc arrs)
sn Int
height) (PreSmartAcc SmartAcc SmartExp arrs -> ArraysR arrs
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR PreSmartAcc SmartAcc SmartExp arrs
pacc)), Int
height)
                    Maybe Int
_ -> do (PreSmartAcc UnscopedAcc RootExp arrs
acc, Int
height) <- IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
newAcc
                            (UnscopedAcc arrs, Int) -> IO (UnscopedAcc arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> SharingAcc UnscopedAcc RootExp arrs -> UnscopedAcc arrs
forall t.
[Int] -> SharingAcc UnscopedAcc RootExp t -> UnscopedAcc t
UnscopedAcc [] (StableAccName arrs
-> PreSmartAcc UnscopedAcc RootExp arrs
-> SharingAcc UnscopedAcc RootExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs
-> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs
AccSharing (StableName (SmartAcc arrs) -> Int -> StableAccName arrs
forall t. StableName t -> Int -> StableNameHeight t
StableNameHeight StableName (SmartAcc arrs)
sn Int
height) PreSmartAcc UnscopedAcc RootExp arrs
acc), Int
height)

          IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (UnscopedAcc arrs, Int)
reconstruct (IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
 -> IO (UnscopedAcc arrs, Int))
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (UnscopedAcc arrs, Int)
forall a b. (a -> b) -> a -> b
$ case PreSmartAcc SmartAcc SmartExp arrs
pacc of
            Atag ArraysR arrs
repr Int
i                 -> (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArraysR arrs -> Int -> PreSmartAcc UnscopedAcc RootExp arrs
forall as (acc :: * -> *) (exp :: * -> *).
ArraysR as -> Int -> PreSmartAcc acc exp as
Atag ArraysR arrs
repr Int
i, Int
0)           -- height is 0!
            Pipe ArraysR as
repr1 ArraysR bs
repr2 ArraysR arrs
repr3 SmartAcc as -> SmartAcc bs
afun1 SmartAcc bs -> SmartAcc arrs
afun2 SmartAcc as
acc
                                        -> do
                                             (SmartAcc as -> UnscopedAcc bs
afun1', Int
h1) <- Int
-> ArraysR as
-> (SmartAcc as -> SmartAcc bs)
-> IO (SmartAcc as -> UnscopedAcc bs, Int)
forall a b.
HasCallStack =>
Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
traverseAfun1 Int
lvl ArraysR as
repr1 SmartAcc as -> SmartAcc bs
afun1
                                             (SmartAcc bs -> UnscopedAcc arrs
afun2', Int
h2) <- Int
-> ArraysR bs
-> (SmartAcc bs -> SmartAcc arrs)
-> IO (SmartAcc bs -> UnscopedAcc arrs, Int)
forall a b.
HasCallStack =>
Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
traverseAfun1 Int
lvl ArraysR bs
repr2 SmartAcc bs -> SmartAcc arrs
afun2
                                             (UnscopedAcc as
acc', Int
h3)   <- Int -> SmartAcc as -> IO (UnscopedAcc as, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc as
acc
                                             (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArraysR as
-> ArraysR bs
-> ArraysR arrs
-> (SmartAcc as -> UnscopedAcc bs)
-> (SmartAcc bs -> UnscopedAcc arrs)
-> UnscopedAcc as
-> PreSmartAcc UnscopedAcc RootExp arrs
forall as bs cs (acc :: * -> *) (exp :: * -> *).
ArraysR as
-> ArraysR bs
-> ArraysR cs
-> (SmartAcc as -> acc bs)
-> (SmartAcc bs -> acc cs)
-> acc as
-> PreSmartAcc acc exp cs
Pipe ArraysR as
repr1 ArraysR bs
repr2 ArraysR arrs
repr3 SmartAcc as -> UnscopedAcc bs
afun1' SmartAcc bs -> UnscopedAcc arrs
afun2' UnscopedAcc as
acc'
                                                    , Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Aforeign ArraysR arrs
repr asm (as -> arrs)
ff SmartAcc as -> SmartAcc arrs
afun SmartAcc as
acc   -> (UnscopedAcc as -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartAcc as -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall arrs'.
HasCallStack =>
(UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travA (ArraysR arrs
-> asm (as -> arrs)
-> (SmartAcc as -> SmartAcc arrs)
-> UnscopedAcc as
-> PreSmartAcc UnscopedAcc RootExp arrs
forall (asm :: * -> *) bs as (acc :: * -> *) (exp :: * -> *).
Foreign asm =>
ArraysR bs
-> asm (as -> bs)
-> (SmartAcc as -> SmartAcc bs)
-> acc as
-> PreSmartAcc acc exp bs
Aforeign ArraysR arrs
repr asm (as -> arrs)
ff SmartAcc as -> SmartAcc arrs
afun) SmartAcc as
acc
            Acond SmartExp PrimBool
e SmartAcc arrs
acc1 SmartAcc arrs
acc2           -> do
                                             (RootExp PrimBool
e'   , Int
h1) <- Int -> SmartExp PrimBool -> IO (RootExp PrimBool, Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp PrimBool
e
                                             (UnscopedAcc arrs
acc1', Int
h2) <- Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs
acc1
                                             (UnscopedAcc arrs
acc2', Int
h3) <- Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs
acc2
                                             (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (RootExp PrimBool
-> UnscopedAcc arrs
-> UnscopedAcc arrs
-> PreSmartAcc UnscopedAcc RootExp arrs
forall (exp :: * -> *) (acc :: * -> *) as.
exp PrimBool -> acc as -> acc as -> PreSmartAcc acc exp as
Acond RootExp PrimBool
e' UnscopedAcc arrs
acc1' UnscopedAcc arrs
acc2', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Awhile ArraysR arrs
repr SmartAcc arrs -> SmartAcc (Scalar PrimBool)
pred SmartAcc arrs -> SmartAcc arrs
iter SmartAcc arrs
init  -> do
                                             (SmartAcc arrs -> UnscopedAcc (Scalar PrimBool)
pred', Int
h1) <- Int
-> ArraysR arrs
-> (SmartAcc arrs -> SmartAcc (Scalar PrimBool))
-> IO (SmartAcc arrs -> UnscopedAcc (Scalar PrimBool), Int)
forall a b.
HasCallStack =>
Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
traverseAfun1 Int
lvl ArraysR arrs
repr SmartAcc arrs -> SmartAcc (Scalar PrimBool)
pred
                                             (SmartAcc arrs -> UnscopedAcc arrs
iter', Int
h2) <- Int
-> ArraysR arrs
-> (SmartAcc arrs -> SmartAcc arrs)
-> IO (SmartAcc arrs -> UnscopedAcc arrs, Int)
forall a b.
HasCallStack =>
Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
traverseAfun1 Int
lvl ArraysR arrs
repr SmartAcc arrs -> SmartAcc arrs
iter
                                             (UnscopedAcc arrs
init', Int
h3) <- Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs
init
                                             (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArraysR arrs
-> (SmartAcc arrs -> UnscopedAcc (Scalar PrimBool))
-> (SmartAcc arrs -> UnscopedAcc arrs)
-> UnscopedAcc arrs
-> PreSmartAcc UnscopedAcc RootExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
ArraysR arrs
-> (SmartAcc arrs -> acc (Scalar PrimBool))
-> (SmartAcc arrs -> acc arrs)
-> acc arrs
-> PreSmartAcc acc exp arrs
Awhile ArraysR arrs
repr SmartAcc arrs -> UnscopedAcc (Scalar PrimBool)
pred' SmartAcc arrs -> UnscopedAcc arrs
iter' UnscopedAcc arrs
init'
                                                    , Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

            PreSmartAcc SmartAcc SmartExp arrs
Anil                        -> (PreSmartAcc UnscopedAcc RootExp (), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (PreSmartAcc UnscopedAcc RootExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartAcc acc exp ()
Anil, Int
0)
            Apair SmartAcc arrs1
acc1 SmartAcc arrs2
acc2             -> do
                                             (UnscopedAcc arrs1
a', Int
h1) <- Int -> SmartAcc arrs1 -> IO (UnscopedAcc arrs1, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs1
acc1
                                             (UnscopedAcc arrs2
b', Int
h2) <- Int -> SmartAcc arrs2 -> IO (UnscopedAcc arrs2, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs2
acc2
                                             (PreSmartAcc UnscopedAcc RootExp (arrs1, arrs2), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (arrs1, arrs2), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedAcc arrs1
-> UnscopedAcc arrs2
-> PreSmartAcc UnscopedAcc RootExp (arrs1, arrs2)
forall (acc :: * -> *) arrs1 arrs2 (exp :: * -> *).
acc arrs1 -> acc arrs2 -> PreSmartAcc acc exp (arrs1, arrs2)
Apair UnscopedAcc arrs1
a' UnscopedAcc arrs2
b', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Aprj PairIdx (arrs1, arrs2) arrs
ix SmartAcc (arrs1, arrs2)
a                   -> (UnscopedAcc (arrs1, arrs2)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartAcc (arrs1, arrs2)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall arrs'.
HasCallStack =>
(UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travA (PairIdx (arrs1, arrs2) arrs
-> UnscopedAcc (arrs1, arrs2)
-> PreSmartAcc UnscopedAcc RootExp arrs
forall arrs1 arrs2 arrs (acc :: * -> *) (exp :: * -> *).
PairIdx (arrs1, arrs2) arrs
-> acc (arrs1, arrs2) -> PreSmartAcc acc exp arrs
Aprj PairIdx (arrs1, arrs2) arrs
ix) SmartAcc (arrs1, arrs2)
a

            Use ArrayR (Array sh e)
repr Array sh e
arr                -> (PreSmartAcc UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArrayR (Array sh e)
-> Array sh e -> PreSmartAcc UnscopedAcc RootExp (Array sh e)
forall sh e (acc :: * -> *) (exp :: * -> *).
ArrayR (Array sh e)
-> Array sh e -> PreSmartAcc acc exp (Array sh e)
Use ArrayR (Array sh e)
repr Array sh e
arr, Int
1)
            Unit TypeR e
tp SmartExp e
e                   -> do
                                             (RootExp e
e', Int
h) <- Int -> SmartExp e -> IO (RootExp e, Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp e
e
                                             (PreSmartAcc UnscopedAcc RootExp (Scalar e), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Scalar e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeR e -> RootExp e -> PreSmartAcc UnscopedAcc RootExp (Scalar e)
forall e (exp :: * -> *) (acc :: * -> *).
TypeR e -> exp e -> PreSmartAcc acc exp (Scalar e)
Unit TypeR e
tp RootExp e
e', Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Generate repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
_) SmartExp sh
e SmartExp sh -> SmartExp e
f
                                        -> do
                                             (RootExp sh
e', Int
h1) <- Int -> SmartExp sh -> IO (RootExp sh, Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp sh
e
                                             (SmartExp sh -> RootExp e
f', Int
h2) <- Int
-> TypeR sh
-> (SmartExp sh -> SmartExp e)
-> IO (SmartExp sh -> RootExp e, Int)
forall a b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
traverseFun1 Int
lvl (ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr) SmartExp sh -> SmartExp e
SmartExp sh -> SmartExp e
f
                                             (PreSmartAcc UnscopedAcc RootExp (Array sh e), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArrayR (Array sh e)
-> RootExp sh
-> (SmartExp sh -> RootExp e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh e)
forall sh e (exp :: * -> *) (acc :: * -> *).
ArrayR (Array sh e)
-> exp sh
-> (SmartExp sh -> exp e)
-> PreSmartAcc acc exp (Array sh e)
Generate ArrayR (Array sh e)
repr RootExp sh
e' SmartExp sh -> RootExp e
SmartExp sh -> RootExp e
f', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Reshape ShapeR sh
shr SmartExp sh
e SmartAcc (Array sh' e)
acc           -> (RootExp sh
 -> UnscopedAcc (Array sh' e)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp sh
-> SmartAcc (Array sh' e)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b arrs'.
HasCallStack =>
(RootExp b
 -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp b
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travEA (ShapeR sh
-> RootExp sh
-> UnscopedAcc (Array sh' e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh e)
forall sh (exp :: * -> *) (acc :: * -> *) sh' e.
ShapeR sh
-> exp sh -> acc (Array sh' e) -> PreSmartAcc acc exp (Array sh e)
Reshape ShapeR sh
shr) SmartExp sh
e SmartAcc (Array sh' e)
acc
            Replicate SliceIndex slix sl co sh
si SmartExp slix
e SmartAcc (Array sl e)
acc          -> (RootExp slix
 -> UnscopedAcc (Array sl e)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp slix
-> SmartAcc (Array sl e)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b arrs'.
HasCallStack =>
(RootExp b
 -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp b
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travEA (SliceIndex slix sl co sh
-> RootExp slix
-> UnscopedAcc (Array sl e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh e)
forall slix sl co sh (exp :: * -> *) (acc :: * -> *) e.
SliceIndex slix sl co sh
-> exp slix -> acc (Array sl e) -> PreSmartAcc acc exp (Array sh e)
Replicate SliceIndex slix sl co sh
si) SmartExp slix
e SmartAcc (Array sl e)
acc
            Slice SliceIndex slix sl co sh
si SmartAcc (Array sh e)
acc SmartExp slix
e              -> (RootExp slix
 -> UnscopedAcc (Array sh e)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp slix
-> SmartAcc (Array sh e)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b arrs'.
HasCallStack =>
(RootExp b
 -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp b
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travEA ((UnscopedAcc (Array sh e)
 -> RootExp slix -> PreSmartAcc UnscopedAcc RootExp (Array sl e))
-> RootExp slix
-> UnscopedAcc (Array sh e)
-> PreSmartAcc UnscopedAcc RootExp (Array sl e)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((UnscopedAcc (Array sh e)
  -> RootExp slix -> PreSmartAcc UnscopedAcc RootExp (Array sl e))
 -> RootExp slix
 -> UnscopedAcc (Array sh e)
 -> PreSmartAcc UnscopedAcc RootExp (Array sl e))
-> (UnscopedAcc (Array sh e)
    -> RootExp slix -> PreSmartAcc UnscopedAcc RootExp (Array sl e))
-> RootExp slix
-> UnscopedAcc (Array sh e)
-> PreSmartAcc UnscopedAcc RootExp (Array sl e)
forall a b. (a -> b) -> a -> b
$ SliceIndex slix sl co sh
-> UnscopedAcc (Array sh e)
-> RootExp slix
-> PreSmartAcc UnscopedAcc RootExp (Array sl e)
forall slix sl co sh (acc :: * -> *) e' (exp :: * -> *).
SliceIndex slix sl co sh
-> acc (Array sh e')
-> exp slix
-> PreSmartAcc acc exp (Array sl e')
Slice SliceIndex slix sl co sh
si) SmartExp slix
e SmartAcc (Array sh e)
acc
            Map TypeR e
t1 TypeR e'
t2 SmartExp e -> SmartExp e'
f SmartAcc (Array sh e)
acc             -> do
                                             (SmartExp e -> RootExp e'
f'  , Int
h1) <- Int
-> TypeR e
-> (SmartExp e -> SmartExp e')
-> IO (SmartExp e -> RootExp e', Int)
forall a b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
traverseFun1 Int
lvl TypeR e
t1 SmartExp e -> SmartExp e'
f
                                             (UnscopedAcc (Array sh e)
acc', Int
h2) <- Int -> SmartAcc (Array sh e) -> IO (UnscopedAcc (Array sh e), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh e)
acc
                                             (PreSmartAcc UnscopedAcc RootExp (Array sh e'), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh e'), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeR e
-> TypeR e'
-> (SmartExp e -> RootExp e')
-> UnscopedAcc (Array sh e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh e')
forall e e' (exp :: * -> *) (acc :: * -> *) sh.
TypeR e
-> TypeR e'
-> (SmartExp e -> exp e')
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh e')
Map TypeR e
t1 TypeR e'
t2 SmartExp e -> RootExp e'
f' UnscopedAcc (Array sh e)
acc', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            ZipWith TypeR e1
t1 TypeR e2
t2 TypeR e3
t3 SmartExp e1 -> SmartExp e2 -> SmartExp e3
f SmartAcc (Array sh e1)
acc1 SmartAcc (Array sh e2)
acc2
                                        -> ((SmartExp e1 -> SmartExp e2 -> RootExp e3)
 -> UnscopedAcc (Array sh e1)
 -> UnscopedAcc (Array sh e2)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR e1
-> TypeR e2
-> (SmartExp e1 -> SmartExp e2 -> SmartExp e3)
-> SmartAcc (Array sh e1)
-> SmartAcc (Array sh e2)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b c d arrs1 arrs2.
HasCallStack =>
((SmartExp b -> SmartExp c -> RootExp d)
 -> UnscopedAcc arrs1
 -> UnscopedAcc arrs2
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> SmartAcc arrs1
-> SmartAcc arrs2
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2A2 (TypeR e1
-> TypeR e2
-> TypeR e3
-> (SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> UnscopedAcc (Array sh e1)
-> UnscopedAcc (Array sh e2)
-> PreSmartAcc UnscopedAcc RootExp (Array sh e3)
forall e1 e2 e3 (exp :: * -> *) (acc :: * -> *) sh.
TypeR e1
-> TypeR e2
-> TypeR e3
-> (SmartExp e1 -> SmartExp e2 -> exp e3)
-> acc (Array sh e1)
-> acc (Array sh e2)
-> PreSmartAcc acc exp (Array sh e3)
ZipWith TypeR e1
t1 TypeR e2
t2 TypeR e3
t3) TypeR e1
t1 TypeR e2
t2 SmartExp e1 -> SmartExp e2 -> SmartExp e3
f SmartAcc (Array sh e1)
acc1 SmartAcc (Array sh e2)
acc2
            Fold TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f Maybe (SmartExp e)
e SmartAcc (Array (sh, Int) e)
acc             -> ((SmartExp e -> SmartExp e -> RootExp e)
 -> Maybe (RootExp e)
 -> UnscopedAcc (Array (sh, Int) e)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> SmartExp e)
-> Maybe (SmartExp e)
-> SmartAcc (Array (sh, Int) e)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b c d e arrs'.
HasCallStack =>
((SmartExp b -> SmartExp c -> RootExp d)
 -> Maybe (RootExp e)
 -> UnscopedAcc arrs'
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> Maybe (SmartExp e)
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2MEA (TypeR e
-> (SmartExp e -> SmartExp e -> RootExp e)
-> Maybe (RootExp e)
-> UnscopedAcc (Array (sh, Int) e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh e)
forall e (exp :: * -> *) (acc :: * -> *) i.
TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (i, Int) e)
-> PreSmartAcc acc exp (Array i e)
Fold TypeR e
tp) TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f Maybe (SmartExp e)
e SmartAcc (Array (sh, Int) e)
acc
            FoldSeg IntegralType i
i TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f Maybe (SmartExp e)
e SmartAcc (Array (sh, Int) e)
acc1 SmartAcc (Segments i)
acc2  -> do
                                             (SmartExp e -> SmartExp e -> RootExp e
f'   , Int
h1) <- Int
-> TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> SmartExp e)
-> IO (SmartExp e -> SmartExp e -> RootExp e, Int)
forall a b c.
HasCallStack =>
Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
traverseFun2 Int
lvl TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f
                                             (Maybe (RootExp e)
e'   , Int
h2) <- Maybe (SmartExp e) -> IO (Maybe (RootExp e), Int)
forall t.
HasCallStack =>
Maybe (SmartExp t) -> IO (Maybe (RootExp t), Int)
travME Maybe (SmartExp e)
e
                                             (UnscopedAcc (Array (sh, Int) e)
acc1', Int
h3) <- Int
-> SmartAcc (Array (sh, Int) e)
-> IO (UnscopedAcc (Array (sh, Int) e), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array (sh, Int) e)
acc1
                                             (UnscopedAcc (Segments i)
acc2', Int
h4) <- Int -> SmartAcc (Segments i) -> IO (UnscopedAcc (Segments i), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Segments i)
acc2
                                             (PreSmartAcc UnscopedAcc RootExp (Array (sh, Int) e), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array (sh, Int) e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (IntegralType i
-> TypeR e
-> (SmartExp e -> SmartExp e -> RootExp e)
-> Maybe (RootExp e)
-> UnscopedAcc (Array (sh, Int) e)
-> UnscopedAcc (Segments i)
-> PreSmartAcc UnscopedAcc RootExp (Array (sh, Int) e)
forall i e (exp :: * -> *) (acc :: * -> *) sh.
IntegralType i
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> acc (Segments i)
-> PreSmartAcc acc exp (Array (sh, Int) e)
FoldSeg IntegralType i
i TypeR e
tp SmartExp e -> SmartExp e -> RootExp e
f' Maybe (RootExp e)
e' UnscopedAcc (Array (sh, Int) e)
acc1' UnscopedAcc (Segments i)
acc2',
                                                     Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Scan  Direction
d TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f Maybe (SmartExp e)
e SmartAcc (Array (sh, Int) e)
acc          -> ((SmartExp e -> SmartExp e -> RootExp e)
 -> Maybe (RootExp e)
 -> UnscopedAcc (Array (sh, Int) e)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> SmartExp e)
-> Maybe (SmartExp e)
-> SmartAcc (Array (sh, Int) e)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b c d e arrs'.
HasCallStack =>
((SmartExp b -> SmartExp c -> RootExp d)
 -> Maybe (RootExp e)
 -> UnscopedAcc arrs'
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> Maybe (SmartExp e)
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2MEA (Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> RootExp e)
-> Maybe (RootExp e)
-> UnscopedAcc (Array (sh, Int) e)
-> PreSmartAcc UnscopedAcc RootExp (Array (sh, Int) e)
forall e (exp :: * -> *) (acc :: * -> *) e.
Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (e, Int) e)
-> PreSmartAcc acc exp (Array (e, Int) e)
Scan  Direction
d TypeR e
tp) TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f Maybe (SmartExp e)
e SmartAcc (Array (sh, Int) e)
acc
            Scan' Direction
d TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f SmartExp e
e SmartAcc (Array (sh, Int) e)
acc          -> ((SmartExp e -> SmartExp e -> RootExp e)
 -> RootExp e
 -> UnscopedAcc (Array (sh, Int) e)
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> SmartExp e)
-> SmartExp e
-> SmartAcc (Array (sh, Int) e)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall b c d e arrs'.
HasCallStack =>
((SmartExp b -> SmartExp c -> RootExp d)
 -> RootExp e
 -> UnscopedAcc arrs'
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> SmartExp e
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2EA (Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> RootExp e)
-> RootExp e
-> UnscopedAcc (Array (sh, Int) e)
-> PreSmartAcc UnscopedAcc RootExp (Array (sh, Int) e, Array sh e)
forall e (exp :: * -> *) (acc :: * -> *) sh.
Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> exp e
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e)
Scan' Direction
d TypeR e
tp) TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
f SmartExp e
e SmartAcc (Array (sh, Int) e)
acc
            Permute repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) SmartExp e -> SmartExp e -> SmartExp e
c SmartAcc (Array sh' e)
acc1 SmartExp sh -> SmartExp (PrimMaybe sh')
p SmartAcc (Array sh e)
acc2
                                        -> do
                                             (SmartExp e -> SmartExp e -> RootExp e
c'   , Int
h1) <- Int
-> TypeR e
-> TypeR e
-> (SmartExp e -> SmartExp e -> SmartExp e)
-> IO (SmartExp e -> SmartExp e -> RootExp e, Int)
forall a b c.
HasCallStack =>
Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
traverseFun2 Int
lvl TypeR e
tp TypeR e
tp SmartExp e -> SmartExp e -> SmartExp e
SmartExp e -> SmartExp e -> SmartExp e
c
                                             (SmartExp sh -> RootExp (PrimMaybe sh')
p'   , Int
h2) <- Int
-> TypeR sh
-> (SmartExp sh -> SmartExp (PrimMaybe sh'))
-> IO (SmartExp sh -> RootExp (PrimMaybe sh'), Int)
forall a b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
traverseFun1 Int
lvl (ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr) SmartExp sh -> SmartExp (PrimMaybe sh')
SmartExp sh -> SmartExp (PrimMaybe sh')
p
                                             (UnscopedAcc (Array sh' e)
acc1', Int
h3) <- Int
-> SmartAcc (Array sh' e) -> IO (UnscopedAcc (Array sh' e), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh' e)
acc1
                                             (UnscopedAcc (Array sh e)
acc2', Int
h4) <- Int -> SmartAcc (Array sh e) -> IO (UnscopedAcc (Array sh e), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh e)
acc2
                                             (PreSmartAcc UnscopedAcc RootExp (Array sh' e), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh' e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArrayR (Array sh e)
-> (SmartExp e -> SmartExp e -> RootExp e)
-> UnscopedAcc (Array sh' e)
-> (SmartExp sh -> RootExp (PrimMaybe sh'))
-> UnscopedAcc (Array sh e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh' e)
forall sh e (exp :: * -> *) (acc :: * -> *) sh'.
ArrayR (Array sh e)
-> (SmartExp e -> SmartExp e -> exp e)
-> acc (Array sh' e)
-> (SmartExp sh -> exp (PrimMaybe sh'))
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Permute ArrayR (Array sh e)
repr SmartExp e -> SmartExp e -> RootExp e
SmartExp e -> SmartExp e -> RootExp e
c' UnscopedAcc (Array sh' e)
acc1' SmartExp sh -> RootExp (PrimMaybe sh')
SmartExp sh -> RootExp (PrimMaybe sh')
p' UnscopedAcc (Array sh e)
acc2',
                                                     Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Backpermute ShapeR sh'
shr SmartExp sh'
e SmartExp sh' -> SmartExp sh
p SmartAcc (Array sh e)
acc     -> do
                                             (RootExp sh'
e'  , Int
h1) <- Int -> SmartExp sh' -> IO (RootExp sh', Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp sh'
e
                                             (SmartExp sh' -> RootExp sh
p'  , Int
h2) <- Int
-> TypeR sh'
-> (SmartExp sh' -> SmartExp sh)
-> IO (SmartExp sh' -> RootExp sh, Int)
forall a b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
traverseFun1 Int
lvl (ShapeR sh' -> TypeR sh'
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh'
shr) SmartExp sh' -> SmartExp sh
p
                                             (UnscopedAcc (Array sh e)
acc', Int
h3) <- Int -> SmartAcc (Array sh e) -> IO (UnscopedAcc (Array sh e), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh e)
acc
                                             (PreSmartAcc UnscopedAcc RootExp (Array sh' e), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh' e), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ShapeR sh'
-> RootExp sh'
-> (SmartExp sh' -> RootExp sh)
-> UnscopedAcc (Array sh e)
-> PreSmartAcc UnscopedAcc RootExp (Array sh' e)
forall sh' (exp :: * -> *) sh (acc :: * -> *) e.
ShapeR sh'
-> exp sh'
-> (SmartExp sh' -> exp sh)
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Backpermute ShapeR sh'
shr RootExp sh'
e' SmartExp sh' -> RootExp sh
p' UnscopedAcc (Array sh e)
acc', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Stencil StencilR sh a stencil
s TypeR b
tp SmartExp stencil -> SmartExp b
f PreBoundary SmartAcc SmartExp (Array sh a)
bnd SmartAcc (Array sh a)
acc      -> do
                                             (SmartExp stencil -> RootExp b
f'  , Int
h1) <- Config
-> OccMapHash SmartAcc
-> StencilR sh a stencil
-> Int
-> (SmartExp stencil -> SmartExp b)
-> IO (SmartExp stencil -> RootExp b, Int)
forall sh a b stencil.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> StencilR sh a stencil
-> Int
-> (SmartExp stencil -> SmartExp b)
-> IO (SmartExp stencil -> RootExp b, Int)
makeOccMapStencil1 Config
config OccMapHash SmartAcc
accOccMap StencilR sh a stencil
s Int
lvl SmartExp stencil -> SmartExp b
f
                                             (PreBoundary UnscopedAcc RootExp (Array sh a)
bnd', Int
h2) <- Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh a)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh a), Int)
forall sh e.
HasCallStack =>
Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh e)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
traverseBoundary Int
lvl (StencilR sh a stencil -> ShapeR sh
forall sh e pat. StencilR sh e pat -> ShapeR sh
stencilShapeR StencilR sh a stencil
s) PreBoundary SmartAcc SmartExp (Array sh a)
bnd
                                             (UnscopedAcc (Array sh a)
acc', Int
h3) <- Int -> SmartAcc (Array sh a) -> IO (UnscopedAcc (Array sh a), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh a)
acc
                                             (PreSmartAcc UnscopedAcc RootExp (Array sh b), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh b), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (StencilR sh a stencil
-> TypeR b
-> (SmartExp stencil -> RootExp b)
-> PreBoundary UnscopedAcc RootExp (Array sh a)
-> UnscopedAcc (Array sh a)
-> PreSmartAcc UnscopedAcc RootExp (Array sh b)
forall sh a stencil sh (exp :: * -> *) (acc :: * -> *).
StencilR sh a stencil
-> TypeR sh
-> (SmartExp stencil -> exp sh)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreSmartAcc acc exp (Array sh sh)
Stencil StencilR sh a stencil
s TypeR b
tp SmartExp stencil -> RootExp b
f' PreBoundary UnscopedAcc RootExp (Array sh a)
bnd' UnscopedAcc (Array sh a)
acc', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Stencil2 StencilR sh a stencil1
s1 StencilR sh b stencil2
s2 TypeR c
tp SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c
f PreBoundary SmartAcc SmartExp (Array sh a)
bnd1 SmartAcc (Array sh a)
acc1
                              PreBoundary SmartAcc SmartExp (Array sh b)
bnd2 SmartAcc (Array sh b)
acc2 -> do
                                             let shr :: ShapeR sh
shr = StencilR sh a stencil1 -> ShapeR sh
forall sh e pat. StencilR sh e pat -> ShapeR sh
stencilShapeR StencilR sh a stencil1
s1
                                             (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c
f'   , Int
h1) <- Config
-> OccMapHash SmartAcc
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> Int
-> (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c)
-> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
forall sh a b c stencil1 stencil2.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> Int
-> (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c)
-> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
makeOccMapStencil2 Config
config OccMapHash SmartAcc
accOccMap StencilR sh a stencil1
s1 StencilR sh b stencil2
s2 Int
lvl SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c
f
                                             (PreBoundary UnscopedAcc RootExp (Array sh a)
bnd1', Int
h2) <- Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh a)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh a), Int)
forall sh e.
HasCallStack =>
Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh e)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
traverseBoundary Int
lvl ShapeR sh
shr PreBoundary SmartAcc SmartExp (Array sh a)
bnd1
                                             (UnscopedAcc (Array sh a)
acc1', Int
h3) <- Int -> SmartAcc (Array sh a) -> IO (UnscopedAcc (Array sh a), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh a)
acc1
                                             (PreBoundary UnscopedAcc RootExp (Array sh b)
bnd2', Int
h4) <- Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh b)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh b), Int)
forall sh e.
HasCallStack =>
Int
-> ShapeR sh
-> PreBoundary SmartAcc SmartExp (Array sh e)
-> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
traverseBoundary Int
lvl ShapeR sh
shr PreBoundary SmartAcc SmartExp (Array sh b)
bnd2
                                             (UnscopedAcc (Array sh b)
acc2', Int
h5) <- Int -> SmartAcc (Array sh b) -> IO (UnscopedAcc (Array sh b), Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc (Array sh b)
acc2
                                             (PreSmartAcc UnscopedAcc RootExp (Array sh c), Int)
-> IO (PreSmartAcc UnscopedAcc RootExp (Array sh c), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c)
-> PreBoundary UnscopedAcc RootExp (Array sh a)
-> UnscopedAcc (Array sh a)
-> PreBoundary UnscopedAcc RootExp (Array sh b)
-> UnscopedAcc (Array sh b)
-> PreSmartAcc UnscopedAcc RootExp (Array sh c)
forall sh a stencil1 b stencil2 c (exp :: * -> *) (acc :: * -> *).
StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> (SmartExp stencil1 -> SmartExp stencil2 -> exp c)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreBoundary acc exp (Array sh b)
-> acc (Array sh b)
-> PreSmartAcc acc exp (Array sh c)
Stencil2 StencilR sh a stencil1
s1 StencilR sh b stencil2
s2 TypeR c
tp SmartExp stencil1 -> SmartExp stencil2 -> RootExp c
f' PreBoundary UnscopedAcc RootExp (Array sh a)
bnd1' UnscopedAcc (Array sh a)
acc1' PreBoundary UnscopedAcc RootExp (Array sh b)
bnd2' UnscopedAcc (Array sh b)
acc2',
                                                     Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h4 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h5 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            -- Collect s                   -> do
            --                                  (s', h) <- traverseSeq lvl s
            --                                  return (Collect s', h + 1)


      where
        travA :: HasCallStack
              => (UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
              -> SmartAcc arrs'
              -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travA :: (UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travA UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs
c SmartAcc arrs'
acc
          = do
              (UnscopedAcc arrs'
acc', Int
h) <- Int -> SmartAcc arrs' -> IO (UnscopedAcc arrs', Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs'
acc
              (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs
c UnscopedAcc arrs'
acc', Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travEA :: HasCallStack
               => (RootExp b -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
               -> SmartExp b
               -> SmartAcc arrs'
               -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travEA :: (RootExp b
 -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
-> SmartExp b
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travEA RootExp b
-> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs
c SmartExp b
exp SmartAcc arrs'
acc
          = do
              (RootExp b
exp', Int
h1) <- Int -> SmartExp b -> IO (RootExp b, Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp b
exp
              (UnscopedAcc arrs'
acc', Int
h2) <- Int -> SmartAcc arrs' -> IO (UnscopedAcc arrs', Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs'
acc
              (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (RootExp b
-> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs
c RootExp b
exp' UnscopedAcc arrs'
acc', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travF2EA
            :: HasCallStack
            => ((SmartExp b -> SmartExp c -> RootExp d) -> RootExp e -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
            -> TypeR b
            -> TypeR c
            -> (SmartExp b -> SmartExp c -> SmartExp d)
            -> SmartExp e
            -> SmartAcc arrs'
            -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travF2EA :: ((SmartExp b -> SmartExp c -> RootExp d)
 -> RootExp e
 -> UnscopedAcc arrs'
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> SmartExp e
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2EA (SmartExp b -> SmartExp c -> RootExp d)
-> RootExp e
-> UnscopedAcc arrs'
-> PreSmartAcc UnscopedAcc RootExp arrs
c TypeR b
t1 TypeR c
t2 SmartExp b -> SmartExp c -> SmartExp d
fun SmartExp e
exp SmartAcc arrs'
acc
          = do
              (SmartExp b -> SmartExp c -> RootExp d
fun', Int
h1) <- Int
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> IO (SmartExp b -> SmartExp c -> RootExp d, Int)
forall a b c.
HasCallStack =>
Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
traverseFun2 Int
lvl TypeR b
t1 TypeR c
t2 SmartExp b -> SmartExp c -> SmartExp d
fun
              (RootExp e
exp', Int
h2) <- Int -> SmartExp e -> IO (RootExp e, Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp e
exp
              (UnscopedAcc arrs'
acc', Int
h3) <- Int -> SmartAcc arrs' -> IO (UnscopedAcc arrs', Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs'
acc
              (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ((SmartExp b -> SmartExp c -> RootExp d)
-> RootExp e
-> UnscopedAcc arrs'
-> PreSmartAcc UnscopedAcc RootExp arrs
c SmartExp b -> SmartExp c -> RootExp d
fun' RootExp e
exp' UnscopedAcc arrs'
acc', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travF2MEA
            :: HasCallStack
            => ((SmartExp b -> SmartExp c -> RootExp d) -> Maybe (RootExp e) -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
            -> TypeR b
            -> TypeR c
            -> (SmartExp b -> SmartExp c -> SmartExp d)
            -> Maybe (SmartExp e)
            -> SmartAcc arrs'
            -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travF2MEA :: ((SmartExp b -> SmartExp c -> RootExp d)
 -> Maybe (RootExp e)
 -> UnscopedAcc arrs'
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> Maybe (SmartExp e)
-> SmartAcc arrs'
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2MEA (SmartExp b -> SmartExp c -> RootExp d)
-> Maybe (RootExp e)
-> UnscopedAcc arrs'
-> PreSmartAcc UnscopedAcc RootExp arrs
c TypeR b
t1 TypeR c
t2 SmartExp b -> SmartExp c -> SmartExp d
fun Maybe (SmartExp e)
exp SmartAcc arrs'
acc
          = do
              (SmartExp b -> SmartExp c -> RootExp d
fun', Int
h1) <- Int
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> IO (SmartExp b -> SmartExp c -> RootExp d, Int)
forall a b c.
HasCallStack =>
Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
traverseFun2 Int
lvl TypeR b
t1 TypeR c
t2 SmartExp b -> SmartExp c -> SmartExp d
fun
              (Maybe (RootExp e)
exp', Int
h2) <- Maybe (SmartExp e) -> IO (Maybe (RootExp e), Int)
forall t.
HasCallStack =>
Maybe (SmartExp t) -> IO (Maybe (RootExp t), Int)
travME Maybe (SmartExp e)
exp
              (UnscopedAcc arrs'
acc', Int
h3) <- Int -> SmartAcc arrs' -> IO (UnscopedAcc arrs', Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs'
acc
              (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ((SmartExp b -> SmartExp c -> RootExp d)
-> Maybe (RootExp e)
-> UnscopedAcc arrs'
-> PreSmartAcc UnscopedAcc RootExp arrs
c SmartExp b -> SmartExp c -> RootExp d
fun' Maybe (RootExp e)
exp' UnscopedAcc arrs'
acc', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travME :: HasCallStack => Maybe (SmartExp t) -> IO (Maybe (RootExp t), Int)
        travME :: Maybe (SmartExp t) -> IO (Maybe (RootExp t), Int)
travME Maybe (SmartExp t)
Nothing  = (Maybe (RootExp t), Int) -> IO (Maybe (RootExp t), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (RootExp t)
forall a. Maybe a
Nothing, Int
0)
        travME (Just SmartExp t
e) = do
          (RootExp t
e', Int
c) <- Int -> SmartExp t -> IO (RootExp t, Int)
forall e. HasCallStack => Int -> SmartExp e -> IO (RootExp e, Int)
traverseExp Int
lvl SmartExp t
e
          (Maybe (RootExp t), Int) -> IO (Maybe (RootExp t), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (RootExp t -> Maybe (RootExp t)
forall a. a -> Maybe a
Just RootExp t
e', Int
c)

        travF2A2
            :: HasCallStack
            => ((SmartExp b -> SmartExp c -> RootExp d) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> PreSmartAcc UnscopedAcc RootExp arrs)
            -> TypeR b
            -> TypeR c
            -> (SmartExp b -> SmartExp c -> SmartExp d)
            -> SmartAcc arrs1
            -> SmartAcc arrs2
            -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travF2A2 :: ((SmartExp b -> SmartExp c -> RootExp d)
 -> UnscopedAcc arrs1
 -> UnscopedAcc arrs2
 -> PreSmartAcc UnscopedAcc RootExp arrs)
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> SmartAcc arrs1
-> SmartAcc arrs2
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
travF2A2 (SmartExp b -> SmartExp c -> RootExp d)
-> UnscopedAcc arrs1
-> UnscopedAcc arrs2
-> PreSmartAcc UnscopedAcc RootExp arrs
c TypeR b
t1 TypeR c
t2 SmartExp b -> SmartExp c -> SmartExp d
fun SmartAcc arrs1
acc1 SmartAcc arrs2
acc2
          = do
              (SmartExp b -> SmartExp c -> RootExp d
fun' , Int
h1) <- Int
-> TypeR b
-> TypeR c
-> (SmartExp b -> SmartExp c -> SmartExp d)
-> IO (SmartExp b -> SmartExp c -> RootExp d, Int)
forall a b c.
HasCallStack =>
Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
traverseFun2 Int
lvl TypeR b
t1 TypeR c
t2 SmartExp b -> SmartExp c -> SmartExp d
fun
              (UnscopedAcc arrs1
acc1', Int
h2) <- Int -> SmartAcc arrs1 -> IO (UnscopedAcc arrs1, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs1
acc1
              (UnscopedAcc arrs2
acc2', Int
h3) <- Int -> SmartAcc arrs2 -> IO (UnscopedAcc arrs2, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc arrs2
acc2
              (PreSmartAcc UnscopedAcc RootExp arrs, Int)
-> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ((SmartExp b -> SmartExp c -> RootExp d)
-> UnscopedAcc arrs1
-> UnscopedAcc arrs2
-> PreSmartAcc UnscopedAcc RootExp arrs
c SmartExp b -> SmartExp c -> RootExp d
fun' UnscopedAcc arrs1
acc1' UnscopedAcc arrs2
acc2', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

makeOccMapAfun1
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> ArraysR a
    -> (SmartAcc a -> SmartAcc b)
    -> IO (SmartAcc a -> UnscopedAcc b, Int)
makeOccMapAfun1 :: Config
-> OccMapHash SmartAcc
-> Int
-> ArraysR a
-> (SmartAcc a -> SmartAcc b)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
makeOccMapAfun1 Config
config OccMapHash SmartAcc
accOccMap Int
lvl ArraysR a
repr SmartAcc a -> SmartAcc b
f = do
  let x :: SmartAcc a
x = PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
forall a. PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
SmartAcc (ArraysR a -> Int -> PreSmartAcc SmartAcc SmartExp a
forall as (acc :: * -> *) (exp :: * -> *).
ArraysR as -> Int -> PreSmartAcc acc exp as
Atag ArraysR a
repr Int
lvl)
  --
  (UnscopedAcc [] SharingAcc UnscopedAcc RootExp b
body, Int
height) <- Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc b
-> IO (UnscopedAcc b, Int)
forall arrs.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc arrs
-> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc Config
config OccMapHash SmartAcc
accOccMap (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (SmartAcc a -> SmartAcc b
f SmartAcc a
x)
  (SmartAcc a -> UnscopedAcc b, Int)
-> IO (SmartAcc a -> UnscopedAcc b, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedAcc b -> SmartAcc a -> UnscopedAcc b
forall a b. a -> b -> a
const ([Int] -> SharingAcc UnscopedAcc RootExp b -> UnscopedAcc b
forall t.
[Int] -> SharingAcc UnscopedAcc RootExp t -> UnscopedAcc t
UnscopedAcc [Int
Item [Int]
lvl] SharingAcc UnscopedAcc RootExp b
body), Int
height)

{--
makeOccMapAfun2 :: (Arrays a, Arrays b, Typeable c)
                => Config
                -> OccMapHash Acc
                -> Level
                -> (Acc a -> Acc b -> Acc c)
                -> IO (Acc a -> Acc b -> UnscopedAcc c, Int)
makeOccMapAfun2 config accOccMap lvl f = do
  let x = Acc (Atag (lvl + 1))
      y = Acc (Atag (lvl + 0))
  --
  (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+2) (f x y)
  return (\ _ _ -> (UnscopedAcc [lvl, lvl+1] body), height)

makeOccMapAfun3 :: (Arrays a, Arrays b, Arrays c, Typeable d)
                => Config
                -> OccMapHash Acc
                -> Level
                -> (Acc a -> Acc b -> Acc c -> Acc d)
                -> IO (Acc a -> Acc b -> Acc c -> UnscopedAcc d, Int)
makeOccMapAfun3 config accOccMap lvl f = do
  let x = Acc (Atag (lvl + 2))
      y = Acc (Atag (lvl + 1))
      z = Acc (Atag (lvl + 0))
  --
  (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+3) (f x y z)
  return (\ _ _ _ -> (UnscopedAcc [lvl, lvl+1, lvl+2] body), height)
--}

-- Generate occupancy information for scalar functions and expressions. Helper
-- functions wrapping around 'makeOccMapRootExp' with more specific types.
--
-- See Note [Traversing functions and side effects]
--
makeOccMapExp
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> SmartExp e
    -> IO (RootExp e, Int)
makeOccMapExp :: Config
-> OccMapHash SmartAcc -> Int -> SmartExp e -> IO (RootExp e, Int)
makeOccMapExp Config
config OccMapHash SmartAcc
accOccMap Int
lvl = Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config OccMapHash SmartAcc
accOccMap Int
lvl []

makeOccMapFun1
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> TypeR a
    -> (SmartExp a -> SmartExp b)
    -> IO (SmartExp a -> RootExp b, Int)
makeOccMapFun1 :: Config
-> OccMapHash SmartAcc
-> Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> RootExp b, Int)
makeOccMapFun1 Config
config OccMapHash SmartAcc
accOccMap Int
lvl TypeR a
tp SmartExp a -> SmartExp b
f = do
  let x :: SmartExp a
x = PreSmartExp SmartAcc SmartExp a -> SmartExp a
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR a -> Int -> PreSmartExp SmartAcc SmartExp a
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR a
tp Int
lvl)
  --
  (RootExp b
body, Int
height) <- Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp b
-> IO (RootExp b, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config OccMapHash SmartAcc
accOccMap (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Int
Item [Int]
lvl] (SmartExp a -> SmartExp b
f SmartExp a
x)
  (SmartExp a -> RootExp b, Int) -> IO (SmartExp a -> RootExp b, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (RootExp b -> SmartExp a -> RootExp b
forall a b. a -> b -> a
const RootExp b
body, Int
height)

makeOccMapFun2
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> TypeR a
    -> TypeR b
    -> (SmartExp a -> SmartExp b -> SmartExp c)
    -> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
makeOccMapFun2 :: Config
-> OccMapHash SmartAcc
-> Int
-> TypeR a
-> TypeR b
-> (SmartExp a -> SmartExp b -> SmartExp c)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
makeOccMapFun2 Config
config OccMapHash SmartAcc
accOccMap Int
lvl TypeR a
t1 TypeR b
t2 SmartExp a -> SmartExp b -> SmartExp c
f = do
  let x :: SmartExp a
x = PreSmartExp SmartAcc SmartExp a -> SmartExp a
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR a -> Int -> PreSmartExp SmartAcc SmartExp a
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR a
t1 (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
      y :: SmartExp b
y = PreSmartExp SmartAcc SmartExp b -> SmartExp b
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR b -> Int -> PreSmartExp SmartAcc SmartExp b
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR b
t2 Int
lvl)
  --
  (RootExp c
body, Int
height) <- Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp c
-> IO (RootExp c, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config OccMapHash SmartAcc
accOccMap (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) [Int
Item [Int]
lvl, Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1] (SmartExp a -> SmartExp b -> SmartExp c
f SmartExp a
x SmartExp b
y)
  (SmartExp a -> SmartExp b -> RootExp c, Int)
-> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (\SmartExp a
_ SmartExp b
_ -> RootExp c
body, Int
height)

makeOccMapStencil1
    :: forall sh a b stencil. HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> R.StencilR sh a stencil
    -> Level
    -> (SmartExp stencil -> SmartExp b)
    -> IO (SmartExp stencil -> RootExp b, Int)
makeOccMapStencil1 :: Config
-> OccMapHash SmartAcc
-> StencilR sh a stencil
-> Int
-> (SmartExp stencil -> SmartExp b)
-> IO (SmartExp stencil -> RootExp b, Int)
makeOccMapStencil1 Config
config OccMapHash SmartAcc
accOccMap StencilR sh a stencil
s Int
lvl SmartExp stencil -> SmartExp b
stencil = do
  let x :: SmartExp stencil
x = PreSmartExp SmartAcc SmartExp stencil -> SmartExp stencil
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR stencil -> Int -> PreSmartExp SmartAcc SmartExp stencil
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag (StencilR sh a stencil -> TypeR stencil
forall sh e pat. StencilR sh e pat -> TypeR pat
R.stencilR StencilR sh a stencil
s) Int
lvl)
  --
  (RootExp b
body, Int
height) <- Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp b
-> IO (RootExp b, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config OccMapHash SmartAcc
accOccMap (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Int
Item [Int]
lvl] (SmartExp stencil -> SmartExp b
stencil SmartExp stencil
x)
  (SmartExp stencil -> RootExp b, Int)
-> IO (SmartExp stencil -> RootExp b, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (RootExp b -> SmartExp stencil -> RootExp b
forall a b. a -> b -> a
const RootExp b
body, Int
height)

makeOccMapStencil2
    :: forall sh a b c stencil1 stencil2. HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> R.StencilR sh a stencil1
    -> R.StencilR sh b stencil2
    -> Level
    -> (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c)
    -> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
makeOccMapStencil2 :: Config
-> OccMapHash SmartAcc
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> Int
-> (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c)
-> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
makeOccMapStencil2 Config
config OccMapHash SmartAcc
accOccMap StencilR sh a stencil1
sR1 StencilR sh b stencil2
sR2 Int
lvl SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c
stencil = do
  let x :: SmartExp stencil1
x = PreSmartExp SmartAcc SmartExp stencil1 -> SmartExp stencil1
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR stencil1 -> Int -> PreSmartExp SmartAcc SmartExp stencil1
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag (StencilR sh a stencil1 -> TypeR stencil1
forall sh e pat. StencilR sh e pat -> TypeR pat
R.stencilR StencilR sh a stencil1
sR1) (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
      y :: SmartExp stencil2
y = PreSmartExp SmartAcc SmartExp stencil2 -> SmartExp stencil2
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR stencil2 -> Int -> PreSmartExp SmartAcc SmartExp stencil2
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag (StencilR sh b stencil2 -> TypeR stencil2
forall sh e pat. StencilR sh e pat -> TypeR pat
R.stencilR StencilR sh b stencil2
sR2) Int
lvl)
  --
  (RootExp c
body, Int
height) <- Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp c
-> IO (RootExp c, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config OccMapHash SmartAcc
accOccMap (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) [Int
Item [Int]
lvl, Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1] (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c
stencil SmartExp stencil1
x SmartExp stencil2
y)
  (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
-> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (\SmartExp stencil1
_ SmartExp stencil2
_ -> RootExp c
body, Int
height)


-- Generate sharing information for expressions embedded in Acc computations.
-- Expressions are annotated with:
--
--  1) the tags of free scalar variables (for scalar functions)
--  2) a local occurrence map for that expression.
--
makeOccMapRootExp
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level                            -- The level of currently bound scalar variables
    -> [Int]                            -- The tags of newly introduced free scalar variables in this expression
    -> SmartExp e
    -> IO (RootExp e, Int)
makeOccMapRootExp :: Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config OccMapHash SmartAcc
accOccMap Int
lvl [Int]
fvs SmartExp e
exp = do
  String -> String -> IO ()
traceLine String
"makeOccMapRootExp" String
"Enter"
  HashTable RealWorld (StableASTName SmartExp) (Int, Int)
expOccMap                     <- IO (HashTable RealWorld (StableASTName SmartExp) (Int, Int))
forall (c :: * -> *) v. IO (ASTHashTable c v)
newASTHashTable
  (UnscopedExp [] SharingExp UnscopedAcc UnscopedExp e
exp', Int
height) <- Config
-> OccMapHash SmartAcc
-> OccMapHash SmartExp
-> Int
-> SmartExp e
-> IO (UnscopedExp e, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> OccMapHash SmartExp
-> Int
-> SmartExp e
-> IO (UnscopedExp e, Int)
makeOccMapSharingExp Config
config OccMapHash SmartAcc
accOccMap HashTable RealWorld (StableASTName SmartExp) (Int, Int)
OccMapHash SmartExp
expOccMap Int
lvl SmartExp e
exp
  OccMap SmartExp
frozenExpOccMap               <- OccMapHash SmartExp -> IO (OccMap SmartExp)
forall (c :: * -> *). OccMapHash c -> IO (OccMap c)
freezeOccMap HashTable RealWorld (StableASTName SmartExp) (Int, Int)
OccMapHash SmartExp
expOccMap
  String -> String -> IO ()
traceLine String
"makeOccMapRootExp" String
"Exit"
  (RootExp e, Int) -> IO (RootExp e, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (OccMap SmartExp -> UnscopedExp e -> RootExp e
forall t. OccMap SmartExp -> UnscopedExp t -> RootExp t
RootExp OccMap SmartExp
frozenExpOccMap ([Int] -> SharingExp UnscopedAcc UnscopedExp e -> UnscopedExp e
forall t.
[Int] -> SharingExp UnscopedAcc UnscopedExp t -> UnscopedExp t
UnscopedExp [Int]
fvs SharingExp UnscopedAcc UnscopedExp e
exp'), Int
height)


-- Generate sharing information for an open scalar expression.
--
makeOccMapSharingExp
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> OccMapHash SmartExp
    -> Level                            -- The level of currently bound variables
    -> SmartExp e
    -> IO (UnscopedExp e, Int)
makeOccMapSharingExp :: Config
-> OccMapHash SmartAcc
-> OccMapHash SmartExp
-> Int
-> SmartExp e
-> IO (UnscopedExp e, Int)
makeOccMapSharingExp Config
config OccMapHash SmartAcc
accOccMap OccMapHash SmartExp
expOccMap = Int -> SmartExp e -> IO (UnscopedExp e, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE
  where
    travE :: forall a. HasCallStack => Level -> SmartExp a -> IO (UnscopedExp a, Int)
    travE :: Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl exp :: SmartExp a
exp@(SmartExp PreSmartExp SmartAcc SmartExp a
pexp)
      = ((UnscopedExp a, Int) -> IO (UnscopedExp a, Int))
-> IO (UnscopedExp a, Int)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (((UnscopedExp a, Int) -> IO (UnscopedExp a, Int))
 -> IO (UnscopedExp a, Int))
-> ((UnscopedExp a, Int) -> IO (UnscopedExp a, Int))
-> IO (UnscopedExp a, Int)
forall a b. (a -> b) -> a -> b
$ \ ~(UnscopedExp a
_, Int
height) -> do
          -- Compute stable name and enter it into the occurrence map
          --
          StableName (SmartExp a)
sn                         <- SmartExp a -> IO (StableName (SmartExp a))
forall (c :: * -> *) t. c t -> IO (StableName (c t))
makeStableAST SmartExp a
exp
          Maybe Int
heightIfRepeatedOccurrence <- OccMapHash SmartExp
-> StableASTName SmartExp -> Int -> IO (Maybe Int)
forall (c :: * -> *).
OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc OccMapHash SmartExp
expOccMap (StableName (SmartExp a) -> StableASTName SmartExp
forall (c :: * -> *) t. StableName (c t) -> StableASTName c
StableASTName StableName (SmartExp a)
sn) Int
height

          String -> String -> IO ()
traceLine (PreSmartExp SmartAcc SmartExp a -> String
forall (acc :: * -> *) (exp :: * -> *) t.
PreSmartExp acc exp t -> String
showPreExpOp PreSmartExp SmartAcc SmartExp a
pexp) (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            let hash :: String
hash = Int -> String
forall a. Show a => a -> String
show (StableName (SmartExp a) -> Int
forall a. StableName a -> Int
hashStableName StableName (SmartExp a)
sn)
            case Maybe Int
heightIfRepeatedOccurrence of
              Just Int
height -> String
"REPEATED occurrence (sn = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
hash String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; height = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
height String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
              Maybe Int
Nothing     -> String
"first occurrence (sn = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
hash String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

          -- Reconstruct the computation in shared form.
          --
          -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise
          -- it is computed by the traversal function passed in 'newExp'.  See also 'enterOcc'.
          --
          let reconstruct :: IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
                          -> IO (UnscopedExp a, Int)
              reconstruct :: IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (UnscopedExp a, Int)
reconstruct IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
newExp
                = case Maybe Int
heightIfRepeatedOccurrence of
                    Just Int
height | Flag
exp_sharing Flag -> BitSet Word32 Flag -> Bool
forall a c. (Enum a, Bits c) => a -> BitSet c a -> Bool
`member` Config -> BitSet Word32 Flag
options Config
config
                      -> (UnscopedExp a, Int) -> IO (UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> SharingExp UnscopedAcc UnscopedExp a -> UnscopedExp a
forall t.
[Int] -> SharingExp UnscopedAcc UnscopedExp t -> UnscopedExp t
UnscopedExp [] (StableExpName a -> TypeR a -> SharingExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> TypeR t -> SharingExp acc exp t
VarSharing (StableName (SmartExp a) -> Int -> StableExpName a
forall t. StableName t -> Int -> StableNameHeight t
StableNameHeight StableName (SmartExp a)
sn Int
height) (PreSmartExp SmartAcc SmartExp a -> TypeR a
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
typeR PreSmartExp SmartAcc SmartExp a
pexp)), Int
height)
                    Maybe Int
_ -> do (PreSmartExp UnscopedAcc UnscopedExp a
exp, Int
height) <- IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
newExp
                            (UnscopedExp a, Int) -> IO (UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> SharingExp UnscopedAcc UnscopedExp a -> UnscopedExp a
forall t.
[Int] -> SharingExp UnscopedAcc UnscopedExp t -> UnscopedExp t
UnscopedExp [] (StableExpName a
-> PreSmartExp UnscopedAcc UnscopedExp a
-> SharingExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> PreSmartExp acc exp t -> SharingExp acc exp t
ExpSharing (StableName (SmartExp a) -> Int -> StableExpName a
forall t. StableName t -> Int -> StableNameHeight t
StableNameHeight StableName (SmartExp a)
sn Int
height) PreSmartExp UnscopedAcc UnscopedExp a
exp), Int
height)

          IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (UnscopedExp a, Int)
reconstruct (IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
 -> IO (UnscopedExp a, Int))
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (UnscopedExp a, Int)
forall a b. (a -> b) -> a -> b
$ case PreSmartExp SmartAcc SmartExp a
pexp of
            Tag TypeR a
tp Int
i            -> (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeR a -> Int -> PreSmartExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR a
tp Int
i, Int
0)      -- height is 0!
            Const ScalarType a
tp a
c          -> (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalarType a -> a -> PreSmartExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const ScalarType a
tp a
c, Int
1)
            Undef ScalarType a
tp            -> (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalarType a -> PreSmartExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> PreSmartExp acc exp t
Undef ScalarType a
tp, Int
1)
            PreSmartExp SmartAcc SmartExp a
Nil                 -> (PreSmartExp UnscopedAcc UnscopedExp (), Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp (), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (PreSmartExp UnscopedAcc UnscopedExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil, Int
1)
            Pair SmartExp t1
e1 SmartExp t2
e2          -> (UnscopedExp t1
 -> UnscopedExp t2 -> PreSmartExp UnscopedAcc UnscopedExp (t1, t2))
-> SmartExp t1
-> SmartExp t2
-> IO (PreSmartExp UnscopedAcc UnscopedExp (t1, t2), Int)
forall b c r.
HasCallStack =>
(UnscopedExp b -> UnscopedExp c -> r)
-> SmartExp b -> SmartExp c -> IO (r, Int)
travE2 UnscopedExp t1
-> UnscopedExp t2 -> PreSmartExp UnscopedAcc UnscopedExp (t1, t2)
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
Pair SmartExp t1
e1 SmartExp t2
e2
            Prj PairIdx (t1, t2) a
i SmartExp (t1, t2)
e             -> (UnscopedExp (t1, t2) -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp (t1, t2)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (PairIdx (t1, t2) a
-> UnscopedExp (t1, t2) -> PreSmartExp UnscopedAcc UnscopedExp a
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (t1, t2) a
i) SmartExp (t1, t2)
e
            VecPack   VecR n s tup
vec SmartExp tup
e     -> (UnscopedExp tup -> PreSmartExp UnscopedAcc UnscopedExp (Vec n s))
-> SmartExp tup
-> IO (PreSmartExp UnscopedAcc UnscopedExp (Vec n s), Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (VecR n s tup
-> UnscopedExp tup -> PreSmartExp UnscopedAcc UnscopedExp (Vec n s)
forall (n :: Nat) s tup (exp :: * -> *) (acc :: * -> *).
KnownNat n =>
VecR n s tup -> exp tup -> PreSmartExp acc exp (Vec n s)
VecPack   VecR n s tup
vec) SmartExp tup
e
            VecUnpack VecR n s a
vec SmartExp (Vec n s)
e     -> (UnscopedExp (Vec n s) -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp (Vec n s)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (VecR n s a
-> UnscopedExp (Vec n s) -> PreSmartExp UnscopedAcc UnscopedExp a
forall (n :: Nat) s tup (exp :: * -> *) (acc :: * -> *).
KnownNat n =>
VecR n s tup -> exp (Vec n s) -> PreSmartExp acc exp tup
VecUnpack VecR n s a
vec) SmartExp (Vec n s)
e
            ToIndex ShapeR sh
shr SmartExp sh
sh SmartExp sh
ix   -> (UnscopedExp sh
 -> UnscopedExp sh -> PreSmartExp UnscopedAcc UnscopedExp Int)
-> SmartExp sh
-> SmartExp sh
-> IO (PreSmartExp UnscopedAcc UnscopedExp Int, Int)
forall b c r.
HasCallStack =>
(UnscopedExp b -> UnscopedExp c -> r)
-> SmartExp b -> SmartExp c -> IO (r, Int)
travE2 (ShapeR sh
-> UnscopedExp sh
-> UnscopedExp sh
-> PreSmartExp UnscopedAcc UnscopedExp Int
forall sh (exp :: * -> *) (acc :: * -> *).
ShapeR sh -> exp sh -> exp sh -> PreSmartExp acc exp Int
ToIndex ShapeR sh
shr) SmartExp sh
sh SmartExp sh
ix
            FromIndex ShapeR a
shr SmartExp a
sh SmartExp Int
e  -> (UnscopedExp a
 -> UnscopedExp Int -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp a
-> SmartExp Int
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b c r.
HasCallStack =>
(UnscopedExp b -> UnscopedExp c -> r)
-> SmartExp b -> SmartExp c -> IO (r, Int)
travE2 (ShapeR a
-> UnscopedExp a
-> UnscopedExp Int
-> PreSmartExp UnscopedAcc UnscopedExp a
forall sh (exp :: * -> *) (acc :: * -> *).
ShapeR sh -> exp sh -> exp Int -> PreSmartExp acc exp sh
FromIndex ShapeR a
shr) SmartExp a
sh SmartExp Int
e
            Match TagR a
t SmartExp a
e           -> (UnscopedExp a -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp a -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (TagR a -> UnscopedExp a -> PreSmartExp UnscopedAcc UnscopedExp a
forall t (exp :: * -> *) (acc :: * -> *).
TagR t -> exp t -> PreSmartExp acc exp t
Match TagR a
t) SmartExp a
e
            Case SmartExp a
e [(TagR a, SmartExp a)]
rhs          -> do
                                     (UnscopedExp a
e',   Int
h1) <- Int -> SmartExp a -> IO (UnscopedExp a, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp a
e
                                     ([(TagR a, UnscopedExp a)]
rhs', [Int]
h2) <- [((TagR a, UnscopedExp a), Int)]
-> ([(TagR a, UnscopedExp a)], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((TagR a, UnscopedExp a), Int)]
 -> ([(TagR a, UnscopedExp a)], [Int]))
-> IO [((TagR a, UnscopedExp a), Int)]
-> IO ([(TagR a, UnscopedExp a)], [Int])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IO ((TagR a, UnscopedExp a), Int)]
-> IO [((TagR a, UnscopedExp a), Int)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ (UnscopedExp a -> (TagR a, UnscopedExp a))
-> SmartExp a -> IO ((TagR a, UnscopedExp a), Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (TagR a
t,) SmartExp a
c | (TagR a
t,SmartExp a
c) <- [(TagR a, SmartExp a)]
rhs ]
                                     (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedExp a
-> [(TagR a, UnscopedExp a)]
-> PreSmartExp UnscopedAcc UnscopedExp a
forall (exp :: * -> *) a b (acc :: * -> *).
exp a -> [(TagR a, exp b)] -> PreSmartExp acc exp b
Case UnscopedExp a
e' [(TagR a, UnscopedExp a)]
rhs', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int]
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Cond SmartExp PrimBool
e1 SmartExp a
e2 SmartExp a
e3       -> (UnscopedExp PrimBool
 -> UnscopedExp a
 -> UnscopedExp a
 -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp PrimBool
-> SmartExp a
-> SmartExp a
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b c d r.
HasCallStack =>
(UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> r)
-> SmartExp b -> SmartExp c -> SmartExp d -> IO (r, Int)
travE3 UnscopedExp PrimBool
-> UnscopedExp a
-> UnscopedExp a
-> PreSmartExp UnscopedAcc UnscopedExp a
forall (exp :: * -> *) t (acc :: * -> *).
exp PrimBool -> exp t -> exp t -> PreSmartExp acc exp t
Cond SmartExp PrimBool
e1 SmartExp a
e2 SmartExp a
e3
            While TypeR a
t SmartExp a -> SmartExp PrimBool
p SmartExp a -> SmartExp a
iter SmartExp a
init -> do
                                     (SmartExp a -> UnscopedExp PrimBool
p'   , Int
h1) <- Int
-> TypeR a
-> (SmartExp a -> SmartExp PrimBool)
-> IO (SmartExp a -> UnscopedExp PrimBool, Int)
forall b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> UnscopedExp b, Int)
traverseFun1 Int
lvl TypeR a
t SmartExp a -> SmartExp PrimBool
p
                                     (SmartExp a -> UnscopedExp a
iter', Int
h2) <- Int
-> TypeR a
-> (SmartExp a -> SmartExp a)
-> IO (SmartExp a -> UnscopedExp a, Int)
forall b.
HasCallStack =>
Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> UnscopedExp b, Int)
traverseFun1 Int
lvl TypeR a
t SmartExp a -> SmartExp a
iter
                                     (UnscopedExp a
init', Int
h3) <- Int -> SmartExp a -> IO (UnscopedExp a, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp a
init
                                     (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeR a
-> (SmartExp a -> UnscopedExp PrimBool)
-> (SmartExp a -> UnscopedExp a)
-> UnscopedExp a
-> PreSmartExp UnscopedAcc UnscopedExp a
forall t (exp :: * -> *) (acc :: * -> *).
TypeR t
-> (SmartExp t -> exp PrimBool)
-> (SmartExp t -> exp t)
-> exp t
-> PreSmartExp acc exp t
While TypeR a
t SmartExp a -> UnscopedExp PrimBool
p' SmartExp a -> UnscopedExp a
iter' UnscopedExp a
init', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            PrimConst PrimConst a
c         -> (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimConst a -> PreSmartExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) (exp :: * -> *).
PrimConst t -> PreSmartExp acc exp t
PrimConst PrimConst a
c, Int
1)
            PrimApp PrimFun (a -> a)
p SmartExp a
e         -> (UnscopedExp a -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp a -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (PrimFun (a -> a)
-> UnscopedExp a -> PreSmartExp UnscopedAcc UnscopedExp a
forall a r (exp :: * -> *) (acc :: * -> *).
PrimFun (a -> r) -> exp a -> PreSmartExp acc exp r
PrimApp PrimFun (a -> a)
p) SmartExp a
e
            Index TypeR a
tp SmartAcc (Array sh a)
a SmartExp sh
e        -> (UnscopedAcc (Array sh a)
 -> UnscopedExp sh -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartAcc (Array sh a)
-> SmartExp sh
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b c r.
HasCallStack =>
(UnscopedAcc b -> UnscopedExp c -> r)
-> SmartAcc b -> SmartExp c -> IO (r, Int)
travAE (TypeR a
-> UnscopedAcc (Array sh a)
-> UnscopedExp sh
-> PreSmartExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) sh (exp :: * -> *).
TypeR t -> acc (Array sh t) -> exp sh -> PreSmartExp acc exp t
Index TypeR a
tp) SmartAcc (Array sh a)
a SmartExp sh
e
            LinearIndex TypeR a
tp SmartAcc (Array sh a)
a SmartExp Int
i  -> (UnscopedAcc (Array sh a)
 -> UnscopedExp Int -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartAcc (Array sh a)
-> SmartExp Int
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b c r.
HasCallStack =>
(UnscopedAcc b -> UnscopedExp c -> r)
-> SmartAcc b -> SmartExp c -> IO (r, Int)
travAE (TypeR a
-> UnscopedAcc (Array sh a)
-> UnscopedExp Int
-> PreSmartExp UnscopedAcc UnscopedExp a
forall t (acc :: * -> *) sh (exp :: * -> *).
TypeR t -> acc (Array sh t) -> exp Int -> PreSmartExp acc exp t
LinearIndex TypeR a
tp) SmartAcc (Array sh a)
a SmartExp Int
i
            Shape ShapeR a
shr SmartAcc (Array a e)
a         -> (UnscopedAcc (Array a e) -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartAcc (Array a e)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b r.
HasCallStack =>
(UnscopedAcc b -> r) -> SmartAcc b -> IO (r, Int)
travA (ShapeR a
-> UnscopedAcc (Array a e) -> PreSmartExp UnscopedAcc UnscopedExp a
forall sh (acc :: * -> *) e (exp :: * -> *).
ShapeR sh -> acc (Array sh e) -> PreSmartExp acc exp sh
Shape ShapeR a
shr) SmartAcc (Array a e)
a
            ShapeSize ShapeR sh
shr SmartExp sh
e     -> (UnscopedExp sh -> PreSmartExp UnscopedAcc UnscopedExp Int)
-> SmartExp sh -> IO (PreSmartExp UnscopedAcc UnscopedExp Int, Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (ShapeR sh
-> UnscopedExp sh -> PreSmartExp UnscopedAcc UnscopedExp Int
forall sh (exp :: * -> *) (acc :: * -> *).
ShapeR sh -> exp sh -> PreSmartExp acc exp Int
ShapeSize ShapeR sh
shr) SmartExp sh
e
            Foreign TypeR a
tp asm (x -> a)
ff SmartExp x -> SmartExp a
f SmartExp x
e   -> do
                                      (UnscopedExp x
e', Int
h) <- Int -> SmartExp x -> IO (UnscopedExp x, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp x
e
                                      (PreSmartExp UnscopedAcc UnscopedExp a, Int)
-> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return  (TypeR a
-> asm (x -> a)
-> (SmartExp x -> SmartExp a)
-> UnscopedExp x
-> PreSmartExp UnscopedAcc UnscopedExp a
forall (asm :: * -> *) y x (exp :: * -> *) (acc :: * -> *).
Foreign asm =>
TypeR y
-> asm (x -> y)
-> (SmartExp x -> SmartExp y)
-> exp x
-> PreSmartExp acc exp y
Foreign TypeR a
tp asm (x -> a)
ff SmartExp x -> SmartExp a
f UnscopedExp x
e', Int
hInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            Coerce ScalarType a
t1 ScalarType a
t2 SmartExp a
e      -> (UnscopedExp a -> PreSmartExp UnscopedAcc UnscopedExp a)
-> SmartExp a -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
forall b r.
HasCallStack =>
(UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 (ScalarType a
-> ScalarType a
-> UnscopedExp a
-> PreSmartExp UnscopedAcc UnscopedExp a
forall a b (exp :: * -> *) (acc :: * -> *).
BitSizeEq a b =>
ScalarType a -> ScalarType b -> exp a -> PreSmartExp acc exp b
Coerce ScalarType a
t1 ScalarType a
t2) SmartExp a
e

      where
        traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
        traverseAcc :: Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc = Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc arrs
-> IO (UnscopedAcc arrs, Int)
forall arrs.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> SmartAcc arrs
-> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc Config
config OccMapHash SmartAcc
accOccMap

        traverseFun1
            :: HasCallStack
            => Level
            -> TypeR a
            -> (SmartExp a -> SmartExp b)
            -> IO (SmartExp a -> UnscopedExp b, Int)
        traverseFun1 :: Int
-> TypeR a
-> (SmartExp a -> SmartExp b)
-> IO (SmartExp a -> UnscopedExp b, Int)
traverseFun1 Int
lvl TypeR a
tp SmartExp a -> SmartExp b
f
          = do
              let x :: SmartExp a
x = PreSmartExp SmartAcc SmartExp a -> SmartExp a
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (TypeR a -> Int -> PreSmartExp SmartAcc SmartExp a
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR a
tp Int
lvl)
              (UnscopedExp [] SharingExp UnscopedAcc UnscopedExp b
body, Int
height) <- Int -> SmartExp b -> IO (UnscopedExp b, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE (Int
lvlInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (SmartExp a -> SmartExp b
f SmartExp a
x)
              (SmartExp a -> UnscopedExp b, Int)
-> IO (SmartExp a -> UnscopedExp b, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedExp b -> SmartExp a -> UnscopedExp b
forall a b. a -> b -> a
const ([Int] -> SharingExp UnscopedAcc UnscopedExp b -> UnscopedExp b
forall t.
[Int] -> SharingExp UnscopedAcc UnscopedExp t -> UnscopedExp t
UnscopedExp [Int
Item [Int]
lvl] SharingExp UnscopedAcc UnscopedExp b
body), Int
height Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)


        travE1 :: HasCallStack => (UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
        travE1 :: (UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
travE1 UnscopedExp b -> r
c SmartExp b
e
          = do
              (UnscopedExp b
e', Int
h) <- Int -> SmartExp b -> IO (UnscopedExp b, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp b
e
              (r, Int) -> IO (r, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedExp b -> r
c UnscopedExp b
e', Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travE2 :: HasCallStack
               => (UnscopedExp b -> UnscopedExp c -> r)
               -> SmartExp b
               -> SmartExp c
               -> IO (r, Int)
        travE2 :: (UnscopedExp b -> UnscopedExp c -> r)
-> SmartExp b -> SmartExp c -> IO (r, Int)
travE2 UnscopedExp b -> UnscopedExp c -> r
c SmartExp b
e1 SmartExp c
e2
          = do
              (UnscopedExp b
e1', Int
h1) <- Int -> SmartExp b -> IO (UnscopedExp b, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp b
e1
              (UnscopedExp c
e2', Int
h2) <- Int -> SmartExp c -> IO (UnscopedExp c, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp c
e2
              (r, Int) -> IO (r, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedExp b -> UnscopedExp c -> r
c UnscopedExp b
e1' UnscopedExp c
e2', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travE3 :: HasCallStack
               => (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> r)
               -> SmartExp b
               -> SmartExp c
               -> SmartExp d
               -> IO (r, Int)
        travE3 :: (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> r)
-> SmartExp b -> SmartExp c -> SmartExp d -> IO (r, Int)
travE3 UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> r
c SmartExp b
e1 SmartExp c
e2 SmartExp d
e3
          = do
              (UnscopedExp b
e1', Int
h1) <- Int -> SmartExp b -> IO (UnscopedExp b, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp b
e1
              (UnscopedExp c
e2', Int
h2) <- Int -> SmartExp c -> IO (UnscopedExp c, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp c
e2
              (UnscopedExp d
e3', Int
h3) <- Int -> SmartExp d -> IO (UnscopedExp d, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp d
e3
              (r, Int) -> IO (r, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> r
c UnscopedExp b
e1' UnscopedExp c
e2' UnscopedExp d
e3', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travA :: HasCallStack => (UnscopedAcc b -> r) -> SmartAcc b -> IO (r, Int)
        travA :: (UnscopedAcc b -> r) -> SmartAcc b -> IO (r, Int)
travA UnscopedAcc b -> r
c SmartAcc b
acc
          = do
              (UnscopedAcc b
acc', Int
h) <- Int -> SmartAcc b -> IO (UnscopedAcc b, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc b
acc
              (r, Int) -> IO (r, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedAcc b -> r
c UnscopedAcc b
acc', Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        travAE :: HasCallStack
               => (UnscopedAcc b -> UnscopedExp c -> r)
               -> SmartAcc b
               -> SmartExp c
               -> IO (r, Int)
        travAE :: (UnscopedAcc b -> UnscopedExp c -> r)
-> SmartAcc b -> SmartExp c -> IO (r, Int)
travAE UnscopedAcc b -> UnscopedExp c -> r
c SmartAcc b
acc SmartExp c
e
          = do
              (UnscopedAcc b
acc', Int
h1) <- Int -> SmartAcc b -> IO (UnscopedAcc b, Int)
forall arrs.
HasCallStack =>
Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc Int
lvl SmartAcc b
acc
              (UnscopedExp c
e'  , Int
h2) <- Int -> SmartExp c -> IO (UnscopedExp c, Int)
forall a.
HasCallStack =>
Int -> SmartExp a -> IO (UnscopedExp a, Int)
travE Int
lvl SmartExp c
e
              (r, Int) -> IO (r, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnscopedAcc b -> UnscopedExp c -> r
c UnscopedAcc b
acc' UnscopedExp c
e', Int
h1 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
h2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

{--
makeOccMapRootSeq
    :: Typeable arrs
    => Config
    -> OccMapHash Acc
    -> Level
    -> Seq arrs
    -> IO (RootSeq arrs, Int)
makeOccMapRootSeq config accOccMap lvl seq = do
  traceLine "makeOccMapRootSeq" "Enter"
  seqOccMap       <- newASTHashTable
  (seq', height)  <- makeOccMapSharingSeq config accOccMap seqOccMap lvl seq
  frozenSeqOccMap <- freezeOccMap seqOccMap
  traceLine "makeOccMapRootSeq" "Exit"
  return (RootSeq frozenSeqOccMap seq', height)

-- Generate sharing information for an open sequence expression.
--
makeOccMapSharingSeq
    :: Typeable e
    => Config
    -> OccMapHash Acc
    -> OccMapHash Seq
    -> Level                            -- The level of currently bound variables
    -> Seq e
    -> IO (UnscopedSeq e, Int)
makeOccMapSharingSeq config accOccMap seqOccMap = traverseSeq
  where
    traverseAcc :: Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int)
    traverseAcc = makeOccMapSharingAcc config accOccMap

    traverseAfun1 :: (Arrays a, Typeable b) => Level -> (Acc a -> Acc b) -> IO (Acc a -> UnscopedAcc b, Int)
    traverseAfun1 = makeOccMapAfun1 config accOccMap

    traverseAfun2 :: (Arrays a, Arrays b, Typeable c) => Level -> (Acc a -> Acc b -> Acc c) -> IO (Acc a -> Acc b -> UnscopedAcc c, Int)
    traverseAfun2 = makeOccMapAfun2 config accOccMap

    traverseAfun3 :: (Arrays a, Arrays b, Arrays c, Typeable d) => Level -> (Acc a -> Acc b -> Acc c -> Acc d) -> IO (Acc a -> Acc b -> Acc c -> UnscopedAcc d, Int)
    traverseAfun3 = makeOccMapAfun3 config accOccMap

    traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int)
    traverseExp = makeOccMapExp config accOccMap

    traverseFun2 :: (Elt a, Elt b, Typeable c)
                 => Level
                 -> (Exp a -> Exp b -> Exp c)
                 -> IO (Exp a -> Exp b -> RootExp c, Int)
    traverseFun2 = makeOccMapFun2 config accOccMap

    traverseTup :: Level -> Atuple Seq tup -> IO (Atuple UnscopedSeq tup, Int)
    traverseTup _   NilAtup          = return (NilAtup, 1)
    traverseTup lvl (SnocAtup tup s) = do
                                        (tup', h1) <- traverseTup lvl tup
                                        (s'  , h2) <- traverseSeq lvl s
                                        return (SnocAtup tup' s', h1 `max` h2 + 1)

    traverseSeq :: forall arrs. Typeable arrs => Level -> Seq arrs -> IO (UnscopedSeq arrs, Int)
    traverseSeq lvl acc@(Seq seq)
      = mfix $ \ ~(_, height) -> do
          -- Compute stable name and enter it into the occurrence map
          --
          sn                         <- makeStableAST acc
          heightIfRepeatedOccurrence <- enterOcc seqOccMap (StableASTName sn) height

          traceLine (showPreSeqOp seq) $ do
            let hash = show (hashStableName sn)
            case heightIfRepeatedOccurrence of
              Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
              Nothing     -> "first occurrence (sn = " ++ hash ++ ")"

          -- Reconstruct the computation in shared form.
          --
          -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise
          -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'.
          --
          -- NB: This function can only be used in the case alternatives below; outside of the
          --     case we cannot discharge the 'Arrays arrs' constraint.
          --
          let producer :: (arrs ~ [a], Arrays a)
                       => IO (PreSeq UnscopedAcc UnscopedSeq RootExp arrs, Int)
                       -> IO (UnscopedSeq arrs, Int)
              producer newSeq
                = case heightIfRepeatedOccurrence of
                    Just height | recoverSeqSharing config
                      -> return (UnscopedSeq (SvarSharing (StableNameHeight sn height)), height)
                    _ -> do (seq, height) <- newSeq
                            return (UnscopedSeq (SeqSharing (StableNameHeight sn height) seq), height)

          let consumer :: IO (PreSeq UnscopedAcc UnscopedSeq RootExp arrs, Int)
                       -> IO (UnscopedSeq arrs, Int)
              consumer newSeq
                = do (seq, height) <- newSeq
                     return (UnscopedSeq (SeqSharing (StableNameHeight sn height) seq), height)

          case seq of
            StreamIn arrs -> producer $ return (StreamIn arrs, 1)
            ToSeq sl acc -> producer $ do
              (acc', h1) <- traverseAcc lvl acc
              return (ToSeq sl acc', h1 + 1)
            MapSeq afun s -> producer $ do
              (afun', h1) <- traverseAfun1 lvl afun
              (s'   , h2) <- traverseSeq lvl s
              return (MapSeq afun' s', h1 `max` h2 + 1)
            ZipWithSeq afun s1 s2 -> producer $ do
              (afun', h1) <- traverseAfun2 lvl afun
              (s1'  , h2) <- traverseSeq lvl s1
              (s2'  , h3) <- traverseSeq lvl s2
              return (ZipWithSeq afun' s1' s2', h1 `max` h2 `max` h3 + 1)
            ScanSeq fun e s -> producer $ do
              (fun', h1) <- traverseFun2 lvl fun
              (e',  h2) <- traverseExp lvl e
              (s'   , h3) <- traverseSeq lvl s
              return (ScanSeq fun' e' s', h1 `max` h2 `max` h3 + 1)
            FoldSeq fun e s -> consumer $ do
              (fun', h1) <- traverseFun2 lvl fun
              (e'  , h2) <- traverseExp lvl e
              (s'  , h3) <- traverseSeq lvl s
              return (FoldSeq fun' e' s', h1 `max` h2 `max` h3 + 1)
            FoldSeqFlatten afun acc s -> consumer $ do
              (afun', h1) <- traverseAfun3 lvl afun
              (acc',  h2) <- traverseAcc lvl acc
              (s'   , h3) <- traverseSeq lvl s
              return (FoldSeqFlatten afun' acc' s', h1 `max` h2 `max` h3 + 1)
            Stuple t -> consumer $ do
              (t', h1) <- traverseTup lvl t
              return (Stuple t', h1 + 1)
--}


-- Type used to maintain how often each shared subterm, so far, occurred during a bottom-up sweep,
-- as well as the relation between subterms. It is comprised of a list of terms and a graph giving
-- their relation.
--
--   Invariants of the list:
--   - If one shared term 's' is itself a subterm of another shared term 't', then 's' must occur
--     *after* 't' in the list.
--   - No shared term occurs twice.
--   - A term may have a final occurrence count of only 1 iff it is either a free variable ('Atag'
--     or 'Tag') or an array computation lifted out of an expression.
--   - All 'Exp' node counts precede all 'SmartAcc' node counts as we don't share 'Exp' nodes across 'SmartAcc'
--     nodes. Similarly, all 'Seq' nodes precede 'SmartAcc' nodes and 'Exp' nodes precede 'Seq' nodes.
--
-- We determine the subterm property by using the tree height in 'StableNameHeight'.  Trees get
-- smaller towards the end of a 'NodeCounts' list.  The height of free variables ('Atag' or 'Tag')
-- is 0, whereas other leaves have height 1.  This guarantees that all free variables are at the end
-- of the 'NodeCounts' list.
--
-- The graph is represented as a map where a stable name 'a' is mapped to a set of stables names 'b'
-- such that if there exists a edge from 'a' to 'c' that 'c' is contained within 'b'.
--
--  Properties of the graph:
--  - There exists an edge from 'a' to 'b' if the term 'a' names is a subterm of the term named by
--    'b'.
--
-- To ensure the list invariant and the graph properties are preserved over merging node counts from
-- sibling subterms, the function '(+++)' must be used.
--
type NodeCounts = ([NodeCount], Map.HashMap NodeName (Set.HashSet NodeName))

data NodeName where
  NodeName :: StableName a -> NodeName

instance Eq NodeName where
  (NodeName StableName a
sn1) == :: NodeName -> NodeName -> Bool
== (NodeName StableName a
sn2) = StableName a -> StableName a -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName StableName a
sn1 StableName a
sn2

instance Hashable NodeName where
  hashWithSalt :: Int -> NodeName -> Int
hashWithSalt Int
hash (NodeName StableName a
sn1) = Int
hash Int -> Int -> Int
forall a. Num a => a -> a -> a
+ StableName a -> Int
forall a. StableName a -> Int
hashStableName StableName a
sn1

instance Show NodeName where
  show :: NodeName -> String
show (NodeName StableName a
sn) = Int -> String
forall a. Show a => a -> String
show (StableName a -> Int
forall a. StableName a -> Int
hashStableName StableName a
sn)

data NodeCount = AccNodeCount StableSharingAcc Int
               | ExpNodeCount StableSharingExp Int
               -- SeqNodeCount StableSharingSeq Int
               deriving Int -> NodeCount -> ShowS
[NodeCount] -> ShowS
NodeCount -> String
(Int -> NodeCount -> ShowS)
-> (NodeCount -> String)
-> ([NodeCount] -> ShowS)
-> Show NodeCount
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NodeCount] -> ShowS
$cshowList :: [NodeCount] -> ShowS
show :: NodeCount -> String
$cshow :: NodeCount -> String
showsPrec :: Int -> NodeCount -> ShowS
$cshowsPrec :: Int -> NodeCount -> ShowS
Show

-- Empty node counts
--
noNodeCounts :: NodeCounts
noNodeCounts :: NodeCounts
noNodeCounts = ([], HashMap NodeName (HashSet NodeName)
forall k v. HashMap k v
Map.empty)

-- Insert an Acc node into the node counts, assuming that it is a superterm of the all the existing
-- nodes.
--
-- TODO: Perform cycle detection here.
--
insertAccNode :: StableSharingAcc -> NodeCounts -> NodeCounts
insertAccNode :: StableSharingAcc -> NodeCounts -> NodeCounts
insertAccNode ssa :: StableSharingAcc
ssa@(StableSharingAcc (StableNameHeight StableName (SmartAcc arrs)
sn Int
_) SharingAcc ScopedAcc ScopedExp arrs
_) ([NodeCount]
subterms,HashMap NodeName (HashSet NodeName)
g)
  = ([StableSharingAcc -> Int -> NodeCount
AccNodeCount StableSharingAcc
ssa Int
1], HashMap NodeName (HashSet NodeName)
g') NodeCounts -> NodeCounts -> NodeCounts
+++ ([NodeCount]
subterms,HashMap NodeName (HashSet NodeName)
g)
  where
    k :: NodeName
k  = StableName (SmartAcc arrs) -> NodeName
forall a. StableName a -> NodeName
NodeName StableName (SmartAcc arrs)
sn
    hs :: [NodeName]
hs = (NodeCount -> NodeName) -> [NodeCount] -> [NodeName]
forall a b. (a -> b) -> [a] -> [b]
map NodeCount -> NodeName
nodeName [NodeCount]
subterms
    g' :: HashMap NodeName (HashSet NodeName)
g' = [(NodeName, HashSet NodeName)]
-> HashMap NodeName (HashSet NodeName)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
Map.fromList ([(NodeName, HashSet NodeName)]
 -> HashMap NodeName (HashSet NodeName))
-> [(NodeName, HashSet NodeName)]
-> HashMap NodeName (HashSet NodeName)
forall a b. (a -> b) -> a -> b
$ (NodeName
k, HashSet NodeName
forall a. HashSet a
Set.empty) (NodeName, HashSet NodeName)
-> [(NodeName, HashSet NodeName)] -> [(NodeName, HashSet NodeName)]
forall a. a -> [a] -> [a]
: [(NodeName
h, NodeName -> HashSet NodeName
forall a. Hashable a => a -> HashSet a
Set.singleton NodeName
k) | NodeName
h <- [NodeName]
hs]

-- Insert an Exp node into the node counts, assuming that it is a superterm of the all the existing
-- nodes.
--
-- TODO: Perform cycle detection here.
--
insertExpNode :: StableSharingExp -> NodeCounts -> NodeCounts
insertExpNode :: StableSharingExp -> NodeCounts -> NodeCounts
insertExpNode ssa :: StableSharingExp
ssa@(StableSharingExp (StableNameHeight StableName (SmartExp t)
sn Int
_) SharingExp ScopedAcc ScopedExp t
_) ([NodeCount]
subterms,HashMap NodeName (HashSet NodeName)
g)
  = ([StableSharingExp -> Int -> NodeCount
ExpNodeCount StableSharingExp
ssa Int
1], HashMap NodeName (HashSet NodeName)
g') NodeCounts -> NodeCounts -> NodeCounts
+++ ([NodeCount]
subterms,HashMap NodeName (HashSet NodeName)
g)
  where
    k :: NodeName
k  = StableName (SmartExp t) -> NodeName
forall a. StableName a -> NodeName
NodeName StableName (SmartExp t)
sn
    hs :: [NodeName]
hs = (NodeCount -> NodeName) -> [NodeCount] -> [NodeName]
forall a b. (a -> b) -> [a] -> [b]
map NodeCount -> NodeName
nodeName [NodeCount]
subterms
    g' :: HashMap NodeName (HashSet NodeName)
g' = [(NodeName, HashSet NodeName)]
-> HashMap NodeName (HashSet NodeName)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
Map.fromList ([(NodeName, HashSet NodeName)]
 -> HashMap NodeName (HashSet NodeName))
-> [(NodeName, HashSet NodeName)]
-> HashMap NodeName (HashSet NodeName)
forall a b. (a -> b) -> a -> b
$ (NodeName
k, HashSet NodeName
forall a. HashSet a
Set.empty) (NodeName, HashSet NodeName)
-> [(NodeName, HashSet NodeName)] -> [(NodeName, HashSet NodeName)]
forall a. a -> [a] -> [a]
: [(NodeName
h, NodeName -> HashSet NodeName
forall a. Hashable a => a -> HashSet a
Set.singleton NodeName
k) | NodeName
h <- [NodeName]
hs]

{--
-- Insert an Seq node into the node counts, assuming that it is a superterm of the all the existing
-- nodes.
--
-- TODO: Perform cycle detection here.
--
insertSeqNode :: StableSharingSeq -> NodeCounts -> NodeCounts
insertSeqNode ssa@(StableSharingSeq (StableNameHeight sn _) _) (subterms,g)
  = ([SeqNodeCount ssa 1], g') +++ (subterms,g)
  where
    k  = NodeName sn
    hs = map nodeName subterms
    g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs]
--}

-- Remove nodes that aren't in the list from the graph.
--
-- RCE: This is no longer necessary when NDP is supported.
--
cleanCounts :: NodeCounts -> NodeCounts
cleanCounts :: NodeCounts -> NodeCounts
cleanCounts ([NodeCount]
ns, HashMap NodeName (HashSet NodeName)
g) = ([NodeCount]
ns, [(NodeName, HashSet NodeName)]
-> HashMap NodeName (HashSet NodeName)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
Map.fromList [(NodeName
h, (NodeName -> Bool) -> HashSet NodeName -> HashSet NodeName
forall a. (a -> Bool) -> HashSet a -> HashSet a
Set.filter ((NodeName -> [NodeName] -> Bool) -> [NodeName] -> NodeName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip NodeName -> [NodeName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [NodeName]
hs) (HashMap NodeName (HashSet NodeName)
g HashMap NodeName (HashSet NodeName) -> NodeName -> HashSet NodeName
forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
Map.! NodeName
h)) | NodeName
h <- [NodeName]
hs ])
  where
    hs :: [NodeName]
hs = (NodeCount -> NodeName) -> [NodeCount] -> [NodeName]
forall a b. (a -> b) -> [a] -> [b]
map NodeCount -> NodeName
nodeName [NodeCount]
ns

nodeName :: NodeCount -> NodeName
nodeName :: NodeCount -> NodeName
nodeName (AccNodeCount (StableSharingAcc (StableNameHeight StableName (SmartAcc arrs)
sn Int
_) SharingAcc ScopedAcc ScopedExp arrs
_) Int
_) = StableName (SmartAcc arrs) -> NodeName
forall a. StableName a -> NodeName
NodeName StableName (SmartAcc arrs)
sn
nodeName (ExpNodeCount (StableSharingExp (StableNameHeight StableName (SmartExp t)
sn Int
_) SharingExp ScopedAcc ScopedExp t
_) Int
_) = StableName (SmartExp t) -> NodeName
forall a. StableName a -> NodeName
NodeName StableName (SmartExp t)
sn
-- nodeName (SeqNodeCount (StableSharingSeq (StableNameHeight sn _) _) _) = NodeName sn


-- Combine node counts that belong to the same node.
--
-- * We assume that the list invariant —subterms follow their parents— holds for both arguments and
--   guarantee that it still holds for the result.
--
-- * In the same manner, we assume that all 'Exp' node counts precede 'SmartAcc' node counts and
--   guarantee that this also hold for the result.
--
(+++) :: NodeCounts -> NodeCounts -> NodeCounts
([NodeCount]
ns1, HashMap NodeName (HashSet NodeName)
g1) +++ :: NodeCounts -> NodeCounts -> NodeCounts
+++ ([NodeCount]
ns2, HashMap NodeName (HashSet NodeName)
g2) = ([NodeCount] -> [NodeCount]
cleanup ([NodeCount] -> [NodeCount]) -> [NodeCount] -> [NodeCount]
forall a b. (a -> b) -> a -> b
$ [NodeCount] -> [NodeCount] -> [NodeCount]
merge [NodeCount]
ns1 [NodeCount]
ns2, (HashSet NodeName -> HashSet NodeName -> HashSet NodeName)
-> HashMap NodeName (HashSet NodeName)
-> HashMap NodeName (HashSet NodeName)
-> HashMap NodeName (HashSet NodeName)
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> HashMap k v -> HashMap k v -> HashMap k v
Map.unionWith HashSet NodeName -> HashSet NodeName -> HashSet NodeName
forall a. (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
Set.union HashMap NodeName (HashSet NodeName)
g1 HashMap NodeName (HashSet NodeName)
g2)
  where
    merge :: [NodeCount] -> [NodeCount] -> [NodeCount]
merge [] [NodeCount]
x = [NodeCount]
x
    merge [NodeCount]
x [] = [NodeCount]
x
    merge (x :: NodeCount
x@(AccNodeCount StableSharingAcc
sa1 Int
count1):[NodeCount]
xs) (y :: NodeCount
y@(AccNodeCount StableSharingAcc
sa2 Int
count2):[NodeCount]
ys)
     | StableSharingAcc
sa1 StableSharingAcc -> StableSharingAcc -> Bool
forall a. Eq a => a -> a -> Bool
== StableSharingAcc
sa2          = StableSharingAcc -> Int -> NodeCount
AccNodeCount (StableSharingAcc
sa1 StableSharingAcc -> StableSharingAcc -> StableSharingAcc
`pickNoneAvar` StableSharingAcc
sa2) (Int
count1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
count2) NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge [NodeCount]
xs [NodeCount]
ys
     | StableSharingAcc
sa1 StableSharingAcc -> StableSharingAcc -> Bool
`higherSSA` StableSharingAcc
sa2 = NodeCount
x NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge [NodeCount]
xs (NodeCount
yNodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
:[NodeCount]
ys)
     | Bool
otherwise           = NodeCount
y NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge (NodeCount
xNodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
:[NodeCount]
xs) [NodeCount]
ys
    merge (x :: NodeCount
x@(ExpNodeCount StableSharingExp
se1 Int
count1):[NodeCount]
xs) (y :: NodeCount
y@(ExpNodeCount StableSharingExp
se2 Int
count2):[NodeCount]
ys)
     | StableSharingExp
se1 StableSharingExp -> StableSharingExp -> Bool
forall a. Eq a => a -> a -> Bool
== StableSharingExp
se2          = StableSharingExp -> Int -> NodeCount
ExpNodeCount (StableSharingExp
se1 StableSharingExp -> StableSharingExp -> StableSharingExp
`pickNoneVar` StableSharingExp
se2) (Int
count1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
count2) NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge [NodeCount]
xs [NodeCount]
ys
     | StableSharingExp
se1 StableSharingExp -> StableSharingExp -> Bool
`higherSSE` StableSharingExp
se2 = NodeCount
x NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge [NodeCount]
xs (NodeCount
yNodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
:[NodeCount]
ys)
     | Bool
otherwise           = NodeCount
y NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge (NodeCount
xNodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
:[NodeCount]
xs) [NodeCount]
ys
    merge (x :: NodeCount
x@(AccNodeCount StableSharingAcc
_ Int
_):[NodeCount]
xs) (y :: NodeCount
y@(ExpNodeCount StableSharingExp
_ Int
_):[NodeCount]
ys) = NodeCount
y NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge (NodeCount
xNodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
:[NodeCount]
xs) [NodeCount]
ys
    merge (x :: NodeCount
x@(ExpNodeCount StableSharingExp
_ Int
_):[NodeCount]
xs) (y :: NodeCount
y@(AccNodeCount StableSharingAcc
_ Int
_):[NodeCount]
ys) = NodeCount
x NodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
: [NodeCount] -> [NodeCount] -> [NodeCount]
merge [NodeCount]
xs (NodeCount
yNodeCount -> [NodeCount] -> [NodeCount]
forall a. a -> [a] -> [a]
:[NodeCount]
ys)

    (StableSharingAcc StableAccName arrs
_ (AvarSharing StableAccName arrs
_ ArraysR arrs
_)) pickNoneAvar :: StableSharingAcc -> StableSharingAcc -> StableSharingAcc
`pickNoneAvar` StableSharingAcc
sa2  = StableSharingAcc
sa2
    StableSharingAcc
sa1                                    `pickNoneAvar` StableSharingAcc
_sa2 = StableSharingAcc
sa1

    (StableSharingExp StableExpName t
_ (VarSharing StableExpName t
_ TypeR t
_))  pickNoneVar :: StableSharingExp -> StableSharingExp -> StableSharingExp
`pickNoneVar`  StableSharingExp
sa2  = StableSharingExp
sa2
    StableSharingExp
sa1                                    `pickNoneVar`  StableSharingExp
_sa2 = StableSharingExp
sa1

    -- As the StableSharingAccs do not pose a strict ordering, this cleanup
    -- step is needed. In this step, all pairs of AccNodes and ExpNodes
    -- that are of the same height are compared against each other. Without
    -- this step, duplicates may arise.
    --
    -- Note that while (+++) is morally symmetric, replacing `merge [x] y'
    -- with `merge y [x]' inside of `cleanup' won't check all required
    -- possibilities.
    --
    cleanup :: [NodeCount] -> [NodeCount]
cleanup = ([NodeCount] -> [NodeCount]) -> [[NodeCount]] -> [NodeCount]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((NodeCount -> [NodeCount] -> [NodeCount])
-> [NodeCount] -> [NodeCount] -> [NodeCount]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\NodeCount
x [NodeCount]
y -> [NodeCount] -> [NodeCount] -> [NodeCount]
merge [Item [NodeCount]
NodeCount
x] [NodeCount]
y) []) ([[NodeCount]] -> [NodeCount])
-> ([NodeCount] -> [[NodeCount]]) -> [NodeCount] -> [NodeCount]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeCount -> NodeCount -> Bool) -> [NodeCount] -> [[NodeCount]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy NodeCount -> NodeCount -> Bool
sameHeight
    sameHeight :: NodeCount -> NodeCount -> Bool
sameHeight (AccNodeCount StableSharingAcc
sa1 Int
_) (AccNodeCount StableSharingAcc
sa2 Int
_) = Bool -> Bool
not (StableSharingAcc
sa1 StableSharingAcc -> StableSharingAcc -> Bool
`higherSSA` StableSharingAcc
sa2) Bool -> Bool -> Bool
&& Bool -> Bool
not (StableSharingAcc
sa2 StableSharingAcc -> StableSharingAcc -> Bool
`higherSSA` StableSharingAcc
sa1)
    sameHeight (ExpNodeCount StableSharingExp
se1 Int
_) (ExpNodeCount StableSharingExp
se2 Int
_) = Bool -> Bool
not (StableSharingExp
se1 StableSharingExp -> StableSharingExp -> Bool
`higherSSE` StableSharingExp
se2) Bool -> Bool -> Bool
&& Bool -> Bool
not (StableSharingExp
se2 StableSharingExp -> StableSharingExp -> Bool
`higherSSE` StableSharingExp
se1)
    sameHeight NodeCount
_ NodeCount
_ = Bool
False


-- Build an initial environment for the tag values given in the first argument for traversing an
-- array expression.  The 'StableSharingAcc's for all tags /actually used/ in the expressions are
-- in the second argument. (Tags are not used if a bound variable has no usage occurrence.)
--
-- Bail out if any tag occurs multiple times as this indicates that the sharing of an argument
-- variable was not preserved and we cannot build an appropriate initial environment (c.f., comments
-- at 'determineScopesAcc'.
--
buildInitialEnvAcc
    :: HasCallStack
    => [Level]
    -> [StableSharingAcc]
    -> [StableSharingAcc]
buildInitialEnvAcc :: [Int] -> [StableSharingAcc] -> [StableSharingAcc]
buildInitialEnvAcc [Int]
tags [StableSharingAcc]
sas = (Int -> StableSharingAcc) -> [Int] -> [StableSharingAcc]
forall a b. (a -> b) -> [a] -> [b]
map ([StableSharingAcc] -> Int -> StableSharingAcc
lookupSA [StableSharingAcc]
sas) [Int]
tags
  where
    lookupSA :: [StableSharingAcc] -> Int -> StableSharingAcc
lookupSA [StableSharingAcc]
sas Int
tag1
      = case (StableSharingAcc -> Bool)
-> [StableSharingAcc] -> [StableSharingAcc]
forall a. (a -> Bool) -> [a] -> [a]
filter StableSharingAcc -> Bool
hasTag [StableSharingAcc]
sas of
          []   -> StableSharingAcc
noStableSharing    -- tag is not used in the analysed expression
          [Item [StableSharingAcc]
sa] -> Item [StableSharingAcc]
StableSharingAcc
sa                 -- tag has a unique occurrence
          [StableSharingAcc]
sas2 -> String -> StableSharingAcc
forall a. HasCallStack => String -> a
internalError (String
"Encountered duplicate 'ATag's\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ((StableSharingAcc -> String) -> [StableSharingAcc] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map StableSharingAcc -> String
showSA [StableSharingAcc]
sas2))
      where
        hasTag :: StableSharingAcc -> Bool
hasTag (StableSharingAcc StableAccName arrs
_ (AccSharing StableAccName arrs
_ (Atag ArraysR arrs
_ Int
tag2))) = Int
tag1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
tag2
        hasTag StableSharingAcc
sa
          = String -> Bool
forall a. HasCallStack => String -> a
internalError (String
"Encountered a node that is not a plain 'Atag'\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StableSharingAcc -> String
showSA StableSharingAcc
sa)

        noStableSharing :: StableSharingAcc
        noStableSharing :: StableSharingAcc
noStableSharing = StableAccName ()
-> SharingAcc ScopedAcc ScopedExp () -> StableSharingAcc
forall arrs.
StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
StableSharingAcc StableAccName ()
forall arrs. StableAccName arrs
noStableAccName (forall a. HasCallStack => a
forall (acc :: * -> *) (exp :: * -> *). SharingAcc acc exp ()
undefined :: SharingAcc acc exp ())

    showSA :: StableSharingAcc -> String
showSA (StableSharingAcc StableAccName arrs
_ (AccSharing  StableAccName arrs
sn PreSmartAcc ScopedAcc ScopedExp arrs
acc)) = Int -> String
forall a. Show a => a -> String
show (StableAccName arrs -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableAccName arrs
sn) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++
                                                       PreSmartAcc ScopedAcc ScopedExp arrs -> String
forall (acc :: * -> *) (exp :: * -> *) arrs.
PreSmartAcc acc exp arrs -> String
showPreAccOp PreSmartAcc ScopedAcc ScopedExp arrs
acc
    showSA (StableSharingAcc StableAccName arrs
_ (AvarSharing StableAccName arrs
sn ArraysR arrs
_))   = String
"AvarSharing " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (StableAccName arrs -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableAccName arrs
sn)
    showSA (StableSharingAcc StableAccName arrs
_ (AletSharing StableSharingAcc
sa ScopedAcc arrs
_))   = String
"AletSharing " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StableSharingAcc -> String
forall a. Show a => a -> String
show StableSharingAcc
sa String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"..."

-- Build an initial environment for the tag values given in the first argument for traversing a
-- scalar expression.  The 'StableSharingExp's for all tags /actually used/ in the expressions are
-- in the second argument. (Tags are not used if a bound variable has no usage occurrence.)
--
-- Bail out if any tag occurs multiple times as this indicates that the sharing of an argument
-- variable was not preserved and we cannot build an appropriate initial environment (c.f., comments
-- at 'determineScopesAcc'.
--
buildInitialEnvExp
    :: HasCallStack
    => [Level]
    -> [StableSharingExp]
    -> [StableSharingExp]
buildInitialEnvExp :: [Int] -> [StableSharingExp] -> [StableSharingExp]
buildInitialEnvExp [Int]
tags [StableSharingExp]
ses = (Int -> StableSharingExp) -> [Int] -> [StableSharingExp]
forall a b. (a -> b) -> [a] -> [b]
map ([StableSharingExp] -> Int -> StableSharingExp
lookupSE [StableSharingExp]
ses) [Int]
tags
  where
    lookupSE :: [StableSharingExp] -> Int -> StableSharingExp
lookupSE [StableSharingExp]
ses Int
tag1
      = case (StableSharingExp -> Bool)
-> [StableSharingExp] -> [StableSharingExp]
forall a. (a -> Bool) -> [a] -> [a]
filter StableSharingExp -> Bool
hasTag [StableSharingExp]
ses of
          []   -> StableSharingExp
noStableSharing    -- tag is not used in the analysed expression
          [Item [StableSharingExp]
se] -> Item [StableSharingExp]
StableSharingExp
se                 -- tag has a unique occurrence
          [StableSharingExp]
ses2 -> String -> StableSharingExp
forall a. HasCallStack => String -> a
internalError (String
"Encountered a duplicate 'Tag'\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ((StableSharingExp -> String) -> [StableSharingExp] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map StableSharingExp -> String
showSE [StableSharingExp]
ses2))
      where
        hasTag :: StableSharingExp -> Bool
hasTag (StableSharingExp StableExpName t
_ (ExpSharing StableExpName t
_ (Tag TypeR t
_ Int
tag2))) = Int
tag1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
tag2
        hasTag StableSharingExp
se
          = String -> Bool
forall a. HasCallStack => String -> a
internalError (String
"Encountered a node that is not a plain 'Tag'\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StableSharingExp -> String
showSE StableSharingExp
se)

        noStableSharing :: StableSharingExp
        noStableSharing :: StableSharingExp
noStableSharing = StableExpName ()
-> SharingExp ScopedAcc ScopedExp () -> StableSharingExp
forall t.
StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
StableSharingExp StableExpName ()
forall t. StableExpName t
noStableExpName (forall a. HasCallStack => a
forall (acc :: * -> *) (exp :: * -> *). SharingExp acc exp ()
undefined :: SharingExp acc exp ())

    showSE :: StableSharingExp -> String
showSE (StableSharingExp StableExpName t
_ (ExpSharing StableExpName t
sn PreSmartExp ScopedAcc ScopedExp t
exp)) = Int -> String
forall a. Show a => a -> String
show (StableExpName t -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableExpName t
sn) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++
                                                      PreSmartExp ScopedAcc ScopedExp t -> String
forall (acc :: * -> *) (exp :: * -> *) t.
PreSmartExp acc exp t -> String
showPreExpOp PreSmartExp ScopedAcc ScopedExp t
exp
    showSE (StableSharingExp StableExpName t
_ (VarSharing StableExpName t
sn TypeR t
_ ))  = String
"VarSharing " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (StableExpName t -> Int
forall t. StableNameHeight t -> Int
hashStableNameHeight StableExpName t
sn)
    showSE (StableSharingExp StableExpName t
_ (LetSharing StableSharingExp
se ScopedExp t
_ ))  = String
"LetSharing " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StableSharingExp -> String
forall a. Show a => a -> String
show StableSharingExp
se String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"..."

-- Determine whether a 'NodeCount' is for an 'Atag' or 'Tag', which represent free variables.
--
isFreeVar :: NodeCount -> Bool
isFreeVar :: NodeCount -> Bool
isFreeVar (AccNodeCount (StableSharingAcc StableAccName arrs
_ (AccSharing StableAccName arrs
_ (Atag ArraysR arrs
_ Int
_))) Int
_) = Bool
True
isFreeVar (ExpNodeCount (StableSharingExp StableExpName t
_ (ExpSharing StableExpName t
_ (Tag  TypeR t
_ Int
_))) Int
_) = Bool
True
isFreeVar NodeCount
_                                                               = Bool
False


-- Determine scope of shared subterms
-- ==================================

-- Determine the scopes of all variables representing shared subterms (Phase Two) in a bottom-up
-- sweep.  The first argument determines whether array computations are floated out of expressions
-- irrespective of whether they are shared or not — 'True' implies floating them out.
--
-- In addition to the AST with sharing information, yield the 'StableSharingAcc's for all free
-- variables of 'rootAcc', which are represented by 'Atag' leaves in the tree. They are in order of
-- the tag values — i.e., in the same order that they need to appear in an environment to use the
-- tag for indexing into that environment.
--
-- Precondition: there are only 'AvarSharing' and 'AccSharing' nodes in the argument.
--
determineScopesAcc
    :: HasCallStack
    => Config
    -> [Level]
    -> OccMap SmartAcc
    -> UnscopedAcc a
    -> (ScopedAcc a, [StableSharingAcc])
determineScopesAcc :: Config
-> [Int]
-> OccMap SmartAcc
-> UnscopedAcc a
-> (ScopedAcc a, [StableSharingAcc])
determineScopesAcc Config
config [Int]
fvs OccMap SmartAcc
accOccMap UnscopedAcc a
rootAcc
  = let (ScopedAcc a
sharingAcc, ([NodeCount]
counts, HashMap NodeName (HashSet NodeName)
_)) = Config
-> OccMap SmartAcc -> UnscopedAcc a -> (ScopedAcc a, NodeCounts)
forall a.
HasCallStack =>
Config
-> OccMap SmartAcc -> UnscopedAcc a -> (ScopedAcc a, NodeCounts)
determineScopesSharingAcc Config
config OccMap SmartAcc
accOccMap UnscopedAcc a
rootAcc
        unboundTrees :: [NodeCount]
unboundTrees              = (NodeCount -> Bool) -> [NodeCount] -> [NodeCount]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (NodeCount -> Bool) -> NodeCount -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeCount -> Bool
isFreeVar) [NodeCount]
counts
    in
    if (NodeCount -> Bool) -> [NodeCount] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all NodeCount -> Bool
isFreeVar [NodeCount]
counts
       then (ScopedAcc a
sharingAcc, HasCallStack => [Int] -> [StableSharingAcc] -> [StableSharingAcc]
[Int] -> [StableSharingAcc] -> [StableSharingAcc]
buildInitialEnvAcc [Int]
fvs [StableSharingAcc
sa | AccNodeCount StableSharingAcc
sa Int
_ <- [NodeCount]
counts])
       else String -> (ScopedAcc a, [StableSharingAcc])
forall a. HasCallStack => String -> a
internalError (String
"unbound shared subtrees" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [NodeCount] -> String
forall a. Show a => a -> String
show [NodeCount]
unboundTrees)


determineScopesSharingAcc
    :: HasCallStack
    => Config
    -> OccMap SmartAcc
    -> UnscopedAcc a
    -> (ScopedAcc a, NodeCounts)
determineScopesSharingAcc :: Config
-> OccMap SmartAcc -> UnscopedAcc a -> (ScopedAcc a, NodeCounts)
determineScopesSharingAcc Config
config OccMap SmartAcc
accOccMap = UnscopedAcc a -> (ScopedAcc a, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc
  where
    scopesAcc :: forall arrs. HasCallStack => UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
    scopesAcc :: UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc (UnscopedAcc [Int]
_ (AletSharing StableSharingAcc
_ UnscopedAcc arrs
_))
      = String -> (ScopedAcc arrs, NodeCounts)
forall a. HasCallStack => String -> a
internalError String
"unexpected 'AletSharing'"

    scopesAcc (UnscopedAcc [Int]
_ (AvarSharing StableAccName arrs
sn ArraysR arrs
tp))
      = ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs -> ScopedAcc arrs
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [] (StableAccName arrs
-> ArraysR arrs -> SharingAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs
AvarSharing StableAccName arrs
sn ArraysR arrs
tp), StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
forall arrs.
StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
StableSharingAcc StableAccName arrs
sn (StableAccName arrs
-> ArraysR arrs -> SharingAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs
AvarSharing StableAccName arrs
sn ArraysR arrs
tp) StableSharingAcc -> NodeCounts -> NodeCounts
`insertAccNode` NodeCounts
noNodeCounts)

    scopesAcc (UnscopedAcc [Int]
_ (AccSharing StableAccName arrs
sn PreSmartAcc UnscopedAcc RootExp arrs
pacc))
      = case PreSmartAcc UnscopedAcc RootExp arrs
pacc of
          Atag ArraysR arrs
tp Int
i               -> HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArraysR arrs -> Int -> PreSmartAcc ScopedAcc ScopedExp arrs
forall as (acc :: * -> *) (exp :: * -> *).
ArraysR as -> Int -> PreSmartAcc acc exp as
Atag ArraysR arrs
tp Int
i) NodeCounts
noNodeCounts
          Pipe ArraysR as
repr1 ArraysR bs
repr2 ArraysR arrs
repr3 SmartAcc as -> UnscopedAcc bs
afun1 SmartAcc bs -> UnscopedAcc arrs
afun2 UnscopedAcc as
acc
                                  -> let
                                       (SmartAcc as -> ScopedAcc bs
afun1', NodeCounts
accCount1) = (SmartAcc as -> UnscopedAcc bs)
-> (SmartAcc as -> ScopedAcc bs, NodeCounts)
forall a1 a2.
HasCallStack =>
(SmartAcc a1 -> UnscopedAcc a2)
-> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
scopesAfun1 SmartAcc as -> UnscopedAcc bs
afun1
                                       (SmartAcc bs -> ScopedAcc arrs
afun2', NodeCounts
accCount2) = (SmartAcc bs -> UnscopedAcc arrs)
-> (SmartAcc bs -> ScopedAcc arrs, NodeCounts)
forall a1 a2.
HasCallStack =>
(SmartAcc a1 -> UnscopedAcc a2)
-> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
scopesAfun1 SmartAcc bs -> UnscopedAcc arrs
afun2
                                       (ScopedAcc as
acc', NodeCounts
accCount3)   = UnscopedAcc as -> (ScopedAcc as, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc as
acc
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArraysR as
-> ArraysR bs
-> ArraysR arrs
-> (SmartAcc as -> ScopedAcc bs)
-> (SmartAcc bs -> ScopedAcc arrs)
-> ScopedAcc as
-> PreSmartAcc ScopedAcc ScopedExp arrs
forall as bs cs (acc :: * -> *) (exp :: * -> *).
ArraysR as
-> ArraysR bs
-> ArraysR cs
-> (SmartAcc as -> acc bs)
-> (SmartAcc bs -> acc cs)
-> acc as
-> PreSmartAcc acc exp cs
Pipe ArraysR as
repr1 ArraysR bs
repr2 ArraysR arrs
repr3 SmartAcc as -> ScopedAcc bs
afun1' SmartAcc bs -> ScopedAcc arrs
afun2' ScopedAcc as
acc')
                                                 (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)

          Aforeign ArraysR arrs
r asm (as -> arrs)
ff SmartAcc as -> SmartAcc arrs
afun UnscopedAcc as
acc  -> let
                                       (ScopedAcc as
acc', NodeCounts
accCount) = UnscopedAcc as -> (ScopedAcc as, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc as
acc
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArraysR arrs
-> asm (as -> arrs)
-> (SmartAcc as -> SmartAcc arrs)
-> ScopedAcc as
-> PreSmartAcc ScopedAcc ScopedExp arrs
forall (asm :: * -> *) bs as (acc :: * -> *) (exp :: * -> *).
Foreign asm =>
ArraysR bs
-> asm (as -> bs)
-> (SmartAcc as -> SmartAcc bs)
-> acc as
-> PreSmartAcc acc exp bs
Aforeign ArraysR arrs
r asm (as -> arrs)
ff SmartAcc as -> SmartAcc arrs
afun ScopedAcc as
acc') NodeCounts
accCount
          Acond RootExp PrimBool
e UnscopedAcc arrs
acc1 UnscopedAcc arrs
acc2       -> let
                                       (ScopedExp PrimBool
e'   , NodeCounts
accCount1) = RootExp PrimBool -> (ScopedExp PrimBool, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp RootExp PrimBool
e
                                       (ScopedAcc arrs
acc1', NodeCounts
accCount2) = UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs
acc1
                                       (ScopedAcc arrs
acc2', NodeCounts
accCount3) = UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs
acc2
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ScopedExp PrimBool
-> ScopedAcc arrs
-> ScopedAcc arrs
-> PreSmartAcc ScopedAcc ScopedExp arrs
forall (exp :: * -> *) (acc :: * -> *) as.
exp PrimBool -> acc as -> acc as -> PreSmartAcc acc exp as
Acond ScopedExp PrimBool
e' ScopedAcc arrs
acc1' ScopedAcc arrs
acc2')
                                                 (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)

          Awhile ArraysR arrs
repr SmartAcc arrs -> UnscopedAcc (Scalar PrimBool)
pred SmartAcc arrs -> UnscopedAcc arrs
iter UnscopedAcc arrs
init
                                  -> let
                                       (SmartAcc arrs -> ScopedAcc (Scalar PrimBool)
pred', NodeCounts
accCount1) = (SmartAcc arrs -> UnscopedAcc (Scalar PrimBool))
-> (SmartAcc arrs -> ScopedAcc (Scalar PrimBool), NodeCounts)
forall a1 a2.
HasCallStack =>
(SmartAcc a1 -> UnscopedAcc a2)
-> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
scopesAfun1 SmartAcc arrs -> UnscopedAcc (Scalar PrimBool)
pred
                                       (SmartAcc arrs -> ScopedAcc arrs
iter', NodeCounts
accCount2) = (SmartAcc arrs -> UnscopedAcc arrs)
-> (SmartAcc arrs -> ScopedAcc arrs, NodeCounts)
forall a1 a2.
HasCallStack =>
(SmartAcc a1 -> UnscopedAcc a2)
-> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
scopesAfun1 SmartAcc arrs -> UnscopedAcc arrs
iter
                                       (ScopedAcc arrs
init', NodeCounts
accCount3) = UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs
init
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArraysR arrs
-> (SmartAcc arrs -> ScopedAcc (Scalar PrimBool))
-> (SmartAcc arrs -> ScopedAcc arrs)
-> ScopedAcc arrs
-> PreSmartAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
ArraysR arrs
-> (SmartAcc arrs -> acc (Scalar PrimBool))
-> (SmartAcc arrs -> acc arrs)
-> acc arrs
-> PreSmartAcc acc exp arrs
Awhile ArraysR arrs
repr SmartAcc arrs -> ScopedAcc (Scalar PrimBool)
pred' SmartAcc arrs -> ScopedAcc arrs
iter' ScopedAcc arrs
init')
                                                 (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)

          PreSmartAcc UnscopedAcc RootExp arrs
Anil                    -> HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct PreSmartAcc ScopedAcc ScopedExp arrs
forall (acc :: * -> *) (exp :: * -> *). PreSmartAcc acc exp ()
Anil NodeCounts
noNodeCounts
          Apair UnscopedAcc arrs1
a1 UnscopedAcc arrs2
a2             -> let
                                       (ScopedAcc arrs1
a1', NodeCounts
accCount1) = UnscopedAcc arrs1 -> (ScopedAcc arrs1, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs1
a1
                                       (ScopedAcc arrs2
a2', NodeCounts
accCount2) = UnscopedAcc arrs2 -> (ScopedAcc arrs2, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs2
a2
                                     in
                                       HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ScopedAcc arrs1
-> ScopedAcc arrs2
-> PreSmartAcc ScopedAcc ScopedExp (arrs1, arrs2)
forall (acc :: * -> *) arrs1 arrs2 (exp :: * -> *).
acc arrs1 -> acc arrs2 -> PreSmartAcc acc exp (arrs1, arrs2)
Apair ScopedAcc arrs1
a1' ScopedAcc arrs2
a2') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2)
          Aprj PairIdx (arrs1, arrs2) arrs
ix UnscopedAcc (arrs1, arrs2)
a               -> (ScopedAcc (arrs1, arrs2) -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> UnscopedAcc (arrs1, arrs2) -> (ScopedAcc arrs, NodeCounts)
forall arrs'.
HasCallStack =>
(ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts)
travA (PairIdx (arrs1, arrs2) arrs
-> ScopedAcc (arrs1, arrs2) -> PreSmartAcc ScopedAcc ScopedExp arrs
forall arrs1 arrs2 arrs (acc :: * -> *) (exp :: * -> *).
PairIdx (arrs1, arrs2) arrs
-> acc (arrs1, arrs2) -> PreSmartAcc acc exp arrs
Aprj PairIdx (arrs1, arrs2) arrs
ix) UnscopedAcc (arrs1, arrs2)
a

          Use ArrayR (Array sh e)
repr Array sh e
arr            -> HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArrayR (Array sh e)
-> Array sh e -> PreSmartAcc ScopedAcc ScopedExp (Array sh e)
forall sh e (acc :: * -> *) (exp :: * -> *).
ArrayR (Array sh e)
-> Array sh e -> PreSmartAcc acc exp (Array sh e)
Use ArrayR (Array sh e)
repr Array sh e
arr) NodeCounts
noNodeCounts
          Unit TypeR e
tp RootExp e
e               -> let
                                       (ScopedExp e
e', NodeCounts
accCount) = RootExp e -> (ScopedExp e, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp RootExp e
e
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (TypeR e
-> ScopedExp e -> PreSmartAcc ScopedAcc ScopedExp (Scalar e)
forall e (exp :: * -> *) (acc :: * -> *).
TypeR e -> exp e -> PreSmartAcc acc exp (Scalar e)
Unit TypeR e
tp ScopedExp e
e') NodeCounts
accCount
          Generate ArrayR (Array sh e)
repr RootExp sh
sh SmartExp sh -> RootExp e
f      -> let
                                       (ScopedExp sh
sh', NodeCounts
accCount1) = RootExp sh -> (ScopedExp sh, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp RootExp sh
sh
                                       (SmartExp sh -> ScopedExp e
f' , NodeCounts
accCount2) = (SmartExp sh -> RootExp e)
-> (SmartExp sh -> ScopedExp e, NodeCounts)
forall e1 e2.
HasCallStack =>
(SmartExp e1 -> RootExp e2)
-> (SmartExp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 SmartExp sh -> RootExp e
f
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArrayR (Array sh e)
-> ScopedExp sh
-> (SmartExp sh -> ScopedExp e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh e)
forall sh e (exp :: * -> *) (acc :: * -> *).
ArrayR (Array sh e)
-> exp sh
-> (SmartExp sh -> exp e)
-> PreSmartAcc acc exp (Array sh e)
Generate ArrayR (Array sh e)
repr ScopedExp sh
sh' SmartExp sh -> ScopedExp e
f') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2)
          Reshape ShapeR sh
shr RootExp sh
sh UnscopedAcc (Array sh' e)
acc      -> (ScopedExp sh
 -> ScopedAcc (Array sh' e) -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp sh
-> UnscopedAcc (Array sh' e)
-> (ScopedAcc arrs, NodeCounts)
forall e arrs'.
HasCallStack =>
(ScopedExp e
 -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts)
travEA (ShapeR sh
-> ScopedExp sh
-> ScopedAcc (Array sh' e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh e)
forall sh (exp :: * -> *) (acc :: * -> *) sh' e.
ShapeR sh
-> exp sh -> acc (Array sh' e) -> PreSmartAcc acc exp (Array sh e)
Reshape ShapeR sh
shr) RootExp sh
sh UnscopedAcc (Array sh' e)
acc
          Replicate SliceIndex slix sl co sh
si RootExp slix
n UnscopedAcc (Array sl e)
acc      -> (ScopedExp slix
 -> ScopedAcc (Array sl e) -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp slix
-> UnscopedAcc (Array sl e)
-> (ScopedAcc arrs, NodeCounts)
forall e arrs'.
HasCallStack =>
(ScopedExp e
 -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts)
travEA (SliceIndex slix sl co sh
-> ScopedExp slix
-> ScopedAcc (Array sl e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh e)
forall slix sl co sh (exp :: * -> *) (acc :: * -> *) e.
SliceIndex slix sl co sh
-> exp slix -> acc (Array sl e) -> PreSmartAcc acc exp (Array sh e)
Replicate SliceIndex slix sl co sh
si) RootExp slix
n UnscopedAcc (Array sl e)
acc
          Slice SliceIndex slix sl co sh
si UnscopedAcc (Array sh e)
acc RootExp slix
i          -> (ScopedExp slix
 -> ScopedAcc (Array sh e) -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp slix
-> UnscopedAcc (Array sh e)
-> (ScopedAcc arrs, NodeCounts)
forall e arrs'.
HasCallStack =>
(ScopedExp e
 -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts)
travEA ((ScopedAcc (Array sh e)
 -> ScopedExp slix -> PreSmartAcc ScopedAcc ScopedExp (Array sl e))
-> ScopedExp slix
-> ScopedAcc (Array sh e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sl e)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((ScopedAcc (Array sh e)
  -> ScopedExp slix -> PreSmartAcc ScopedAcc ScopedExp (Array sl e))
 -> ScopedExp slix
 -> ScopedAcc (Array sh e)
 -> PreSmartAcc ScopedAcc ScopedExp (Array sl e))
-> (ScopedAcc (Array sh e)
    -> ScopedExp slix -> PreSmartAcc ScopedAcc ScopedExp (Array sl e))
-> ScopedExp slix
-> ScopedAcc (Array sh e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sl e)
forall a b. (a -> b) -> a -> b
$ SliceIndex slix sl co sh
-> ScopedAcc (Array sh e)
-> ScopedExp slix
-> PreSmartAcc ScopedAcc ScopedExp (Array sl e)
forall slix sl co sh (acc :: * -> *) e' (exp :: * -> *).
SliceIndex slix sl co sh
-> acc (Array sh e')
-> exp slix
-> PreSmartAcc acc exp (Array sl e')
Slice SliceIndex slix sl co sh
si) RootExp slix
i UnscopedAcc (Array sh e)
acc
          Map TypeR e
t1 TypeR e'
t2 SmartExp e -> RootExp e'
f UnscopedAcc (Array sh e)
acc         -> let
                                       (SmartExp e -> ScopedExp e'
f'  , NodeCounts
accCount1) = (SmartExp e -> RootExp e')
-> (SmartExp e -> ScopedExp e', NodeCounts)
forall e1 e2.
HasCallStack =>
(SmartExp e1 -> RootExp e2)
-> (SmartExp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 SmartExp e -> RootExp e'
f
                                       (ScopedAcc (Array sh e)
acc', NodeCounts
accCount2) = UnscopedAcc (Array sh e) -> (ScopedAcc (Array sh e), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc (Array sh e)
acc
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (TypeR e
-> TypeR e'
-> (SmartExp e -> ScopedExp e')
-> ScopedAcc (Array sh e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh e')
forall e e' (exp :: * -> *) (acc :: * -> *) sh.
TypeR e
-> TypeR e'
-> (SmartExp e -> exp e')
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh e')
Map TypeR e
t1 TypeR e'
t2 SmartExp e -> ScopedExp e'
f' ScopedAcc (Array sh e)
acc') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2)
          ZipWith TypeR e1
t1 TypeR e2
t2 TypeR e3
t3 SmartExp e1 -> SmartExp e2 -> RootExp e3
f UnscopedAcc (Array sh e1)
acc1 UnscopedAcc (Array sh e2)
acc2
                                  -> ((SmartExp e1 -> SmartExp e2 -> ScopedExp e3)
 -> ScopedAcc (Array sh e1)
 -> ScopedAcc (Array sh e2)
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> UnscopedAcc (Array sh e1)
-> UnscopedAcc (Array sh e2)
-> (ScopedAcc arrs, NodeCounts)
forall a b c arrs1 arrs2.
HasCallStack =>
((SmartExp a -> SmartExp b -> ScopedExp c)
 -> ScopedAcc arrs1
 -> ScopedAcc arrs2
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> UnscopedAcc arrs1
-> UnscopedAcc arrs2
-> (ScopedAcc arrs, NodeCounts)
travF2A2 (TypeR e1
-> TypeR e2
-> TypeR e3
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3)
-> ScopedAcc (Array sh e1)
-> ScopedAcc (Array sh e2)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh e3)
forall e1 e2 e3 (exp :: * -> *) (acc :: * -> *) sh.
TypeR e1
-> TypeR e2
-> TypeR e3
-> (SmartExp e1 -> SmartExp e2 -> exp e3)
-> acc (Array sh e1)
-> acc (Array sh e2)
-> PreSmartAcc acc exp (Array sh e3)
ZipWith TypeR e1
t1 TypeR e2
t2 TypeR e3
t3) SmartExp e1 -> SmartExp e2 -> RootExp e3
f UnscopedAcc (Array sh e1)
acc1 UnscopedAcc (Array sh e2)
acc2
          Fold TypeR e
tp SmartExp e -> SmartExp e -> RootExp e
f Maybe (RootExp e)
z UnscopedAcc (Array (sh, Int) e)
acc         -> ((SmartExp e -> SmartExp e -> ScopedExp e)
 -> Maybe (ScopedExp e)
 -> ScopedAcc (Array (sh, Int) e)
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp e -> SmartExp e -> RootExp e)
-> Maybe (RootExp e)
-> UnscopedAcc (Array (sh, Int) e)
-> (ScopedAcc arrs, NodeCounts)
forall a b c e arrs'.
HasCallStack =>
((SmartExp a -> SmartExp b -> ScopedExp c)
 -> Maybe (ScopedExp e)
 -> ScopedAcc arrs'
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> Maybe (RootExp e)
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2MEA (TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Maybe (ScopedExp e)
-> ScopedAcc (Array (sh, Int) e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh e)
forall e (exp :: * -> *) (acc :: * -> *) i.
TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (i, Int) e)
-> PreSmartAcc acc exp (Array i e)
Fold TypeR e
tp) SmartExp e -> SmartExp e -> RootExp e
f Maybe (RootExp e)
z UnscopedAcc (Array (sh, Int) e)
acc
          FoldSeg IntegralType i
i TypeR e
tp SmartExp e -> SmartExp e -> RootExp e
f Maybe (RootExp e)
z UnscopedAcc (Array (sh, Int) e)
acc1 UnscopedAcc (Segments i)
acc2 -> let
                                       (SmartExp e -> SmartExp e -> ScopedExp e
f'   , NodeCounts
accCount1)  = (SmartExp e -> SmartExp e -> RootExp e)
-> (SmartExp e -> SmartExp e -> ScopedExp e, NodeCounts)
forall e1 e2 e3.
HasCallStack =>
(SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 SmartExp e -> SmartExp e -> RootExp e
f
                                       (Maybe (ScopedExp e)
z'   , NodeCounts
accCount2)  = Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
forall e.
HasCallStack =>
Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
travME Maybe (RootExp e)
z
                                       (ScopedAcc (Array (sh, Int) e)
acc1', NodeCounts
accCount3)  = UnscopedAcc (Array (sh, Int) e)
-> (ScopedAcc (Array (sh, Int) e), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc (Array (sh, Int) e)
acc1
                                       (ScopedAcc (Segments i)
acc2', NodeCounts
accCount4)  = UnscopedAcc (Segments i) -> (ScopedAcc (Segments i), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc (Segments i)
acc2
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (IntegralType i
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Maybe (ScopedExp e)
-> ScopedAcc (Array (sh, Int) e)
-> ScopedAcc (Segments i)
-> PreSmartAcc ScopedAcc ScopedExp (Array (sh, Int) e)
forall i e (exp :: * -> *) (acc :: * -> *) sh.
IntegralType i
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> acc (Segments i)
-> PreSmartAcc acc exp (Array (sh, Int) e)
FoldSeg IntegralType i
i TypeR e
tp SmartExp e -> SmartExp e -> ScopedExp e
f' Maybe (ScopedExp e)
z' ScopedAcc (Array (sh, Int) e)
acc1' ScopedAcc (Segments i)
acc2')
                                       (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount4)
          Scan Direction
d TypeR e
tp SmartExp e -> SmartExp e -> RootExp e
f Maybe (RootExp e)
z UnscopedAcc (Array (sh, Int) e)
acc       -> ((SmartExp e -> SmartExp e -> ScopedExp e)
 -> Maybe (ScopedExp e)
 -> ScopedAcc (Array (sh, Int) e)
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp e -> SmartExp e -> RootExp e)
-> Maybe (RootExp e)
-> UnscopedAcc (Array (sh, Int) e)
-> (ScopedAcc arrs, NodeCounts)
forall a b c e arrs'.
HasCallStack =>
((SmartExp a -> SmartExp b -> ScopedExp c)
 -> Maybe (ScopedExp e)
 -> ScopedAcc arrs'
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> Maybe (RootExp e)
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2MEA (Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> Maybe (ScopedExp e)
-> ScopedAcc (Array (sh, Int) e)
-> PreSmartAcc ScopedAcc ScopedExp (Array (sh, Int) e)
forall e (exp :: * -> *) (acc :: * -> *) e.
Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (e, Int) e)
-> PreSmartAcc acc exp (Array (e, Int) e)
Scan Direction
d TypeR e
tp) SmartExp e -> SmartExp e -> RootExp e
f Maybe (RootExp e)
z UnscopedAcc (Array (sh, Int) e)
acc
          Scan' Direction
d TypeR e
tp SmartExp e -> SmartExp e -> RootExp e
f RootExp e
z UnscopedAcc (Array (sh, Int) e)
acc      -> ((SmartExp e -> SmartExp e -> ScopedExp e)
 -> ScopedExp e
 -> ScopedAcc (Array (sh, Int) e)
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp e -> SmartExp e -> RootExp e)
-> RootExp e
-> UnscopedAcc (Array (sh, Int) e)
-> (ScopedAcc arrs, NodeCounts)
forall a b c e arrs'.
HasCallStack =>
((SmartExp a -> SmartExp b -> ScopedExp c)
 -> ScopedExp e
 -> ScopedAcc arrs'
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> RootExp e
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2EA (Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> ScopedExp e
-> ScopedAcc (Array (sh, Int) e)
-> PreSmartAcc ScopedAcc ScopedExp (Array (sh, Int) e, Array sh e)
forall e (exp :: * -> *) (acc :: * -> *) sh.
Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> exp e
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e)
Scan' Direction
d TypeR e
tp) SmartExp e -> SmartExp e -> RootExp e
f RootExp e
z UnscopedAcc (Array (sh, Int) e)
acc
          Permute ArrayR (Array sh e)
repr SmartExp e -> SmartExp e -> RootExp e
fc UnscopedAcc (Array sh' e)
acc1 SmartExp sh -> RootExp (PrimMaybe sh')
fp UnscopedAcc (Array sh e)
acc2
                                  -> let
                                       (SmartExp e -> SmartExp e -> ScopedExp e
fc'  , NodeCounts
accCount1) = (SmartExp e -> SmartExp e -> RootExp e)
-> (SmartExp e -> SmartExp e -> ScopedExp e, NodeCounts)
forall e1 e2 e3.
HasCallStack =>
(SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 SmartExp e -> SmartExp e -> RootExp e
fc
                                       (ScopedAcc (Array sh' e)
acc1', NodeCounts
accCount2) = UnscopedAcc (Array sh' e) -> (ScopedAcc (Array sh' e), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc (Array sh' e)
acc1
                                       (SmartExp sh -> ScopedExp (PrimMaybe sh')
fp'  , NodeCounts
accCount3) = (SmartExp sh -> RootExp (PrimMaybe sh'))
-> (SmartExp sh -> ScopedExp (PrimMaybe sh'), NodeCounts)
forall e1 e2.
HasCallStack =>
(SmartExp e1 -> RootExp e2)
-> (SmartExp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 SmartExp sh -> RootExp (PrimMaybe sh')
fp
                                       (ScopedAcc (Array sh e)
acc2', NodeCounts
accCount4) = UnscopedAcc (Array sh e) -> (ScopedAcc (Array sh e), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc (Array sh e)
acc2
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ArrayR (Array sh e)
-> (SmartExp e -> SmartExp e -> ScopedExp e)
-> ScopedAcc (Array sh' e)
-> (SmartExp sh -> ScopedExp (PrimMaybe sh'))
-> ScopedAcc (Array sh e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh' e)
forall sh e (exp :: * -> *) (acc :: * -> *) sh'.
ArrayR (Array sh e)
-> (SmartExp e -> SmartExp e -> exp e)
-> acc (Array sh' e)
-> (SmartExp sh -> exp (PrimMaybe sh'))
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Permute ArrayR (Array sh e)
repr SmartExp e -> SmartExp e -> ScopedExp e
fc' ScopedAcc (Array sh' e)
acc1' SmartExp sh -> ScopedExp (PrimMaybe sh')
fp' ScopedAcc (Array sh e)
acc2')
                                       (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount4)
          Backpermute ShapeR sh'
shr RootExp sh'
sh SmartExp sh' -> RootExp sh
fp UnscopedAcc (Array sh e)
acc
                                  -> let
                                       (ScopedExp sh'
sh' , NodeCounts
accCount1) = RootExp sh' -> (ScopedExp sh', NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp  RootExp sh'
sh
                                       (SmartExp sh' -> ScopedExp sh
fp' , NodeCounts
accCount2) = (SmartExp sh' -> RootExp sh)
-> (SmartExp sh' -> ScopedExp sh, NodeCounts)
forall e1 e2.
HasCallStack =>
(SmartExp e1 -> RootExp e2)
-> (SmartExp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 SmartExp sh' -> RootExp sh
fp
                                       (ScopedAcc (Array sh e)
acc', NodeCounts
accCount3) = UnscopedAcc (Array sh e) -> (ScopedAcc (Array sh e), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc (Array sh e)
acc
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ShapeR sh'
-> ScopedExp sh'
-> (SmartExp sh' -> ScopedExp sh)
-> ScopedAcc (Array sh e)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh' e)
forall sh' (exp :: * -> *) sh (acc :: * -> *) e.
ShapeR sh'
-> exp sh'
-> (SmartExp sh' -> exp sh)
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Backpermute ShapeR sh'
shr ScopedExp sh'
sh' SmartExp sh' -> ScopedExp sh
fp' ScopedAcc (Array sh e)
acc')
                                       (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          Stencil StencilR sh a stencil
sr TypeR b
tp SmartExp stencil -> RootExp b
st PreBoundary UnscopedAcc RootExp (Array sh a)
bnd UnscopedAcc (Array sh a)
acc      -> let
                                       (SmartExp stencil -> ScopedExp b
st' , NodeCounts
accCount1) = UnscopedAcc (Array sh a)
-> (SmartExp stencil -> RootExp b)
-> (SmartExp stencil -> ScopedExp b, NodeCounts)
forall sh e1 e2 stencil.
HasCallStack =>
UnscopedAcc (Array sh e1)
-> (stencil -> RootExp e2) -> (stencil -> ScopedExp e2, NodeCounts)
scopesStencil1 UnscopedAcc (Array sh a)
acc SmartExp stencil -> RootExp b
st
                                       (PreBoundary ScopedAcc ScopedExp (Array sh a)
bnd', NodeCounts
accCount2) = PreBoundary UnscopedAcc RootExp (Array sh a)
-> (PreBoundary ScopedAcc ScopedExp (Array sh a), NodeCounts)
forall t.
HasCallStack =>
PreBoundary UnscopedAcc RootExp t
-> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
scopesBoundary PreBoundary UnscopedAcc RootExp (Array sh a)
bnd
                                       (ScopedAcc (Array sh a)
acc', NodeCounts
accCount3) = UnscopedAcc (Array sh a) -> (ScopedAcc (Array sh a), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc (Array sh a)
acc
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (StencilR sh a stencil
-> TypeR b
-> (SmartExp stencil -> ScopedExp b)
-> PreBoundary ScopedAcc ScopedExp (Array sh a)
-> ScopedAcc (Array sh a)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh b)
forall sh a stencil sh (exp :: * -> *) (acc :: * -> *).
StencilR sh a stencil
-> TypeR sh
-> (SmartExp stencil -> exp sh)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreSmartAcc acc exp (Array sh sh)
Stencil StencilR sh a stencil
sr TypeR b
tp SmartExp stencil -> ScopedExp b
st' PreBoundary ScopedAcc ScopedExp (Array sh a)
bnd' ScopedAcc (Array sh a)
acc') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          Stencil2 StencilR sh a stencil1
s1 StencilR sh b stencil2
s2 TypeR c
tp SmartExp stencil1 -> SmartExp stencil2 -> RootExp c
st PreBoundary UnscopedAcc RootExp (Array sh a)
bnd1 UnscopedAcc (Array sh a)
acc1 PreBoundary UnscopedAcc RootExp (Array sh b)
bnd2 UnscopedAcc (Array sh b)
acc2
                                  -> let
                                       (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c
st'  , NodeCounts
accCount1) = UnscopedAcc (Array sh a)
-> UnscopedAcc (Array sh b)
-> (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c)
-> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c,
    NodeCounts)
forall sh e1 e2 e3 stencil1 stencil2.
HasCallStack =>
UnscopedAcc (Array sh e1)
-> UnscopedAcc (Array sh e2)
-> (stencil1 -> stencil2 -> RootExp e3)
-> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts)
scopesStencil2 UnscopedAcc (Array sh a)
acc1 UnscopedAcc (Array sh b)
acc2 SmartExp stencil1 -> SmartExp stencil2 -> RootExp c
st
                                       (PreBoundary ScopedAcc ScopedExp (Array sh a)
bnd1', NodeCounts
accCount2) = PreBoundary UnscopedAcc RootExp (Array sh a)
-> (PreBoundary ScopedAcc ScopedExp (Array sh a), NodeCounts)
forall t.
HasCallStack =>
PreBoundary UnscopedAcc RootExp t
-> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
scopesBoundary PreBoundary UnscopedAcc RootExp (Array sh a)
bnd1
                                       (ScopedAcc (Array sh a)
acc1', NodeCounts
accCount3) = UnscopedAcc (Array sh a) -> (ScopedAcc (Array sh a), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc (Array sh a)
acc1
                                       (PreBoundary ScopedAcc ScopedExp (Array sh b)
bnd2', NodeCounts
accCount4) = PreBoundary UnscopedAcc RootExp (Array sh b)
-> (PreBoundary ScopedAcc ScopedExp (Array sh b), NodeCounts)
forall t.
HasCallStack =>
PreBoundary UnscopedAcc RootExp t
-> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
scopesBoundary PreBoundary UnscopedAcc RootExp (Array sh b)
bnd2
                                       (ScopedAcc (Array sh b)
acc2', NodeCounts
accCount5) = UnscopedAcc (Array sh b) -> (ScopedAcc (Array sh b), NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc (Array sh b)
acc2
                                     in
                                     HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
-> PreBoundary ScopedAcc ScopedExp (Array sh a)
-> ScopedAcc (Array sh a)
-> PreBoundary ScopedAcc ScopedExp (Array sh b)
-> ScopedAcc (Array sh b)
-> PreSmartAcc ScopedAcc ScopedExp (Array sh c)
forall sh a stencil1 b stencil2 c (exp :: * -> *) (acc :: * -> *).
StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> (SmartExp stencil1 -> SmartExp stencil2 -> exp c)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreBoundary acc exp (Array sh b)
-> acc (Array sh b)
-> PreSmartAcc acc exp (Array sh c)
Stencil2 StencilR sh a stencil1
s1 StencilR sh b stencil2
s2 TypeR c
tp SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c
st' PreBoundary ScopedAcc ScopedExp (Array sh a)
bnd1' ScopedAcc (Array sh a)
acc1' PreBoundary ScopedAcc ScopedExp (Array sh b)
bnd2' ScopedAcc (Array sh b)
acc2')
                                       (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount4 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount5)
          -- Collect seq             -> let
          --                              (seq', accCount1) = scopesSeq seq
          --                            in
          --                            reconstruct (Collect seq') accCount1

      where
        travEA :: HasCallStack
               => (ScopedExp e -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
               -> RootExp e
               -> UnscopedAcc arrs'
               -> (ScopedAcc arrs, NodeCounts)
        travEA :: (ScopedExp e
 -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts)
travEA ScopedExp e
-> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs
c RootExp e
e UnscopedAcc arrs'
acc = HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ScopedExp e
-> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs
c ScopedExp e
e' ScopedAcc arrs'
acc') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2)
          where
            (ScopedExp e
e'  , NodeCounts
accCount1) = RootExp e -> (ScopedExp e, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp RootExp e
e
            (ScopedAcc arrs'
acc', NodeCounts
accCount2) = UnscopedAcc arrs' -> (ScopedAcc arrs', NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs'
acc

        travF2EA
            :: HasCallStack
            => ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedExp e -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
            -> (SmartExp a -> SmartExp b -> RootExp c)
            -> RootExp e
            -> UnscopedAcc arrs'
            -> (ScopedAcc arrs, NodeCounts)
        travF2EA :: ((SmartExp a -> SmartExp b -> ScopedExp c)
 -> ScopedExp e
 -> ScopedAcc arrs'
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> RootExp e
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2EA (SmartExp a -> SmartExp b -> ScopedExp c)
-> ScopedExp e
-> ScopedAcc arrs'
-> PreSmartAcc ScopedAcc ScopedExp arrs
c SmartExp a -> SmartExp b -> RootExp c
f RootExp e
e UnscopedAcc arrs'
acc = HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct ((SmartExp a -> SmartExp b -> ScopedExp c)
-> ScopedExp e
-> ScopedAcc arrs'
-> PreSmartAcc ScopedAcc ScopedExp arrs
c SmartExp a -> SmartExp b -> ScopedExp c
f' ScopedExp e
e' ScopedAcc arrs'
acc') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          where
            (SmartExp a -> SmartExp b -> ScopedExp c
f'  , NodeCounts
accCount1) = (SmartExp a -> SmartExp b -> RootExp c)
-> (SmartExp a -> SmartExp b -> ScopedExp c, NodeCounts)
forall e1 e2 e3.
HasCallStack =>
(SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 SmartExp a -> SmartExp b -> RootExp c
f
            (ScopedExp e
e'  , NodeCounts
accCount2) = RootExp e -> (ScopedExp e, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp  RootExp e
e
            (ScopedAcc arrs'
acc', NodeCounts
accCount3) = UnscopedAcc arrs' -> (ScopedAcc arrs', NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc arrs'
acc

        travF2MEA
            :: HasCallStack
            => ((SmartExp a -> SmartExp b -> ScopedExp c) -> Maybe (ScopedExp e) -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
            -> (SmartExp a -> SmartExp b -> RootExp c)
            -> Maybe (RootExp e)
            -> UnscopedAcc arrs'
            -> (ScopedAcc arrs, NodeCounts)
        travF2MEA :: ((SmartExp a -> SmartExp b -> ScopedExp c)
 -> Maybe (ScopedExp e)
 -> ScopedAcc arrs'
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> Maybe (RootExp e)
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2MEA (SmartExp a -> SmartExp b -> ScopedExp c)
-> Maybe (ScopedExp e)
-> ScopedAcc arrs'
-> PreSmartAcc ScopedAcc ScopedExp arrs
c SmartExp a -> SmartExp b -> RootExp c
f Maybe (RootExp e)
e UnscopedAcc arrs'
acc = HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct ((SmartExp a -> SmartExp b -> ScopedExp c)
-> Maybe (ScopedExp e)
-> ScopedAcc arrs'
-> PreSmartAcc ScopedAcc ScopedExp arrs
c SmartExp a -> SmartExp b -> ScopedExp c
f' Maybe (ScopedExp e)
e' ScopedAcc arrs'
acc') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          where
            (SmartExp a -> SmartExp b -> ScopedExp c
f'  , NodeCounts
accCount1) = (SmartExp a -> SmartExp b -> RootExp c)
-> (SmartExp a -> SmartExp b -> ScopedExp c, NodeCounts)
forall e1 e2 e3.
HasCallStack =>
(SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 SmartExp a -> SmartExp b -> RootExp c
f
            (Maybe (ScopedExp e)
e'  , NodeCounts
accCount2) = Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
forall e.
HasCallStack =>
Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
travME Maybe (RootExp e)
e
            (ScopedAcc arrs'
acc', NodeCounts
accCount3) = UnscopedAcc arrs' -> (ScopedAcc arrs', NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc arrs'
acc

        travME :: HasCallStack => Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
        travME :: Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
travME Maybe (RootExp e)
Nothing  = (Maybe (ScopedExp e)
forall a. Maybe a
Nothing, NodeCounts
noNodeCounts)
        travME (Just RootExp e
e) = (ScopedExp e -> Maybe (ScopedExp e)
forall a. a -> Maybe a
Just ScopedExp e
e', NodeCounts
c)
          where (ScopedExp e
e', NodeCounts
c) = RootExp e -> (ScopedExp e, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp RootExp e
e

        travF2A2
            :: HasCallStack
            => ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedAcc arrs1 -> ScopedAcc arrs2 -> PreSmartAcc ScopedAcc ScopedExp arrs)
            -> (SmartExp a -> SmartExp b -> RootExp c)
            -> UnscopedAcc arrs1
            -> UnscopedAcc arrs2
            -> (ScopedAcc arrs, NodeCounts)
        travF2A2 :: ((SmartExp a -> SmartExp b -> ScopedExp c)
 -> ScopedAcc arrs1
 -> ScopedAcc arrs2
 -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> (SmartExp a -> SmartExp b -> RootExp c)
-> UnscopedAcc arrs1
-> UnscopedAcc arrs2
-> (ScopedAcc arrs, NodeCounts)
travF2A2 (SmartExp a -> SmartExp b -> ScopedExp c)
-> ScopedAcc arrs1
-> ScopedAcc arrs2
-> PreSmartAcc ScopedAcc ScopedExp arrs
c SmartExp a -> SmartExp b -> RootExp c
f UnscopedAcc arrs1
acc1 UnscopedAcc arrs2
acc2 = HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct ((SmartExp a -> SmartExp b -> ScopedExp c)
-> ScopedAcc arrs1
-> ScopedAcc arrs2
-> PreSmartAcc ScopedAcc ScopedExp arrs
c SmartExp a -> SmartExp b -> ScopedExp c
f' ScopedAcc arrs1
acc1' ScopedAcc arrs2
acc2')
                                             (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          where
            (SmartExp a -> SmartExp b -> ScopedExp c
f'   , NodeCounts
accCount1) = (SmartExp a -> SmartExp b -> RootExp c)
-> (SmartExp a -> SmartExp b -> ScopedExp c, NodeCounts)
forall e1 e2 e3.
HasCallStack =>
(SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 SmartExp a -> SmartExp b -> RootExp c
f
            (ScopedAcc arrs1
acc1', NodeCounts
accCount2) = UnscopedAcc arrs1 -> (ScopedAcc arrs1, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc arrs1
acc1
            (ScopedAcc arrs2
acc2', NodeCounts
accCount3) = UnscopedAcc arrs2 -> (ScopedAcc arrs2, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc  UnscopedAcc arrs2
acc2

        travA :: HasCallStack
              => (ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
              -> UnscopedAcc arrs'
              -> (ScopedAcc arrs, NodeCounts)
        travA :: (ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
-> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts)
travA ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs
c UnscopedAcc arrs'
acc = HasCallStack =>
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct (ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs
c ScopedAcc arrs'
acc') NodeCounts
accCount
          where
            (ScopedAcc arrs'
acc', NodeCounts
accCount) = UnscopedAcc arrs' -> (ScopedAcc arrs', NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc arrs'
acc

          -- Occurrence count of the currently processed node
        accOccCount :: Int
accOccCount = let StableNameHeight StableName (SmartAcc arrs)
sn' Int
_ = StableAccName arrs
sn
                      in
                      OccMap SmartAcc -> StableASTName SmartAcc -> Int
forall (c :: * -> *). OccMap c -> StableASTName c -> Int
lookupWithASTName OccMap SmartAcc
accOccMap (StableName (SmartAcc arrs) -> StableASTName SmartAcc
forall (c :: * -> *) t. StableName (c t) -> StableASTName c
StableASTName StableName (SmartAcc arrs)
sn')

        -- Reconstruct the current tree node.
        --
        -- * If the current node is being shared ('accOccCount > 1'), replace it by a 'AvarSharing'
        --   node and float the shared subtree out wrapped in a 'NodeCounts' value.
        -- * If the current node is not shared, reconstruct it in place.
        -- * Special case for free variables ('Atag'): Replace the tree by a sharing variable and
        --   float the 'Atag' out in a 'NodeCounts' value.  This is independent of the number of
        --   occurrences.
        --
        -- In either case, any completed 'NodeCounts' are injected as bindings using 'AletSharing'
        -- node.
        --
        reconstruct
            :: HasCallStack
            => PreSmartAcc ScopedAcc ScopedExp arrs
            -> NodeCounts
            -> (ScopedAcc arrs, NodeCounts)
        reconstruct :: PreSmartAcc ScopedAcc ScopedExp arrs
-> NodeCounts -> (ScopedAcc arrs, NodeCounts)
reconstruct newAcc :: PreSmartAcc ScopedAcc ScopedExp arrs
newAcc@(Atag ArraysR arrs
tp Int
_) NodeCounts
_subCount
              -- free variable => replace by a sharing variable regardless of the number of
              -- occurrences
          = let thisCount :: NodeCounts
thisCount = StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
forall arrs.
StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
StableSharingAcc StableAccName arrs
sn (StableAccName arrs
-> PreSmartAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs
-> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs
AccSharing StableAccName arrs
sn PreSmartAcc ScopedAcc ScopedExp arrs
newAcc) StableSharingAcc -> NodeCounts -> NodeCounts
`insertAccNode` NodeCounts
noNodeCounts
            in
            String
-> String
-> (ScopedAcc arrs, NodeCounts)
-> (ScopedAcc arrs, NodeCounts)
forall a. String -> String -> a -> a
tracePure String
"FREE" (NodeCounts -> String
forall a. Show a => a -> String
show NodeCounts
thisCount)
            ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs -> ScopedAcc arrs
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [] (StableAccName arrs
-> ArraysR arrs -> SharingAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs
AvarSharing StableAccName arrs
sn ArraysR arrs
tp), NodeCounts
thisCount)
        reconstruct PreSmartAcc ScopedAcc ScopedExp arrs
newAcc NodeCounts
subCount
              -- shared subtree => replace by a sharing variable (if 'recoverAccSharing' enabled)
          | Int
accOccCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 Bool -> Bool -> Bool
&& Flag
acc_sharing Flag -> BitSet Word32 Flag -> Bool
forall a c. (Enum a, Bits c) => a -> BitSet c a -> Bool
`member` Config -> BitSet Word32 Flag
options Config
config
          = let allCount :: NodeCounts
allCount = (StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
forall arrs.
StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
StableSharingAcc StableAccName arrs
sn SharingAcc ScopedAcc ScopedExp arrs
sharingAcc StableSharingAcc -> NodeCounts -> NodeCounts
`insertAccNode` NodeCounts
newCount)
            in
            String
-> String
-> (ScopedAcc arrs, NodeCounts)
-> (ScopedAcc arrs, NodeCounts)
forall a. String -> String -> a -> a
tracePure (String
"SHARED" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
completed) (NodeCounts -> String
forall a. Show a => a -> String
show NodeCounts
allCount)
            ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs -> ScopedAcc arrs
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [] (StableAccName arrs
-> ArraysR arrs -> SharingAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs
AvarSharing StableAccName arrs
sn (ArraysR arrs -> SharingAcc ScopedAcc ScopedExp arrs)
-> ArraysR arrs -> SharingAcc ScopedAcc ScopedExp arrs
forall a b. (a -> b) -> a -> b
$ PreSmartAcc ScopedAcc ScopedExp arrs -> ArraysR arrs
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR PreSmartAcc ScopedAcc ScopedExp arrs
newAcc), NodeCounts
allCount)
              -- neither shared nor free variable => leave it as it is
          | Bool
otherwise
          = String
-> String
-> (ScopedAcc arrs, NodeCounts)
-> (ScopedAcc arrs, NodeCounts)
forall a. String -> String -> a -> a
tracePure (String
"Normal" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
completed) (NodeCounts -> String
forall a. Show a => a -> String
show NodeCounts
newCount)
            ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs -> ScopedAcc arrs
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [] SharingAcc ScopedAcc ScopedExp arrs
sharingAcc, NodeCounts
newCount)
          where
              -- Determine the bindings that need to be attached to the current node...
            (NodeCounts
newCount, [StableSharingAcc]
bindHere) = NodeCounts -> (NodeCounts, [StableSharingAcc])
filterCompleted NodeCounts
subCount

              -- ...and wrap them in 'AletSharing' constructors
            lets :: SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
lets       = ((SharingAcc ScopedAcc ScopedExp arrs
  -> SharingAcc ScopedAcc ScopedExp arrs)
 -> (SharingAcc ScopedAcc ScopedExp arrs
     -> SharingAcc ScopedAcc ScopedExp arrs)
 -> SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> (SharingAcc ScopedAcc ScopedExp arrs
    -> SharingAcc ScopedAcc ScopedExp arrs)
-> [SharingAcc ScopedAcc ScopedExp arrs
    -> SharingAcc ScopedAcc ScopedExp arrs]
-> SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (((SharingAcc ScopedAcc ScopedExp arrs
  -> SharingAcc ScopedAcc ScopedExp arrs)
 -> (SharingAcc ScopedAcc ScopedExp arrs
     -> SharingAcc ScopedAcc ScopedExp arrs)
 -> SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> (SharingAcc ScopedAcc ScopedExp arrs
    -> SharingAcc ScopedAcc ScopedExp arrs)
-> (SharingAcc ScopedAcc ScopedExp arrs
    -> SharingAcc ScopedAcc ScopedExp arrs)
-> SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall a b c. (a -> b -> c) -> b -> a -> c
flip (SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> (SharingAcc ScopedAcc ScopedExp arrs
    -> SharingAcc ScopedAcc ScopedExp arrs)
-> SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.)) SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall a. a -> a
id ([SharingAcc ScopedAcc ScopedExp arrs
  -> SharingAcc ScopedAcc ScopedExp arrs]
 -> SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> ([StableSharingAcc]
    -> [SharingAcc ScopedAcc ScopedExp arrs
        -> SharingAcc ScopedAcc ScopedExp arrs])
-> [StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StableSharingAcc
 -> SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> [StableSharingAcc]
-> [SharingAcc ScopedAcc ScopedExp arrs
    -> SharingAcc ScopedAcc ScopedExp arrs]
forall a b. (a -> b) -> [a] -> [b]
map (\StableSharingAcc
x SharingAcc ScopedAcc ScopedExp arrs
y -> StableSharingAcc
-> ScopedAcc arrs -> SharingAcc ScopedAcc ScopedExp arrs
forall (acc :: * -> *) arrs (exp :: * -> *).
StableSharingAcc -> acc arrs -> SharingAcc acc exp arrs
AletSharing StableSharingAcc
x ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs -> ScopedAcc arrs
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [] SharingAcc ScopedAcc ScopedExp arrs
y)) ([StableSharingAcc]
 -> SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> [StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall a b. (a -> b) -> a -> b
$ [StableSharingAcc]
bindHere
            sharingAcc :: SharingAcc ScopedAcc ScopedExp arrs
sharingAcc = SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
lets (SharingAcc ScopedAcc ScopedExp arrs
 -> SharingAcc ScopedAcc ScopedExp arrs)
-> SharingAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall a b. (a -> b) -> a -> b
$ StableAccName arrs
-> PreSmartAcc ScopedAcc ScopedExp arrs
-> SharingAcc ScopedAcc ScopedExp arrs
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs
-> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs
AccSharing StableAccName arrs
sn PreSmartAcc ScopedAcc ScopedExp arrs
newAcc

              -- trace support
            completed :: String
completed | [StableSharingAcc] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [StableSharingAcc]
bindHere = String
""
                      | Bool
otherwise     = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([StableSharingAcc] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StableSharingAcc]
bindHere) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" lets)"

        -- Extract *leading* nodes that have a complete node count (i.e., their node count is equal
        -- to the number of occurrences of that node in the overall expression).
        --
        -- Nodes with a completed node count should be let bound at the currently processed node.
        --
        -- NB: Only extract leading nodes (i.e., the longest run at the *front* of the list that is
        --     complete).  Otherwise, we would let-bind subterms before their parents, which leads
        --     scope errors.
        --
        filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
        filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
filterCompleted ([NodeCount]
ns, HashMap NodeName (HashSet NodeName)
graph)
          = let bindable :: [Bool]
bindable     = (NodeCount -> Bool) -> [NodeCount] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ([Bool] -> [NodeName] -> NodeCount -> Bool
isBindable [Bool]
bindable ((NodeCount -> NodeName) -> [NodeCount] -> [NodeName]
forall a b. (a -> b) -> [a] -> [b]
map NodeCount -> NodeName
nodeName [NodeCount]
ns)) [NodeCount]
ns
                ([(Bool, NodeCount)]
bind, [(Bool, NodeCount)]
rest) = ((Bool, NodeCount) -> Bool)
-> [(Bool, NodeCount)]
-> ([(Bool, NodeCount)], [(Bool, NodeCount)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool, NodeCount) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, NodeCount)] -> ([(Bool, NodeCount)], [(Bool, NodeCount)]))
-> [(Bool, NodeCount)]
-> ([(Bool, NodeCount)], [(Bool, NodeCount)])
forall a b. (a -> b) -> a -> b
$ [Bool] -> [NodeCount] -> [(Bool, NodeCount)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bindable [NodeCount]
ns
            in ((((Bool, NodeCount) -> NodeCount)
-> [(Bool, NodeCount)] -> [NodeCount]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, NodeCount) -> NodeCount
forall a b. (a, b) -> b
snd [(Bool, NodeCount)]
rest, HashMap NodeName (HashSet NodeName)
graph), [StableSharingAcc
sa | AccNodeCount StableSharingAcc
sa Int
_ <- ((Bool, NodeCount) -> NodeCount)
-> [(Bool, NodeCount)] -> [NodeCount]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, NodeCount) -> NodeCount
forall a b. (a, b) -> b
snd [(Bool, NodeCount)]
bind])
          where
            -- a node is not yet complete while the node count 'n' is below the overall number
            -- of occurrences for that node in the whole program, with the exception that free
            -- variables are never complete
            isCompleted :: NodeCount -> Bool
isCompleted nc :: NodeCount
nc@(AccNodeCount StableSharingAcc
sa Int
n) | Bool -> Bool
not (Bool -> Bool) -> (NodeCount -> Bool) -> NodeCount -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeCount -> Bool
isFreeVar (NodeCount -> Bool) -> NodeCount -> Bool
forall a b. (a -> b) -> a -> b
$ NodeCount
nc = OccMap SmartAcc -> StableSharingAcc -> Int
lookupWithSharingAcc OccMap SmartAcc
accOccMap StableSharingAcc
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
            isCompleted NodeCount
_                                             = Bool
False

            isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
            isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
isBindable [Bool]
bindable [NodeName]
nodes nc :: NodeCount
nc@(AccNodeCount StableSharingAcc
_ Int
_) =
              let superTerms :: [NodeName]
superTerms = HashSet NodeName -> [NodeName]
forall a. HashSet a -> [a]
Set.toList (HashSet NodeName -> [NodeName]) -> HashSet NodeName -> [NodeName]
forall a b. (a -> b) -> a -> b
$ HashMap NodeName (HashSet NodeName)
graph HashMap NodeName (HashSet NodeName) -> NodeName -> HashSet NodeName
forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
Map.! NodeCount -> NodeName
nodeName NodeCount
nc
                  unbound :: [Int]
unbound    = (NodeName -> Maybe Int) -> [NodeName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (NodeName -> [NodeName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [NodeName]
nodes) [NodeName]
superTerms
              in    NodeCount -> Bool
isCompleted NodeCount
nc
                 Bool -> Bool -> Bool
&& (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([Bool]
bindable [Bool] -> Int -> Bool
forall a. [a] -> Int -> a
!!) [Int]
unbound
            isBindable [Bool]
_ [NodeName]
_ (ExpNodeCount StableSharingExp
_ Int
_) = Bool
False
            -- isBindable _ _ (SeqNodeCount _ _) = False

    -- scopesSeq :: forall arrs. RootSeq arrs -> (ScopedSeq arrs, NodeCounts)
    -- scopesSeq = determineScopesSeq config accOccMap

    scopesExp
        :: HasCallStack
        => RootExp t
        -> (ScopedExp t, NodeCounts)
    scopesExp :: RootExp t -> (ScopedExp t, NodeCounts)
scopesExp = Config -> OccMap SmartAcc -> RootExp t -> (ScopedExp t, NodeCounts)
forall t.
HasCallStack =>
Config -> OccMap SmartAcc -> RootExp t -> (ScopedExp t, NodeCounts)
determineScopesExp Config
config OccMap SmartAcc
accOccMap

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesAfun1
        :: HasCallStack
        => (SmartAcc a1 -> UnscopedAcc a2)
        -> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
    scopesAfun1 :: (SmartAcc a1 -> UnscopedAcc a2)
-> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
scopesAfun1 SmartAcc a1 -> UnscopedAcc a2
f = (ScopedAcc a2 -> SmartAcc a1 -> ScopedAcc a2
forall a b. a -> b -> a
const ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp a2 -> ScopedAcc a2
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [StableSharingAcc]
ssa SharingAcc ScopedAcc ScopedExp a2
body'), ([NodeCount]
counts', HashMap NodeName (HashSet NodeName)
graph))
      where
        body :: UnscopedAcc a2
body@(UnscopedAcc [Int]
fvs SharingAcc UnscopedAcc RootExp a2
_)             = SmartAcc a1 -> UnscopedAcc a2
f SmartAcc a1
forall a. HasCallStack => a
undefined
        (ScopedAcc [] SharingAcc ScopedAcc ScopedExp a2
body', ([NodeCount]
counts,HashMap NodeName (HashSet NodeName)
graph)) = UnscopedAcc a2 -> (ScopedAcc a2, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc a2
body
        ([NodeCount]
freeCounts, [NodeCount]
counts')                = (NodeCount -> Bool) -> [NodeCount] -> ([NodeCount], [NodeCount])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition NodeCount -> Bool
isBoundHere [NodeCount]
counts
        ssa :: [StableSharingAcc]
ssa                                  = HasCallStack => [Int] -> [StableSharingAcc] -> [StableSharingAcc]
[Int] -> [StableSharingAcc] -> [StableSharingAcc]
buildInitialEnvAcc [Int]
fvs [StableSharingAcc
sa | AccNodeCount StableSharingAcc
sa Int
_ <- [NodeCount]
freeCounts]

        isBoundHere :: NodeCount -> Bool
isBoundHere (AccNodeCount (StableSharingAcc StableAccName arrs
_ (AccSharing StableAccName arrs
_ (Atag ArraysR arrs
_ Int
i))) Int
_) = Int
i Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
fvs
        isBoundHere NodeCount
_                                                               = Bool
False

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesFun1
        :: HasCallStack
        => (SmartExp e1 -> RootExp e2)
        -> (SmartExp e1 -> ScopedExp e2, NodeCounts)
    scopesFun1 :: (SmartExp e1 -> RootExp e2)
-> (SmartExp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 SmartExp e1 -> RootExp e2
f = (ScopedExp e2 -> SmartExp e1 -> ScopedExp e2
forall a b. a -> b -> a
const ScopedExp e2
body, NodeCounts
counts)
      where
        (ScopedExp e2
body, NodeCounts
counts) = RootExp e2 -> (ScopedExp e2, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp (SmartExp e1 -> RootExp e2
f SmartExp e1
forall a. HasCallStack => a
undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesFun2
        :: HasCallStack
        => (SmartExp e1 -> SmartExp e2 -> RootExp e3)
        -> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
    scopesFun2 :: (SmartExp e1 -> SmartExp e2 -> RootExp e3)
-> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 SmartExp e1 -> SmartExp e2 -> RootExp e3
f = (\SmartExp e1
_ SmartExp e2
_ -> ScopedExp e3
body, NodeCounts
counts)
      where
        (ScopedExp e3
body, NodeCounts
counts) = RootExp e3 -> (ScopedExp e3, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp (SmartExp e1 -> SmartExp e2 -> RootExp e3
f SmartExp e1
forall a. HasCallStack => a
undefined SmartExp e2
forall a. HasCallStack => a
undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesStencil1
        :: forall sh e1 e2 stencil. HasCallStack
        => UnscopedAcc (Array sh e1){-dummy-}
        -> (stencil -> RootExp e2)
        -> (stencil -> ScopedExp e2, NodeCounts)
    scopesStencil1 :: UnscopedAcc (Array sh e1)
-> (stencil -> RootExp e2) -> (stencil -> ScopedExp e2, NodeCounts)
scopesStencil1 UnscopedAcc (Array sh e1)
_ stencil -> RootExp e2
stencilFun = (ScopedExp e2 -> stencil -> ScopedExp e2
forall a b. a -> b -> a
const ScopedExp e2
body, NodeCounts
counts)
      where
        (ScopedExp e2
body, NodeCounts
counts) = RootExp e2 -> (ScopedExp e2, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp (stencil -> RootExp e2
stencilFun stencil
forall a. HasCallStack => a
undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesStencil2
        :: forall sh e1 e2 e3 stencil1 stencil2. HasCallStack
        => UnscopedAcc (Array sh e1){-dummy-}
        -> UnscopedAcc (Array sh e2){-dummy-}
        -> (stencil1 -> stencil2 -> RootExp e3)
        -> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts)
    scopesStencil2 :: UnscopedAcc (Array sh e1)
-> UnscopedAcc (Array sh e2)
-> (stencil1 -> stencil2 -> RootExp e3)
-> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts)
scopesStencil2 UnscopedAcc (Array sh e1)
_ UnscopedAcc (Array sh e2)
_ stencil1 -> stencil2 -> RootExp e3
stencilFun = (\stencil1
_ stencil2
_ -> ScopedExp e3
body, NodeCounts
counts)
      where
        (ScopedExp e3
body, NodeCounts
counts) = RootExp e3 -> (ScopedExp e3, NodeCounts)
forall t. HasCallStack => RootExp t -> (ScopedExp t, NodeCounts)
scopesExp (stencil1 -> stencil2 -> RootExp e3
stencilFun stencil1
forall a. HasCallStack => a
undefined stencil2
forall a. HasCallStack => a
undefined)

    scopesBoundary
        :: HasCallStack
        => PreBoundary UnscopedAcc RootExp t
        -> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
    scopesBoundary :: PreBoundary UnscopedAcc RootExp t
-> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
scopesBoundary PreBoundary UnscopedAcc RootExp t
bndy =
      case PreBoundary UnscopedAcc RootExp t
bndy of
        PreBoundary UnscopedAcc RootExp t
Clamp      -> (PreBoundary ScopedAcc ScopedExp t
forall (acc :: * -> *) (exp :: * -> *) t. PreBoundary acc exp t
Clamp, NodeCounts
noNodeCounts)
        PreBoundary UnscopedAcc RootExp t
Mirror     -> (PreBoundary ScopedAcc ScopedExp t
forall (acc :: * -> *) (exp :: * -> *) t. PreBoundary acc exp t
Mirror, NodeCounts
noNodeCounts)
        PreBoundary UnscopedAcc RootExp t
Wrap       -> (PreBoundary ScopedAcc ScopedExp t
forall (acc :: * -> *) (exp :: * -> *) t. PreBoundary acc exp t
Wrap, NodeCounts
noNodeCounts)
        Constant e
v -> (e -> PreBoundary ScopedAcc ScopedExp (Array sh e)
forall e (acc :: * -> *) (exp :: * -> *) sh.
e -> PreBoundary acc exp (Array sh e)
Constant e
v, NodeCounts
noNodeCounts)
        Function SmartExp sh -> RootExp e
f -> let (SmartExp sh -> ScopedExp e
body, NodeCounts
counts) = (SmartExp sh -> RootExp e)
-> (SmartExp sh -> ScopedExp e, NodeCounts)
forall e1 e2.
HasCallStack =>
(SmartExp e1 -> RootExp e2)
-> (SmartExp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 SmartExp sh -> RootExp e
f
                      in  ((SmartExp sh -> ScopedExp e)
-> PreBoundary ScopedAcc ScopedExp (Array sh e)
forall sh (exp :: * -> *) e (acc :: * -> *).
(SmartExp sh -> exp e) -> PreBoundary acc exp (Array sh e)
Function SmartExp sh -> ScopedExp e
body, NodeCounts
counts)


determineScopesExp
    :: HasCallStack
    => Config
    -> OccMap SmartAcc
    -> RootExp t
    -> (ScopedExp t, NodeCounts)          -- Root (closed) expression plus Acc node counts
determineScopesExp :: Config -> OccMap SmartAcc -> RootExp t -> (ScopedExp t, NodeCounts)
determineScopesExp Config
config OccMap SmartAcc
accOccMap (RootExp OccMap SmartExp
expOccMap exp :: UnscopedExp t
exp@(UnscopedExp [Int]
fvs SharingExp UnscopedAcc UnscopedExp t
_))
  = let
        (ScopedExp [] SharingExp ScopedAcc ScopedExp t
expWithScopes, ([NodeCount]
nodeCounts,HashMap NodeName (HashSet NodeName)
graph)) = Config
-> OccMap SmartAcc
-> OccMap SmartExp
-> UnscopedExp t
-> (ScopedExp t, NodeCounts)
forall t.
HasCallStack =>
Config
-> OccMap SmartAcc
-> OccMap SmartExp
-> UnscopedExp t
-> (ScopedExp t, NodeCounts)
determineScopesSharingExp Config
config OccMap SmartAcc
accOccMap OccMap SmartExp
expOccMap UnscopedExp t
exp
        ([NodeCount]
expCounts, [NodeCount]
accCounts)                           = (NodeCount -> Bool) -> [NodeCount] -> ([NodeCount], [NodeCount])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition NodeCount -> Bool
isExpNodeCount [NodeCount]
nodeCounts

        isExpNodeCount :: NodeCount -> Bool
isExpNodeCount ExpNodeCount{} = Bool
True
        isExpNodeCount NodeCount
_              = Bool
False
    in
    ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp (HasCallStack => [Int] -> [StableSharingExp] -> [StableSharingExp]
[Int] -> [StableSharingExp] -> [StableSharingExp]
buildInitialEnvExp [Int]
fvs [StableSharingExp
se | ExpNodeCount StableSharingExp
se Int
_ <- [NodeCount]
expCounts]) SharingExp ScopedAcc ScopedExp t
expWithScopes, NodeCounts -> NodeCounts
cleanCounts ([NodeCount]
accCounts,HashMap NodeName (HashSet NodeName)
graph))


determineScopesSharingExp
    :: HasCallStack
    => Config
    -> OccMap SmartAcc
    -> OccMap SmartExp
    -> UnscopedExp t
    -> (ScopedExp t, NodeCounts)
determineScopesSharingExp :: Config
-> OccMap SmartAcc
-> OccMap SmartExp
-> UnscopedExp t
-> (ScopedExp t, NodeCounts)
determineScopesSharingExp Config
config OccMap SmartAcc
accOccMap OccMap SmartExp
expOccMap = UnscopedExp t -> (ScopedExp t, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp
  where
    scopesAcc
        :: HasCallStack
        => UnscopedAcc a
        -> (ScopedAcc a, NodeCounts)
    scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts)
scopesAcc = Config
-> OccMap SmartAcc -> UnscopedAcc a -> (ScopedAcc a, NodeCounts)
forall a.
HasCallStack =>
Config
-> OccMap SmartAcc -> UnscopedAcc a -> (ScopedAcc a, NodeCounts)
determineScopesSharingAcc Config
config OccMap SmartAcc
accOccMap

    scopesFun1
        :: HasCallStack
        => (SmartExp a -> UnscopedExp b)
        -> (SmartExp a -> ScopedExp b, NodeCounts)
    scopesFun1 :: (SmartExp a -> UnscopedExp b)
-> (SmartExp a -> ScopedExp b, NodeCounts)
scopesFun1 SmartExp a -> UnscopedExp b
f = String
-> String
-> (SmartExp a -> ScopedExp b, NodeCounts)
-> (SmartExp a -> ScopedExp b, NodeCounts)
forall a. String -> String -> a -> a
tracePure (String
"LAMBDA " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [StableSharingExp] -> String
forall a. Show a => a -> String
show [StableSharingExp]
ssa) ([NodeCount] -> String
forall a. Show a => a -> String
show [NodeCount]
counts) (ScopedExp b -> SmartExp a -> ScopedExp b
forall a b. a -> b -> a
const ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp b -> ScopedExp b
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [StableSharingExp]
ssa SharingExp ScopedAcc ScopedExp b
body'), ([NodeCount]
counts',HashMap NodeName (HashSet NodeName)
graph))
      where
        body :: UnscopedExp b
body@(UnscopedExp [Int]
fvs SharingExp UnscopedAcc UnscopedExp b
_)              = SmartExp a -> UnscopedExp b
f SmartExp a
forall a. HasCallStack => a
undefined
        (ScopedExp [] SharingExp ScopedAcc ScopedExp b
body', ([NodeCount]
counts, HashMap NodeName (HashSet NodeName)
graph)) = UnscopedExp b -> (ScopedExp b, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp b
body
        ([NodeCount]
freeCounts, [NodeCount]
counts')                 = (NodeCount -> Bool) -> [NodeCount] -> ([NodeCount], [NodeCount])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition NodeCount -> Bool
isBoundHere [NodeCount]
counts
        ssa :: [StableSharingExp]
ssa                                   = HasCallStack => [Int] -> [StableSharingExp] -> [StableSharingExp]
[Int] -> [StableSharingExp] -> [StableSharingExp]
buildInitialEnvExp [Int]
fvs [StableSharingExp
se | ExpNodeCount StableSharingExp
se Int
_ <- [NodeCount]
freeCounts]

        isBoundHere :: NodeCount -> Bool
isBoundHere (ExpNodeCount (StableSharingExp StableExpName t
_ (ExpSharing StableExpName t
_ (Tag TypeR t
_ Int
i))) Int
_) = Int
i Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
fvs
        isBoundHere NodeCount
_                                                              = Bool
False

    scopesExp
        :: forall t. HasCallStack
        => UnscopedExp t
        -> (ScopedExp t, NodeCounts)
    scopesExp :: UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp (UnscopedExp [Int]
_ (LetSharing StableSharingExp
_ UnscopedExp t
_))
      = String -> (ScopedExp t, NodeCounts)
forall a. HasCallStack => String -> a
internalError String
"unexpected 'LetSharing'"

    scopesExp (UnscopedExp [Int]
_ (VarSharing StableExpName t
sn TypeR t
tp))
      = ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] (StableExpName t -> TypeR t -> SharingExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> TypeR t -> SharingExp acc exp t
VarSharing StableExpName t
sn TypeR t
tp), StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
forall t.
StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
StableSharingExp StableExpName t
sn (StableExpName t -> TypeR t -> SharingExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> TypeR t -> SharingExp acc exp t
VarSharing StableExpName t
sn TypeR t
tp) StableSharingExp -> NodeCounts -> NodeCounts
`insertExpNode` NodeCounts
noNodeCounts)

    scopesExp (UnscopedExp [Int]
_ (ExpSharing StableExpName t
sn PreSmartExp UnscopedAcc UnscopedExp t
pexp))
      = case PreSmartExp UnscopedAcc UnscopedExp t
pexp of
          Tag TypeR t
tp Int
i              -> HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (TypeR t -> Int -> PreSmartExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
TypeR t -> Int -> PreSmartExp acc exp t
Tag TypeR t
tp Int
i) NodeCounts
noNodeCounts
          Const ScalarType t
tp t
c            -> HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScalarType t -> t -> PreSmartExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const ScalarType t
tp t
c) NodeCounts
noNodeCounts
          Undef ScalarType t
tp              -> HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScalarType t -> PreSmartExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> PreSmartExp acc exp t
Undef ScalarType t
tp) NodeCounts
noNodeCounts
          Pair UnscopedExp t1
e1 UnscopedExp t2
e2            -> (ScopedExp t1 -> ScopedExp t2 -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp t1 -> UnscopedExp t2 -> (ScopedExp t, NodeCounts)
forall a b.
HasCallStack =>
(ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travE2 ScopedExp t1 -> ScopedExp t2 -> PreSmartExp ScopedAcc ScopedExp t
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
Pair UnscopedExp t1
e1 UnscopedExp t2
e2
          PreSmartExp UnscopedAcc UnscopedExp t
Nil                   -> HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct PreSmartExp ScopedAcc ScopedExp t
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil NodeCounts
noNodeCounts
          Prj PairIdx (t1, t2) t
i UnscopedExp (t1, t2)
e               -> (ScopedExp (t1, t2) -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp (t1, t2) -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (PairIdx (t1, t2) t
-> ScopedExp (t1, t2) -> PreSmartExp ScopedAcc ScopedExp t
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (t1, t2) t
i) UnscopedExp (t1, t2)
e
          VecPack   VecR n s tup
vec UnscopedExp tup
e       -> (ScopedExp tup -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp tup -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (VecR n s tup
-> ScopedExp tup -> PreSmartExp ScopedAcc ScopedExp (Vec n s)
forall (n :: Nat) s tup (exp :: * -> *) (acc :: * -> *).
KnownNat n =>
VecR n s tup -> exp tup -> PreSmartExp acc exp (Vec n s)
VecPack   VecR n s tup
vec) UnscopedExp tup
e
          VecUnpack VecR n s t
vec UnscopedExp (Vec n s)
e       -> (ScopedExp (Vec n s) -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp (Vec n s) -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (VecR n s t
-> ScopedExp (Vec n s) -> PreSmartExp ScopedAcc ScopedExp t
forall (n :: Nat) s tup (exp :: * -> *) (acc :: * -> *).
KnownNat n =>
VecR n s tup -> exp (Vec n s) -> PreSmartExp acc exp tup
VecUnpack VecR n s t
vec) UnscopedExp (Vec n s)
e
          ToIndex ShapeR sh
shr UnscopedExp sh
sh UnscopedExp sh
ix     -> (ScopedExp sh -> ScopedExp sh -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp sh -> UnscopedExp sh -> (ScopedExp t, NodeCounts)
forall a b.
HasCallStack =>
(ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travE2 (ShapeR sh
-> ScopedExp sh
-> ScopedExp sh
-> PreSmartExp ScopedAcc ScopedExp Int
forall sh (exp :: * -> *) (acc :: * -> *).
ShapeR sh -> exp sh -> exp sh -> PreSmartExp acc exp Int
ToIndex ShapeR sh
shr) UnscopedExp sh
sh UnscopedExp sh
ix
          FromIndex ShapeR t
shr UnscopedExp t
sh UnscopedExp Int
e    -> (ScopedExp t -> ScopedExp Int -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp t -> UnscopedExp Int -> (ScopedExp t, NodeCounts)
forall a b.
HasCallStack =>
(ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travE2 (ShapeR t
-> ScopedExp t
-> ScopedExp Int
-> PreSmartExp ScopedAcc ScopedExp t
forall sh (exp :: * -> *) (acc :: * -> *).
ShapeR sh -> exp sh -> exp Int -> PreSmartExp acc exp sh
FromIndex ShapeR t
shr) UnscopedExp t
sh UnscopedExp Int
e
          Match TagR t
t UnscopedExp t
e             -> (ScopedExp t -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp t -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (TagR t -> ScopedExp t -> PreSmartExp ScopedAcc ScopedExp t
forall t (exp :: * -> *) (acc :: * -> *).
TagR t -> exp t -> PreSmartExp acc exp t
Match TagR t
t) UnscopedExp t
e
          Case UnscopedExp a
e [(TagR a, UnscopedExp t)]
rhs            -> let (ScopedExp a
e',   NodeCounts
accCount1) = UnscopedExp a -> (ScopedExp a, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp a
e
                                       ([(TagR a, ScopedExp t)]
rhs', [NodeCounts]
accCount2) = [((TagR a, ScopedExp t), NodeCounts)]
-> ([(TagR a, ScopedExp t)], [NodeCounts])
forall a b. [(a, b)] -> ([a], [b])
unzip [ ((TagR a
t,ScopedExp t
c'), NodeCounts
counts)| (TagR a
t,UnscopedExp t
c) <- [(TagR a, UnscopedExp t)]
rhs, let (ScopedExp t
c', NodeCounts
counts) = UnscopedExp t -> (ScopedExp t, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp t
c ]
                                    in HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScopedExp a
-> [(TagR a, ScopedExp t)] -> PreSmartExp ScopedAcc ScopedExp t
forall (exp :: * -> *) a b (acc :: * -> *).
exp a -> [(TagR a, exp b)] -> PreSmartExp acc exp b
Case ScopedExp a
e' [(TagR a, ScopedExp t)]
rhs') ((NodeCounts -> NodeCounts -> NodeCounts)
-> NodeCounts -> [NodeCounts] -> NodeCounts
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr NodeCounts -> NodeCounts -> NodeCounts
(+++) NodeCounts
accCount1 [NodeCounts]
accCount2)
          Cond UnscopedExp PrimBool
e1 UnscopedExp t
e2 UnscopedExp t
e3         -> (ScopedExp PrimBool
 -> ScopedExp t -> ScopedExp t -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp PrimBool
-> UnscopedExp t
-> UnscopedExp t
-> (ScopedExp t, NodeCounts)
forall a b c.
HasCallStack =>
(ScopedExp a
 -> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a
-> UnscopedExp b
-> UnscopedExp c
-> (ScopedExp t, NodeCounts)
travE3 ScopedExp PrimBool
-> ScopedExp t -> ScopedExp t -> PreSmartExp ScopedAcc ScopedExp t
forall (exp :: * -> *) t (acc :: * -> *).
exp PrimBool -> exp t -> exp t -> PreSmartExp acc exp t
Cond UnscopedExp PrimBool
e1 UnscopedExp t
e2 UnscopedExp t
e3
          While TypeR t
tp SmartExp t -> UnscopedExp PrimBool
p SmartExp t -> UnscopedExp t
it UnscopedExp t
i       -> let (SmartExp t -> ScopedExp PrimBool
p' , NodeCounts
accCount1) = (SmartExp t -> UnscopedExp PrimBool)
-> (SmartExp t -> ScopedExp PrimBool, NodeCounts)
forall a b.
HasCallStack =>
(SmartExp a -> UnscopedExp b)
-> (SmartExp a -> ScopedExp b, NodeCounts)
scopesFun1 SmartExp t -> UnscopedExp PrimBool
p
                                       (SmartExp t -> ScopedExp t
it', NodeCounts
accCount2) = (SmartExp t -> UnscopedExp t)
-> (SmartExp t -> ScopedExp t, NodeCounts)
forall a b.
HasCallStack =>
(SmartExp a -> UnscopedExp b)
-> (SmartExp a -> ScopedExp b, NodeCounts)
scopesFun1 SmartExp t -> UnscopedExp t
it
                                       (ScopedExp t
i' , NodeCounts
accCount3) = UnscopedExp t -> (ScopedExp t, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp t
i
                                    in HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (TypeR t
-> (SmartExp t -> ScopedExp PrimBool)
-> (SmartExp t -> ScopedExp t)
-> ScopedExp t
-> PreSmartExp ScopedAcc ScopedExp t
forall t (exp :: * -> *) (acc :: * -> *).
TypeR t
-> (SmartExp t -> exp PrimBool)
-> (SmartExp t -> exp t)
-> exp t
-> PreSmartExp acc exp t
While TypeR t
tp SmartExp t -> ScopedExp PrimBool
p' SmartExp t -> ScopedExp t
it' ScopedExp t
i') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          PrimConst PrimConst t
c           -> HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (PrimConst t -> PreSmartExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
PrimConst t -> PreSmartExp acc exp t
PrimConst PrimConst t
c) NodeCounts
noNodeCounts
          PrimApp PrimFun (a -> t)
p UnscopedExp a
e           -> (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (PrimFun (a -> t)
-> ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t
forall a r (exp :: * -> *) (acc :: * -> *).
PrimFun (a -> r) -> exp a -> PreSmartExp acc exp r
PrimApp PrimFun (a -> t)
p) UnscopedExp a
e
          Index TypeR t
tp UnscopedAcc (Array sh t)
a UnscopedExp sh
e          -> (ScopedAcc (Array sh t)
 -> ScopedExp sh -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc (Array sh t)
-> UnscopedExp sh
-> (ScopedExp t, NodeCounts)
forall a b.
HasCallStack =>
(ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travAE (TypeR t
-> ScopedAcc (Array sh t)
-> ScopedExp sh
-> PreSmartExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) sh (exp :: * -> *).
TypeR t -> acc (Array sh t) -> exp sh -> PreSmartExp acc exp t
Index TypeR t
tp) UnscopedAcc (Array sh t)
a UnscopedExp sh
e
          LinearIndex TypeR t
tp UnscopedAcc (Array sh t)
a UnscopedExp Int
e    -> (ScopedAcc (Array sh t)
 -> ScopedExp Int -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc (Array sh t)
-> UnscopedExp Int
-> (ScopedExp t, NodeCounts)
forall a b.
HasCallStack =>
(ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travAE (TypeR t
-> ScopedAcc (Array sh t)
-> ScopedExp Int
-> PreSmartExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) sh (exp :: * -> *).
TypeR t -> acc (Array sh t) -> exp Int -> PreSmartExp acc exp t
LinearIndex TypeR t
tp) UnscopedAcc (Array sh t)
a UnscopedExp Int
e
          Shape ShapeR t
shr UnscopedAcc (Array t e)
a           -> (ScopedAcc (Array t e) -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc (Array t e) -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc a -> (ScopedExp t, NodeCounts)
travA (ShapeR t
-> ScopedAcc (Array t e) -> PreSmartExp ScopedAcc ScopedExp t
forall sh (acc :: * -> *) e (exp :: * -> *).
ShapeR sh -> acc (Array sh e) -> PreSmartExp acc exp sh
Shape ShapeR t
shr) UnscopedAcc (Array t e)
a
          ShapeSize ShapeR sh
shr UnscopedExp sh
e       -> (ScopedExp sh -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp sh -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (ShapeR sh -> ScopedExp sh -> PreSmartExp ScopedAcc ScopedExp Int
forall sh (exp :: * -> *) (acc :: * -> *).
ShapeR sh -> exp sh -> PreSmartExp acc exp Int
ShapeSize ShapeR sh
shr) UnscopedExp sh
e
          Foreign TypeR t
tp asm (x -> t)
ff SmartExp x -> SmartExp t
f UnscopedExp x
e     -> (ScopedExp x -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp x -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (TypeR t
-> asm (x -> t)
-> (SmartExp x -> SmartExp t)
-> ScopedExp x
-> PreSmartExp ScopedAcc ScopedExp t
forall (asm :: * -> *) y x (exp :: * -> *) (acc :: * -> *).
Foreign asm =>
TypeR y
-> asm (x -> y)
-> (SmartExp x -> SmartExp y)
-> exp x
-> PreSmartExp acc exp y
Foreign TypeR t
tp asm (x -> t)
ff SmartExp x -> SmartExp t
f) UnscopedExp x
e
          Coerce ScalarType a
t1 ScalarType t
t2 UnscopedExp a
e        -> (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 (ScalarType a
-> ScalarType t -> ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t
forall a b (exp :: * -> *) (acc :: * -> *).
BitSizeEq a b =>
ScalarType a -> ScalarType b -> exp a -> PreSmartExp acc exp b
Coerce ScalarType a
t1 ScalarType t
t2) UnscopedExp a
e
      where
        travE1 :: HasCallStack
               => (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedExp a
               -> (ScopedExp t, NodeCounts)
        travE1 :: (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> (ScopedExp t, NodeCounts)
travE1 ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t
c UnscopedExp a
e = HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t
c ScopedExp a
e') NodeCounts
accCount
          where
            (ScopedExp a
e', NodeCounts
accCount) = UnscopedExp a -> (ScopedExp a, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp a
e

        travE2 :: HasCallStack
               => (ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedExp a
               -> UnscopedExp b
               -> (ScopedExp t, NodeCounts)
        travE2 :: (ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travE2 ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t
c UnscopedExp a
e1 UnscopedExp b
e2 = HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t
c ScopedExp a
e1' ScopedExp b
e2') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2)
          where
            (ScopedExp a
e1', NodeCounts
accCount1) = UnscopedExp a -> (ScopedExp a, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp a
e1
            (ScopedExp b
e2', NodeCounts
accCount2) = UnscopedExp b -> (ScopedExp b, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp b
e2

        travE3 :: HasCallStack
               => (ScopedExp a -> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedExp a
               -> UnscopedExp b
               -> UnscopedExp c
               -> (ScopedExp t, NodeCounts)
        travE3 :: (ScopedExp a
 -> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedExp a
-> UnscopedExp b
-> UnscopedExp c
-> (ScopedExp t, NodeCounts)
travE3 ScopedExp a
-> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t
c UnscopedExp a
e1 UnscopedExp b
e2 UnscopedExp c
e3 = HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScopedExp a
-> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t
c ScopedExp a
e1' ScopedExp b
e2' ScopedExp c
e3') (NodeCounts
accCount1 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount2 NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount3)
          where
            (ScopedExp a
e1', NodeCounts
accCount1) = UnscopedExp a -> (ScopedExp a, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp a
e1
            (ScopedExp b
e2', NodeCounts
accCount2) = UnscopedExp b -> (ScopedExp b, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp b
e2
            (ScopedExp c
e3', NodeCounts
accCount3) = UnscopedExp c -> (ScopedExp c, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp c
e3

        travA :: HasCallStack
              => (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedAcc a
              -> (ScopedExp t, NodeCounts)
        travA :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc a -> (ScopedExp t, NodeCounts)
travA ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t
c UnscopedAcc a
acc = (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts)
floatOutAcc ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t
c ScopedAcc a
acc' NodeCounts
accCount
          where
            (ScopedAcc a
acc', NodeCounts
accCount)  = UnscopedAcc a -> (ScopedAcc a, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc a
acc

        travAE :: HasCallStack
               => (ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedAcc a
               -> UnscopedExp b
               -> (ScopedExp t, NodeCounts)
        travAE :: (ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
-> UnscopedAcc a -> UnscopedExp b -> (ScopedExp t, NodeCounts)
travAE ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t
c UnscopedAcc a
acc UnscopedExp b
e = (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts)
forall a.
HasCallStack =>
(ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts)
floatOutAcc (ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t
`c` ScopedExp b
e') ScopedAcc a
acc' (NodeCounts
accCountA NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCountE)
          where
            (ScopedAcc a
acc', NodeCounts
accCountA) = UnscopedAcc a -> (ScopedAcc a, NodeCounts)
forall arrs.
HasCallStack =>
UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc UnscopedAcc a
acc
            (ScopedExp b
e'  , NodeCounts
accCountE) = UnscopedExp b -> (ScopedExp b, NodeCounts)
forall t.
HasCallStack =>
UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp UnscopedExp b
e

        floatOutAcc
            :: HasCallStack
            => (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
            -> ScopedAcc a
            -> NodeCounts
            -> (ScopedExp t, NodeCounts)
        floatOutAcc :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
-> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts)
floatOutAcc ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t
c acc :: ScopedAcc a
acc@(ScopedAcc [StableSharingAcc]
_ (AvarSharing StableAccName a
_ ArraysR a
_)) NodeCounts
accCount        -- nothing to float out
          = HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t
c ScopedAcc a
acc) NodeCounts
accCount
        floatOutAcc ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t
c ScopedAcc a
acc NodeCounts
accCount
          = HasCallStack =>
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t
c ScopedAcc a
var) ((StableSharingAcc
stableAcc StableSharingAcc -> NodeCounts -> NodeCounts
`insertAccNode` NodeCounts
noNodeCounts) NodeCounts -> NodeCounts -> NodeCounts
+++ NodeCounts
accCount)
          where
             (ScopedAcc a
var, StableSharingAcc
stableAcc) = ScopedAcc a
-> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a, StableSharingAcc)
forall a.
HasCallStack =>
ScopedAcc a
-> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a, StableSharingAcc)
abstract ScopedAcc a
acc (\(ScopedAcc [StableSharingAcc]
_ SharingAcc ScopedAcc ScopedExp a
s) -> SharingAcc ScopedAcc ScopedExp a
s)

        abstract
            :: HasCallStack
            => ScopedAcc a
            -> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
            -> (ScopedAcc a, StableSharingAcc)
        abstract :: ScopedAcc a
-> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a, StableSharingAcc)
abstract (ScopedAcc [StableSharingAcc]
_   (AvarSharing StableAccName a
_ ArraysR a
_))     ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a
_    = String -> (ScopedAcc a, StableSharingAcc)
forall a. HasCallStack => String -> a
internalError String
"AvarSharing"
        abstract (ScopedAcc [StableSharingAcc]
ssa (AletSharing StableSharingAcc
sa ScopedAcc a
acc))  ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a
lets = ScopedAcc a
-> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a, StableSharingAcc)
forall a.
HasCallStack =>
ScopedAcc a
-> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a, StableSharingAcc)
abstract ScopedAcc a
acc (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a
lets (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a -> ScopedAcc a)
-> ScopedAcc a
-> SharingAcc ScopedAcc ScopedExp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp a -> ScopedAcc a
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [StableSharingAcc]
ssa (SharingAcc ScopedAcc ScopedExp a -> ScopedAcc a)
-> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> ScopedAcc a
-> ScopedAcc a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StableSharingAcc -> ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a
forall (acc :: * -> *) arrs (exp :: * -> *).
StableSharingAcc -> acc arrs -> SharingAcc acc exp arrs
AletSharing StableSharingAcc
sa)
        abstract acc :: ScopedAcc a
acc@(ScopedAcc [StableSharingAcc]
ssa (AccSharing StableAccName a
sn PreSmartAcc ScopedAcc ScopedExp a
a)) ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a
lets = ([StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp a -> ScopedAcc a
forall t.
[StableSharingAcc]
-> SharingAcc ScopedAcc ScopedExp t -> ScopedAcc t
ScopedAcc [StableSharingAcc]
ssa (StableAccName a -> ArraysR a -> SharingAcc ScopedAcc ScopedExp a
forall arrs (acc :: * -> *) (exp :: * -> *).
StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs
AvarSharing StableAccName a
sn (ArraysR a -> SharingAcc ScopedAcc ScopedExp a)
-> ArraysR a -> SharingAcc ScopedAcc ScopedExp a
forall a b. (a -> b) -> a -> b
$ PreSmartAcc ScopedAcc ScopedExp a -> ArraysR a
forall (f :: * -> *) a. HasArraysR f => f a -> ArraysR a
Smart.arraysR PreSmartAcc ScopedAcc ScopedExp a
a), StableAccName a
-> SharingAcc ScopedAcc ScopedExp a -> StableSharingAcc
forall arrs.
StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc
StableSharingAcc StableAccName a
sn (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a
lets ScopedAcc a
acc))

        -- Occurrence count of the currently processed node
        expOccCount :: Int
expOccCount = let StableNameHeight StableName (SmartExp t)
sn' Int
_ = StableExpName t
sn
                       in OccMap SmartExp -> StableASTName SmartExp -> Int
forall (c :: * -> *). OccMap c -> StableASTName c -> Int
lookupWithASTName OccMap SmartExp
expOccMap (StableName (SmartExp t) -> StableASTName SmartExp
forall (c :: * -> *) t. StableName (c t) -> StableASTName c
StableASTName StableName (SmartExp t)
sn')

        -- Reconstruct the current tree node.
        --
        -- * If the current node is being shared ('expOccCount > 1'), replace it by a 'VarSharing'
        --   node and float the shared subtree out wrapped in a 'NodeCounts' value.
        -- * If the current node is not shared, reconstruct it in place.
        -- * Special case for free variables ('Tag'): Replace the tree by a sharing variable and
        --   float the 'Tag' out in a 'NodeCounts' value.  This is independent of the number of
        --   occurrences.
        --
        -- In either case, any completed 'NodeCounts' are injected as bindings using 'LetSharing'
        -- node.
        --
        reconstruct
            :: HasCallStack
            => PreSmartExp ScopedAcc ScopedExp t
            -> NodeCounts
            -> (ScopedExp t, NodeCounts)
        reconstruct :: PreSmartExp ScopedAcc ScopedExp t
-> NodeCounts -> (ScopedExp t, NodeCounts)
reconstruct newExp :: PreSmartExp ScopedAcc ScopedExp t
newExp@(Tag TypeR t
tp Int
_) NodeCounts
_subCount
              -- free variable => replace by a sharing variable regardless of the number of
              -- occurrences
          = let thisCount :: NodeCounts
thisCount = StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
forall t.
StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
StableSharingExp StableExpName t
sn (StableExpName t
-> PreSmartExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> PreSmartExp acc exp t -> SharingExp acc exp t
ExpSharing StableExpName t
sn PreSmartExp ScopedAcc ScopedExp t
newExp) StableSharingExp -> NodeCounts -> NodeCounts
`insertExpNode` NodeCounts
noNodeCounts
            in
            String
-> String -> (ScopedExp t, NodeCounts) -> (ScopedExp t, NodeCounts)
forall a. String -> String -> a -> a
tracePure String
"FREE" (NodeCounts -> String
forall a. Show a => a -> String
show NodeCounts
thisCount)
            ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] (StableExpName t -> TypeR t -> SharingExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> TypeR t -> SharingExp acc exp t
VarSharing StableExpName t
sn TypeR t
tp), NodeCounts
thisCount)
        reconstruct PreSmartExp ScopedAcc ScopedExp t
newExp NodeCounts
subCount
              -- shared subtree => replace by a sharing variable (if 'recoverExpSharing' enabled)
          | Int
expOccCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 Bool -> Bool -> Bool
&& Flag
exp_sharing Flag -> BitSet Word32 Flag -> Bool
forall a c. (Enum a, Bits c) => a -> BitSet c a -> Bool
`member` Config -> BitSet Word32 Flag
options Config
config
          = let allCount :: NodeCounts
allCount = StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
forall t.
StableExpName t
-> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
StableSharingExp StableExpName t
sn SharingExp ScopedAcc ScopedExp t
sharingExp StableSharingExp -> NodeCounts -> NodeCounts
`insertExpNode` NodeCounts
newCount
            in
            String
-> String -> (ScopedExp t, NodeCounts) -> (ScopedExp t, NodeCounts)
forall a. String -> String -> a -> a
tracePure (String
"SHARED" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
completed) (NodeCounts -> String
forall a. Show a => a -> String
show NodeCounts
allCount)
            ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] (StableExpName t -> TypeR t -> SharingExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> TypeR t -> SharingExp acc exp t
VarSharing StableExpName t
sn (TypeR t -> SharingExp ScopedAcc ScopedExp t)
-> TypeR t -> SharingExp ScopedAcc ScopedExp t
forall a b. (a -> b) -> a -> b
$ PreSmartExp ScopedAcc ScopedExp t -> TypeR t
forall (f :: * -> *) t.
(HasTypeR f, HasCallStack) =>
f t -> TypeR t
typeR PreSmartExp ScopedAcc ScopedExp t
newExp), NodeCounts
allCount)
              -- neither shared nor free variable => leave it as it is
          | Bool
otherwise
          = String
-> String -> (ScopedExp t, NodeCounts) -> (ScopedExp t, NodeCounts)
forall a. String -> String -> a -> a
tracePure (String
"Normal" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
completed) (NodeCounts -> String
forall a. Show a => a -> String
show NodeCounts
newCount)
            ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] SharingExp ScopedAcc ScopedExp t
sharingExp, NodeCounts
newCount)
          where
              -- Determine the bindings that need to be attached to the current node...
            (NodeCounts
newCount, [StableSharingExp]
bindHere) = HasCallStack => NodeCounts -> (NodeCounts, [StableSharingExp])
NodeCounts -> (NodeCounts, [StableSharingExp])
filterCompleted NodeCounts
subCount

              -- ...and wrap them in 'LetSharing' constructors
            lets :: SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
lets       = ((SharingExp ScopedAcc ScopedExp t
  -> SharingExp ScopedAcc ScopedExp t)
 -> (SharingExp ScopedAcc ScopedExp t
     -> SharingExp ScopedAcc ScopedExp t)
 -> SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> (SharingExp ScopedAcc ScopedExp t
    -> SharingExp ScopedAcc ScopedExp t)
-> [SharingExp ScopedAcc ScopedExp t
    -> SharingExp ScopedAcc ScopedExp t]
-> SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (((SharingExp ScopedAcc ScopedExp t
  -> SharingExp ScopedAcc ScopedExp t)
 -> (SharingExp ScopedAcc ScopedExp t
     -> SharingExp ScopedAcc ScopedExp t)
 -> SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> (SharingExp ScopedAcc ScopedExp t
    -> SharingExp ScopedAcc ScopedExp t)
-> (SharingExp ScopedAcc ScopedExp t
    -> SharingExp ScopedAcc ScopedExp t)
-> SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall a b c. (a -> b -> c) -> b -> a -> c
flip (SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> (SharingExp ScopedAcc ScopedExp t
    -> SharingExp ScopedAcc ScopedExp t)
-> SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.)) SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall a. a -> a
id ([SharingExp ScopedAcc ScopedExp t
  -> SharingExp ScopedAcc ScopedExp t]
 -> SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> ([StableSharingExp]
    -> [SharingExp ScopedAcc ScopedExp t
        -> SharingExp ScopedAcc ScopedExp t])
-> [StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StableSharingExp
 -> SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> [StableSharingExp]
-> [SharingExp ScopedAcc ScopedExp t
    -> SharingExp ScopedAcc ScopedExp t]
forall a b. (a -> b) -> [a] -> [b]
map (\StableSharingExp
x SharingExp ScopedAcc ScopedExp t
y -> StableSharingExp -> ScopedExp t -> SharingExp ScopedAcc ScopedExp t
forall (exp :: * -> *) t (acc :: * -> *).
StableSharingExp -> exp t -> SharingExp acc exp t
LetSharing StableSharingExp
x ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] SharingExp ScopedAcc ScopedExp t
y)) ([StableSharingExp]
 -> SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> [StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall a b. (a -> b) -> a -> b
$ [StableSharingExp]
bindHere
            sharingExp :: SharingExp ScopedAcc ScopedExp t
sharingExp = SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
lets (SharingExp ScopedAcc ScopedExp t
 -> SharingExp ScopedAcc ScopedExp t)
-> SharingExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall a b. (a -> b) -> a -> b
$ StableExpName t
-> PreSmartExp ScopedAcc ScopedExp t
-> SharingExp ScopedAcc ScopedExp t
forall t (acc :: * -> *) (exp :: * -> *).
StableExpName t -> PreSmartExp acc exp t -> SharingExp acc exp t
ExpSharing StableExpName t
sn PreSmartExp ScopedAcc ScopedExp t
newExp

              -- trace support
            completed :: String
completed | [StableSharingExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [StableSharingExp]
bindHere = String
""
                      | Bool
otherwise     = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([StableSharingExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StableSharingExp]
bindHere) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" lets)"

        -- Extract *leading* nodes that have a complete node count (i.e., their node count is equal
        -- to the number of occurrences of that node in the overall expression).
        --
        -- Nodes with a completed node count should be let bound at the currently processed node.
        --
        -- NB: Only extract leading nodes (i.e., the longest run at the *front* of the list that is
        --     complete).  Otherwise, we would let-bind subterms before their parents, which leads
        --     scope errors.
        --
        filterCompleted :: HasCallStack => NodeCounts -> (NodeCounts, [StableSharingExp])
        filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingExp])
filterCompleted ([NodeCount]
ns,HashMap NodeName (HashSet NodeName)
graph)
          = let bindable :: [Bool]
bindable       = (NodeCount -> Bool) -> [NodeCount] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ([Bool] -> [NodeName] -> NodeCount -> Bool
isBindable [Bool]
bindable ((NodeCount -> NodeName) -> [NodeCount] -> [NodeName]
forall a b. (a -> b) -> [a] -> [b]
map NodeCount -> NodeName
nodeName [NodeCount]
ns)) [NodeCount]
ns
                ([(Bool, NodeCount)]
bind, [(Bool, NodeCount)]
unbind) = ((Bool, NodeCount) -> Bool)
-> [(Bool, NodeCount)]
-> ([(Bool, NodeCount)], [(Bool, NodeCount)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool, NodeCount) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, NodeCount)] -> ([(Bool, NodeCount)], [(Bool, NodeCount)]))
-> [(Bool, NodeCount)]
-> ([(Bool, NodeCount)], [(Bool, NodeCount)])
forall a b. (a -> b) -> a -> b
$ [Bool] -> [NodeCount] -> [(Bool, NodeCount)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bindable [NodeCount]
ns
            in ((((Bool, NodeCount) -> NodeCount)
-> [(Bool, NodeCount)] -> [NodeCount]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, NodeCount) -> NodeCount
forall a b. (a, b) -> b
snd [(Bool, NodeCount)]
unbind, HashMap NodeName (HashSet NodeName)
graph), [StableSharingExp
se | ExpNodeCount StableSharingExp
se Int
_ <- ((Bool, NodeCount) -> NodeCount)
-> [(Bool, NodeCount)] -> [NodeCount]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, NodeCount) -> NodeCount
forall a b. (a, b) -> b
snd [(Bool, NodeCount)]
bind])
          where
            -- a node is not yet complete while the node count 'n' is below the overall number
            -- of occurrences for that node in the whole program, with the exception that free
            -- variables are never complete
            isCompleted :: NodeCount -> Bool
isCompleted nc :: NodeCount
nc@(ExpNodeCount StableSharingExp
sa Int
n) | Bool -> Bool
not (Bool -> Bool) -> (NodeCount -> Bool) -> NodeCount -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeCount -> Bool
isFreeVar (NodeCount -> Bool) -> NodeCount -> Bool
forall a b. (a -> b) -> a -> b
$ NodeCount
nc = OccMap SmartExp -> StableSharingExp -> Int
lookupWithSharingExp OccMap SmartExp
expOccMap StableSharingExp
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
            isCompleted NodeCount
_                                             = Bool
False

            isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
            isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
isBindable [Bool]
bindable [NodeName]
nodes nc :: NodeCount
nc@(ExpNodeCount StableSharingExp
_ Int
_) =
              let superTerms :: [NodeName]
superTerms = HashSet NodeName -> [NodeName]
forall a. HashSet a -> [a]
Set.toList (HashSet NodeName -> [NodeName]) -> HashSet NodeName -> [NodeName]
forall a b. (a -> b) -> a -> b
$ HashMap NodeName (HashSet NodeName)
graph HashMap NodeName (HashSet NodeName) -> NodeName -> HashSet NodeName
forall k v.
(Eq k, Hashable k, HasCallStack) =>
HashMap k v -> k -> v
Map.! NodeCount -> NodeName
nodeName NodeCount
nc
                  unbound :: [Int]
unbound    = (NodeName -> Maybe Int) -> [NodeName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (NodeName -> [NodeName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [NodeName]
nodes) [NodeName]
superTerms
              in    NodeCount -> Bool
isCompleted NodeCount
nc
                 Bool -> Bool -> Bool
&& (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([Bool]
bindable [Bool] -> Int -> Bool
forall a. [a] -> Int -> a
!!) [Int]
unbound
            isBindable [Bool]
_ [NodeName]
_ (AccNodeCount StableSharingAcc
_ Int
_) = Bool
False
            -- isBindable _ _ (SeqNodeCount _ _) = False

{--
determineScopesSeq
    :: Config
    -> OccMap Acc
    -> RootSeq t
    -> (ScopedSeq t, NodeCounts)          -- Root (closed) expression plus Acc node counts
determineScopesSeq config accOccMap (RootSeq seqOccMap seq)
  = let
        (ScopedSeq seqWithScopes, (nodeCounts,graph)) = determineScopesSharingSeq config accOccMap seqOccMap seq
        binds      = [s | SeqNodeCount s _ <- nodeCounts]
        lets       = foldl (flip (.)) id . map (\x y -> SletSharing x (ScopedSeq y)) $ binds
        sharingSeq = lets seqWithScopes
        newCounts  = filter (not . isSeqCount) nodeCounts
        isSeqCount SeqNodeCount{} = True
        isSeqCount _              = False
    in
    (ScopedSeq sharingSeq, cleanCounts (newCounts,graph))

determineScopesSharingSeq
  :: Config
  -> OccMap Acc
  -> OccMap Seq
  -> UnscopedSeq t
  -> (ScopedSeq t, NodeCounts)
determineScopesSharingSeq config accOccMap _seqOccMap = scopesSeq
  where
    scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts)
    scopesAcc = determineScopesSharingAcc config accOccMap

    scopesExp :: RootExp t -> (ScopedExp t, NodeCounts)
    scopesExp = determineScopesExp config accOccMap

    scopesFun2 :: (Elt e1, Elt e2)
               => (Exp e1 -> Exp e2 -> RootExp e3)
               -> (Exp e1 -> Exp e2 -> ScopedExp e3, NodeCounts)
    scopesFun2 f = (\_ _ -> body, counts)
      where
        (body, counts) = scopesExp (f undefined undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesAfun1 :: Arrays a1 => (Acc a1 -> UnscopedAcc a2) -> (Acc a1 -> ScopedAcc a2, NodeCounts)
    scopesAfun1 f = (const (ScopedAcc ssa body'), (counts',graph))
      where
        body@(UnscopedAcc fvs _) = f undefined
        ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
        ssa     = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
        (freeCounts, counts') = partition isBoundHere counts

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
        isBoundHere _                                                             = False

    scopesAfun2 :: (Arrays a1, Arrays a2) => (Acc a1 -> Acc a2 -> UnscopedAcc a3) -> (Acc a1 -> Acc a2 -> ScopedAcc a3, NodeCounts)
    scopesAfun2 f = (\ _ _ -> (ScopedAcc ssa body'), (counts',graph))
      where
        body@(UnscopedAcc fvs _) = f undefined undefined
        ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
        ssa     = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
        (freeCounts, counts') = partition isBoundHere counts

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
        isBoundHere _                                                             = False

    scopesAfun3 :: (Arrays a1, Arrays a2, Arrays a3) => (Acc a1 -> Acc a2 -> Acc a3 -> UnscopedAcc a4) -> (Acc a1 -> Acc a2 -> Acc a3 -> ScopedAcc a4, NodeCounts)
    scopesAfun3 f = (\ _ _ _ -> (ScopedAcc ssa body'), (counts',graph))
      where
        body@(UnscopedAcc fvs _) = f undefined undefined undefined
        ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
        ssa     = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
        (freeCounts, counts') = partition isBoundHere counts

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
        isBoundHere _                                                             = False

    scopesTup :: Atuple UnscopedSeq tup -> (Atuple ScopedSeq tup, NodeCounts)
    scopesTup NilAtup          = (NilAtup, noNodeCounts)
    scopesTup (SnocAtup tup s) = let
                                   (tup', accCountT) = scopesTup tup
                                   (s'  , accCountS) = scopesSeq s
                                 in
                                 (SnocAtup tup' s', accCountT +++ accCountS)

    scopesSeq :: forall t. UnscopedSeq t -> (ScopedSeq t, NodeCounts)
    scopesSeq (UnscopedSeq (SletSharing _ _))
      = $internalError "determineScopesSharingSeq: scopesSeq" "unexpected 'LetSharing'"
    scopesSeq (UnscopedSeq (SvarSharing sn))
      = (ScopedSeq (SvarSharing sn), StableSharingSeq sn (SvarSharing sn) `insertSeqNode` noNodeCounts)

    scopesSeq (UnscopedSeq (SeqSharing sn s)) =
      case s of
        StreamIn arrs -> producer (StreamIn arrs) noNodeCounts
        ToSeq sl acc   -> let
                            (acc', accCount1) = scopesAcc acc
                          in producer (ToSeq sl acc') accCount1
        MapSeq     afun s'  -> let
                                 (afun', accCount1) = scopesAfun1 afun
                                 (s''  , accCount2) = scopesSeq s'
                               in producer (MapSeq afun' s'') (accCount1 +++ accCount2)
        ZipWithSeq afun s1 s2 -> let
                                   (afun', accCount1) = scopesAfun2 afun
                                   (s1'  , accCount2) = scopesSeq s1
                                   (s2'  , accCount3) = scopesSeq s2
                                 in producer (ZipWithSeq afun' s1' s2') (accCount1 +++ accCount2 +++ accCount3)
        ScanSeq fun e s' -> let
                              (fun', accCount1) = scopesFun2 fun
                              (e'  , accCount2) = scopesExp e
                              (s'' , accCount3) = scopesSeq s'
                            in producer (ScanSeq fun' e' s'') (accCount1 +++ accCount2 +++ accCount3)
        FoldSeq fun e s' -> let
                              (fun', accCount1) = scopesFun2 fun
                              (e'  , accCount2) = scopesExp e
                              (s'' , accCount3) = scopesSeq s'
                            in consumer (FoldSeq fun' e' s'') (accCount1 +++ accCount2 +++ accCount3)
        FoldSeqFlatten afun acc s' ->
                               let
                                 (afun', accCount1) = scopesAfun3 afun
                                 (acc' , accCount2) = scopesAcc acc
                                 (s''  , accCount3) = scopesSeq s'
                               in consumer (FoldSeqFlatten afun' acc' s'') (accCount1 +++ accCount2 +++ accCount3)
        Stuple tup          -> let
                                 (tup', accCount1) = scopesTup tup
                               in consumer (Stuple tup') accCount1
      where
        -- All producers must be replaced by sharing variables
        --
        producer :: (t ~ [a], Arrays a)
                 => PreSeq ScopedAcc ScopedSeq ScopedExp t
                 -> NodeCounts
                 -> (ScopedSeq t, NodeCounts)
        producer newSeq subCount
          = let allCount = StableSharingSeq sn (SeqSharing sn newSeq) `insertSeqNode` subCount
            in
            tracePure "Producer" (show allCount)
            (ScopedSeq (SvarSharing sn), allCount)

        -- Consumers cannot be shared.
        --
        consumer :: PreSeq ScopedAcc ScopedSeq ScopedExp t
                 -> NodeCounts
                 -> (ScopedSeq t, NodeCounts)
        consumer newSeq subCount
          = tracePure "Consumer" (show subCount)
            (ScopedSeq (SeqSharing sn newSeq), subCount)
--}

-- |Recover sharing information and annotate the HOAS AST with variable and let binding
-- annotations.  The first argument determines whether array computations are floated out of
-- expressions irrespective of whether they are shared or not — 'True' implies floating them out.
--
-- Also returns the 'StableSharingAcc's of all 'Atag' leaves in environment order — they represent
-- the free variables of the AST.
--
-- NB: Strictly speaking, this function is not deterministic, as it uses stable pointers to
--     determine the sharing of subterms.  The stable pointer API does not guarantee its
--     completeness; i.e., it may miss some equalities, which implies that we may fail to discover
--     some sharing.  However, sharing does not affect the denotational meaning of an array
--     computation; hence, we do not compromise denotational correctness.
--
--     There is one caveat: We currently rely on the 'Atag' and 'Tag' leaves representing free
--     variables to be shared if any of them is used more than once.  If one is duplicated, the
--     environment for de Bruijn conversion will have a duplicate entry, and hence, be of the wrong
--     size, which is fatal. (The 'buildInitialEnv*' functions will already bail out.)
--
{-# NOINLINE recoverSharingAcc #-}
recoverSharingAcc
    :: HasCallStack
    => Config
    -> Level            -- The level of currently bound array variables
    -> [Level]          -- The tags of newly introduced free array variables
    -> SmartAcc a
    -> (ScopedAcc a, [StableSharingAcc])
recoverSharingAcc :: Config
-> Int -> [Int] -> SmartAcc a -> (ScopedAcc a, [StableSharingAcc])
recoverSharingAcc Config
config Int
alvl [Int]
avars SmartAcc a
acc
  = let (UnscopedAcc a
acc', OccMap SmartAcc
occMap)
          = IO (UnscopedAcc a, OccMap SmartAcc)
-> (UnscopedAcc a, OccMap SmartAcc)
forall a. IO a -> a
unsafePerformIO             -- to enable stable pointers; this is safe as explained above
          (IO (UnscopedAcc a, OccMap SmartAcc)
 -> (UnscopedAcc a, OccMap SmartAcc))
-> IO (UnscopedAcc a, OccMap SmartAcc)
-> (UnscopedAcc a, OccMap SmartAcc)
forall a b. (a -> b) -> a -> b
$ Config -> Int -> SmartAcc a -> IO (UnscopedAcc a, OccMap SmartAcc)
forall arrs.
HasCallStack =>
Config
-> Int -> SmartAcc arrs -> IO (UnscopedAcc arrs, OccMap SmartAcc)
makeOccMapAcc Config
config Int
alvl SmartAcc a
acc
    in
    Config
-> [Int]
-> OccMap SmartAcc
-> UnscopedAcc a
-> (ScopedAcc a, [StableSharingAcc])
forall a.
HasCallStack =>
Config
-> [Int]
-> OccMap SmartAcc
-> UnscopedAcc a
-> (ScopedAcc a, [StableSharingAcc])
determineScopesAcc Config
config [Int]
avars OccMap SmartAcc
occMap UnscopedAcc a
acc'


{-# NOINLINE recoverSharingExp #-}
recoverSharingExp
    :: HasCallStack
    => Config
    -> Level            -- The level of currently bound scalar variables
    -> [Level]          -- The tags of newly introduced free scalar variables
    -> SmartExp e
    -> (ScopedExp e, [StableSharingExp])
recoverSharingExp :: Config
-> Int -> [Int] -> SmartExp e -> (ScopedExp e, [StableSharingExp])
recoverSharingExp Config
config Int
lvl [Int]
fvar SmartExp e
exp
  = let
        (RootExp e
rootExp, OccMap SmartAcc
accOccMap) = IO (RootExp e, OccMap SmartAcc) -> (RootExp e, OccMap SmartAcc)
forall a. IO a -> a
unsafePerformIO (IO (RootExp e, OccMap SmartAcc) -> (RootExp e, OccMap SmartAcc))
-> IO (RootExp e, OccMap SmartAcc) -> (RootExp e, OccMap SmartAcc)
forall a b. (a -> b) -> a -> b
$ do
          HashTable RealWorld (StableASTName SmartAcc) (Int, Int)
accOccMap       <- IO (HashTable RealWorld (StableASTName SmartAcc) (Int, Int))
forall (c :: * -> *) v. IO (ASTHashTable c v)
newASTHashTable
          (RootExp e
exp', Int
_)       <- Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
forall e.
HasCallStack =>
Config
-> OccMapHash SmartAcc
-> Int
-> [Int]
-> SmartExp e
-> IO (RootExp e, Int)
makeOccMapRootExp Config
config HashTable RealWorld (StableASTName SmartAcc) (Int, Int)
OccMapHash SmartAcc
accOccMap Int
lvl [Int]
fvar SmartExp e
exp
          OccMap SmartAcc
frozenAccOccMap <- OccMapHash SmartAcc -> IO (OccMap SmartAcc)
forall (c :: * -> *). OccMapHash c -> IO (OccMap c)
freezeOccMap HashTable RealWorld (StableASTName SmartAcc) (Int, Int)
OccMapHash SmartAcc
accOccMap

          (RootExp e, OccMap SmartAcc) -> IO (RootExp e, OccMap SmartAcc)
forall (m :: * -> *) a. Monad m => a -> m a
return (RootExp e
exp', OccMap SmartAcc
frozenAccOccMap)

        (ScopedExp [StableSharingExp]
sse SharingExp ScopedAcc ScopedExp e
sharingExp, NodeCounts
_) =
          Config -> OccMap SmartAcc -> RootExp e -> (ScopedExp e, NodeCounts)
forall t.
HasCallStack =>
Config -> OccMap SmartAcc -> RootExp t -> (ScopedExp t, NodeCounts)
determineScopesExp Config
config OccMap SmartAcc
accOccMap RootExp e
rootExp
    in
    ([StableSharingExp]
-> SharingExp ScopedAcc ScopedExp e -> ScopedExp e
forall t.
[StableSharingExp]
-> SharingExp ScopedAcc ScopedExp t -> ScopedExp t
ScopedExp [] SharingExp ScopedAcc ScopedExp e
sharingExp, [StableSharingExp]
sse)


{--
{-# NOINLINE recoverSharingSeq #-}
recoverSharingSeq
    :: Config
    -> Seq e
    -> (ScopedSeq e, [StableSharingSeq])
recoverSharingSeq config seq
  = let
        (rootSeq, accOccMap) = unsafePerformIO $ do
          accOccMap       <- newASTHashTable
          (seq', _)       <- makeOccMapRootSeq config accOccMap 0 seq
          frozenAccOccMap <- freezeOccMap accOccMap

          return (seq', frozenAccOccMap)

        (ScopedSeq sharingSeq, (ns, _)) =
          determineScopesSeq config accOccMap rootSeq
    in
    (ScopedSeq sharingSeq, [a | SeqNodeCount a _ <- ns])
--}


-- Debugging
-- ---------

traceLine :: String -> String -> IO ()
traceLine :: String -> String -> IO ()
traceLine String
header String
msg
  = Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_sharing
  (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
header String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg

traceChunk :: String -> String -> IO ()
traceChunk :: String -> String -> IO ()
traceChunk String
header String
msg
  = Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_sharing
  (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
header String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n      " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg

tracePure :: String -> String -> a -> a
tracePure :: String -> String -> a -> a
tracePure String
header String
msg
  = Flag -> String -> a -> a
forall a. Flag -> String -> a -> a
Debug.trace Flag
Debug.dump_sharing
  (String -> a -> a) -> String -> a -> a
forall a b. (a -> b) -> a -> b
$ String
header String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg