{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Trafo.Sharing (
convertAcc, convertAfun, Afunction, AfunctionR,
convertExp, convertFun, Function, FunctionR,
) where
import Control.Applicative hiding ( Const )
import Control.Monad.Fix
import Data.List
import Data.Maybe
import Data.Hashable
import Data.Typeable
import System.Mem.StableName
import System.IO.Unsafe ( unsafePerformIO )
import qualified Data.HashTable.IO as Hash
import qualified Data.IntMap as IntMap
import qualified Data.HashMap.Strict as Map
import qualified Data.HashSet as Set
import Prelude
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Array.Sugar as Sugar
import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc
, PreOpenExp(..), OpenExp, PreExp, Exp
, PreBoundary(..), Boundary, Stencil(..)
, showPreAccOp, showPreExpOp )
import qualified Data.Array.Accelerate.AST as AST
import qualified Data.Array.Accelerate.Debug as Debug
data Config = Config
{
recoverAccSharing :: Bool
, recoverExpSharing :: Bool
, recoverSeqSharing :: Bool
, floatOutAcc :: Bool
}
data Layout env env' where
EmptyLayout :: Layout env ()
PushLayout :: Typeable t
=> Layout env env' -> Idx env t -> Layout env (env', t)
prjIdx :: forall t env env'. Typeable t => String -> Int -> Layout env env' -> Idx env t
prjIdx ctxt 0 (PushLayout _ (ix :: Idx env0 t0))
= flip fromMaybe (gcast ix)
$ possiblyNestedErr ctxt $
"Couldn't match expected type `" ++ show (typeOf (undefined::t)) ++
"' with actual type `" ++ show (typeOf (undefined::t0)) ++ "'" ++
"\n Type mismatch"
prjIdx ctxt n (PushLayout l _) = prjIdx ctxt (n - 1) l
prjIdx ctxt _ EmptyLayout = possiblyNestedErr ctxt "Environment doesn't contain index"
possiblyNestedErr :: String -> String -> a
possiblyNestedErr ctxt failreason
= error $ "Fatal error in Sharing.prjIdx:"
++ "\n " ++ failreason ++ " at " ++ ctxt
++ "\n Possible reason: nested data parallelism — array computation that depends on a"
++ "\n scalar variable of type 'Exp a'"
incLayout :: Layout env env' -> Layout (env, t) env'
incLayout EmptyLayout = EmptyLayout
incLayout (PushLayout lyt ix) = PushLayout (incLayout lyt) (SuccIdx ix)
sizeLayout :: Layout env env' -> Int
sizeLayout EmptyLayout = 0
sizeLayout (PushLayout lyt _) = 1 + sizeLayout lyt
convertAcc
:: Arrays arrs
=> Bool
-> Bool
-> Bool
-> Bool
-> Acc arrs
-> AST.Acc arrs
convertAcc shareAcc shareExp shareSeq floatAcc acc
= let config = Config shareAcc shareExp shareSeq (shareAcc && floatAcc)
in
convertOpenAcc config 0 [] EmptyLayout acc
convertAfun :: Afunction f => Bool -> Bool -> Bool -> Bool -> f -> AST.Afun (AfunctionR f)
convertAfun shareAcc shareExp shareSeq floatAcc =
let config = Config shareAcc shareExp shareSeq (shareAcc && floatAcc)
in aconvert config EmptyLayout
class Afunction f where
type AfunctionR f
aconvert :: Config -> Layout aenv aenv -> f -> AST.OpenAfun aenv (AfunctionR f)
instance (Arrays a, Afunction r) => Afunction (Acc a -> r) where
type AfunctionR (Acc a -> r) = a -> AfunctionR r
aconvert config alyt f
= let a = Acc $ Atag (sizeLayout alyt)
alyt' = incLayout alyt `PushLayout` ZeroIdx
in
Alam $ aconvert config alyt' (f a)
instance Arrays b => Afunction (Acc b) where
type AfunctionR (Acc b) = b
aconvert config alyt body
= let lvl = sizeLayout alyt
vars = [lvl-1, lvl-2 .. 0]
in
Abody $ convertOpenAcc config lvl vars alyt body
convertOpenAcc
:: Arrays arrs
=> Config
-> Level
-> [Level]
-> Layout aenv aenv
-> Acc arrs
-> AST.OpenAcc aenv arrs
convertOpenAcc config lvl fvs alyt acc
= let (sharingAcc, initialEnv) = recoverSharingAcc config lvl fvs acc
in
convertSharingAcc config alyt initialEnv sharingAcc
convertSharingAcc
:: forall aenv arrs. Arrays arrs
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> ScopedAcc arrs
-> AST.OpenAcc aenv arrs
convertSharingAcc _ alyt aenv (ScopedAcc lams (AvarSharing sa))
| Just i <- findIndex (matchStableAcc sa) aenv'
= AST.OpenAcc $ AST.Avar (prjIdx (ctxt ++ "; i = " ++ show i) i alyt)
| null aenv'
= error $ "Cyclic definition of a value of type 'Acc' (sa = " ++
show (hashStableNameHeight sa) ++ ")"
| otherwise
= $internalError "convertSharingAcc" err
where
aenv' = lams ++ aenv
ctxt = "shared 'Acc' tree with stable name " ++ show (hashStableNameHeight sa)
err = "inconsistent valuation @ " ++ ctxt ++ ";\n aenv = " ++ show aenv'
convertSharingAcc config alyt aenv (ScopedAcc lams (AletSharing sa@(StableSharingAcc _ boundAcc) bodyAcc))
= AST.OpenAcc
$ let alyt' = incLayout alyt `PushLayout` ZeroIdx
aenv' = lams ++ aenv
in
AST.Alet (convertSharingAcc config alyt aenv' (ScopedAcc [] boundAcc))
(convertSharingAcc config alyt' (sa:aenv') bodyAcc)
convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc))
= AST.OpenAcc
$ let aenv' = lams ++ aenv
cvtA :: Arrays a => ScopedAcc a -> AST.OpenAcc aenv a
cvtA = convertSharingAcc config alyt aenv'
cvtE :: Elt t => ScopedExp t -> AST.Exp aenv t
cvtE = convertSharingExp config EmptyLayout alyt [] aenv'
cvtF1 :: (Elt a, Elt b) => (Exp a -> ScopedExp b) -> AST.Fun aenv (a -> b)
cvtF1 = convertSharingFun1 config alyt aenv'
cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
cvtF2 = convertSharingFun2 config alyt aenv'
cvtAfun1 :: (Arrays a, Arrays b) => (Acc a -> ScopedAcc b) -> AST.OpenAfun aenv (a -> b)
cvtAfun1 = convertSharingAfun1 config alyt aenv'
in
case preAcc of
Atag i
-> AST.Avar (prjIdx ("de Bruijn conversion tag " ++ show i) i alyt)
Pipe afun1 afun2 acc
-> let noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ())
alyt' = incLayout alyt `PushLayout` ZeroIdx
boundAcc = cvtAfun1 afun1 `AST.Apply` cvtA acc
bodyAcc = convertSharingAfun1 config alyt' (noStableSharing : aenv') afun2
`AST.Apply`
AST.OpenAcc (AST.Avar AST.ZeroIdx)
in
AST.Alet (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc)
Aforeign ff afun acc
-> let a = recoverAccSharing config
e = recoverExpSharing config
s = recoverSeqSharing config
f = floatOutAcc config
in
AST.Aforeign ff (convertAfun a e s f afun) (cvtA acc)
Acond b acc1 acc2 -> AST.Acond (cvtE b) (cvtA acc1) (cvtA acc2)
Awhile pred iter init -> AST.Awhile (cvtAfun1 pred) (cvtAfun1 iter) (cvtA init)
Atuple arrs -> AST.Atuple (convertSharingAtuple config alyt aenv' arrs)
Aprj ix a -> AST.Aprj ix (cvtA a)
Use array -> AST.Use (fromArr array)
Unit e -> AST.Unit (cvtE e)
Generate sh f -> AST.Generate (cvtE sh) (cvtF1 f)
Reshape e acc -> AST.Reshape (cvtE e) (cvtA acc)
Replicate ix acc -> mkReplicate (cvtE ix) (cvtA acc)
Slice acc ix -> mkIndex (cvtA acc) (cvtE ix)
Map f acc -> AST.Map (cvtF1 f) (cvtA acc)
ZipWith f acc1 acc2 -> AST.ZipWith (cvtF2 f) (cvtA acc1) (cvtA acc2)
Fold f e acc -> AST.Fold (cvtF2 f) (cvtE e) (cvtA acc)
Fold1 f acc -> AST.Fold1 (cvtF2 f) (cvtA acc)
FoldSeg f e acc1 acc2 -> AST.FoldSeg (cvtF2 f) (cvtE e) (cvtA acc1) (cvtA acc2)
Fold1Seg f acc1 acc2 -> AST.Fold1Seg (cvtF2 f) (cvtA acc1) (cvtA acc2)
Scanl f e acc -> AST.Scanl (cvtF2 f) (cvtE e) (cvtA acc)
Scanl' f e acc -> AST.Scanl' (cvtF2 f) (cvtE e) (cvtA acc)
Scanl1 f acc -> AST.Scanl1 (cvtF2 f) (cvtA acc)
Scanr f e acc -> AST.Scanr (cvtF2 f) (cvtE e) (cvtA acc)
Scanr' f e acc -> AST.Scanr' (cvtF2 f) (cvtE e) (cvtA acc)
Scanr1 f acc -> AST.Scanr1 (cvtF2 f) (cvtA acc)
Permute f dftAcc perm acc -> AST.Permute (cvtF2 f) (cvtA dftAcc) (cvtF1 perm) (cvtA acc)
Backpermute newDim perm acc -> AST.Backpermute (cvtE newDim) (cvtF1 perm) (cvtA acc)
Stencil stencil boundary acc
-> AST.Stencil (convertSharingStencilFun1 config acc alyt aenv' stencil)
(convertSharingBoundary config alyt aenv' boundary)
(cvtA acc)
Stencil2 stencil bndy1 acc1 bndy2 acc2
-> AST.Stencil2 (convertSharingStencilFun2 config acc1 acc2 alyt aenv' stencil)
(convertSharingBoundary config alyt aenv' bndy1)
(cvtA acc1)
(convertSharingBoundary config alyt aenv' bndy2)
(cvtA acc2)
convertSharingAfun1
:: forall aenv a b. (Arrays a, Arrays b)
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> (Acc a -> ScopedAcc b)
-> OpenAfun aenv (a -> b)
convertSharingAfun1 config alyt aenv f
= Alam (Abody (convertSharingAcc config alyt' aenv body))
where
alyt' = incLayout alyt `PushLayout` ZeroIdx
body = f undefined
convertSharingAtuple
:: forall aenv a.
Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> Atuple ScopedAcc a
-> Atuple (AST.OpenAcc aenv) a
convertSharingAtuple config alyt aenv = cvt
where
cvt :: Atuple ScopedAcc a' -> Atuple (AST.OpenAcc aenv) a'
cvt NilAtup = NilAtup
cvt (SnocAtup t a) = cvt t `SnocAtup` convertSharingAcc config alyt aenv a
convertSharingBoundary
:: forall aenv t.
Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> PreBoundary ScopedAcc ScopedExp t
-> AST.PreBoundary AST.OpenAcc aenv t
convertSharingBoundary config alyt aenv = cvt
where
cvt :: PreBoundary ScopedAcc ScopedExp t -> AST.Boundary aenv t
cvt bndy =
case bndy of
Clamp -> AST.Clamp
Mirror -> AST.Mirror
Wrap -> AST.Wrap
Constant v -> AST.Constant $ fromElt v
Function f -> AST.Function $ convertSharingFun1 config alyt aenv f
mkIndex :: forall slix e aenv. (Slice slix, Elt e)
=> AST.OpenAcc aenv (Array (FullShape slix) e)
-> AST.Exp aenv slix
-> AST.PreOpenAcc AST.OpenAcc aenv (Array (SliceShape slix) e)
mkIndex = AST.Slice (sliceIndex slix)
where
slix = undefined :: slix
mkReplicate :: forall slix e aenv. (Slice slix, Elt e)
=> AST.Exp aenv slix
-> AST.OpenAcc aenv (Array (SliceShape slix) e)
-> AST.PreOpenAcc AST.OpenAcc aenv (Array (FullShape slix) e)
mkReplicate = AST.Replicate (sliceIndex slix)
where
slix = undefined :: slix
convertFun :: Function f => Bool -> f -> AST.Fun () (FunctionR f)
convertFun shareExp =
let config = Config False shareExp False False
in convert config EmptyLayout
class Function f where
type FunctionR f
convert :: Config -> Layout env env -> f -> AST.OpenFun env () (FunctionR f)
instance (Elt a, Function r) => Function (Exp a -> r) where
type FunctionR (Exp a -> r) = a -> FunctionR r
convert config lyt f
= let x = Exp $ Tag (sizeLayout lyt)
lyt' = incLayout lyt `PushLayout` ZeroIdx
in
Lam $ convert config lyt' (f x)
instance Elt b => Function (Exp b) where
type FunctionR (Exp b) = b
convert config lyt body
= let lvl = sizeLayout lyt
vars = [lvl-1, lvl-2 .. 0]
in
Body $ convertOpenExp config lvl vars lyt body
convertExp
:: Elt e
=> Bool
-> Exp e
-> AST.Exp () e
convertExp shareExp exp
= let config = Config False shareExp False False
in
convertOpenExp config 0 [] EmptyLayout exp
convertOpenExp
:: Elt e
=> Config
-> Level
-> [Level]
-> Layout env env
-> Exp e
-> AST.OpenExp env () e
convertOpenExp config lvl fvar lyt exp
= let (sharingExp, initialEnv) = recoverSharingExp config lvl fvar exp
in
convertSharingExp config lyt EmptyLayout initialEnv [] sharingExp
convertSharingExp
:: forall t env aenv. Elt t
=> Config
-> Layout env env
-> Layout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> ScopedExp t
-> AST.OpenExp env aenv t
convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp
where
env' = lams ++ env
cvt :: Elt t' => ScopedExp t' -> AST.OpenExp env aenv t'
cvt (ScopedExp _ (VarSharing se))
| Just i <- findIndex (matchStableExp se) env'
= AST.Var (prjIdx (ctxt ++ "; i = " ++ show i) i lyt)
| null env'
= error $ "Cyclic definition of a value of type 'Exp' (sa = " ++ show (hashStableNameHeight se) ++ ")"
| otherwise
= $internalError "convertSharingExp" err
where
ctxt = "shared 'Exp' tree with stable name " ++ show (hashStableNameHeight se)
err = "inconsistent valuation @ " ++ ctxt ++ ";\n env' = " ++ show env'
cvt (ScopedExp _ (LetSharing se@(StableSharingExp _ boundExp) bodyExp))
= let lyt' = incLayout lyt `PushLayout` ZeroIdx
in
AST.Let (cvt (ScopedExp [] boundExp)) (convertSharingExp config lyt' alyt (se:env') aenv bodyExp)
cvt (ScopedExp _ (ExpSharing _ pexp))
= case pexp of
Tag i -> AST.Var (prjIdx ("de Bruijn conversion tag " ++ show i) i lyt)
Const v -> AST.Const (fromElt v)
Tuple tup -> AST.Tuple (cvtT tup)
Prj idx e -> AST.Prj idx (cvt e)
IndexNil -> AST.IndexNil
IndexCons ix i -> AST.IndexCons (cvt ix) (cvt i)
IndexHead i -> AST.IndexHead (cvt i)
IndexTail ix -> AST.IndexTail (cvt ix)
IndexAny -> AST.IndexAny
ToIndex sh ix -> AST.ToIndex (cvt sh) (cvt ix)
FromIndex sh e -> AST.FromIndex (cvt sh) (cvt e)
Cond e1 e2 e3 -> AST.Cond (cvt e1) (cvt e2) (cvt e3)
While p it i -> AST.While (cvtFun1 p) (cvtFun1 it) (cvt i)
PrimConst c -> AST.PrimConst c
PrimApp f e -> cvtPrimFun f (cvt e)
Index a e -> AST.Index (cvtA a) (cvt e)
LinearIndex a i -> AST.LinearIndex (cvtA a) (cvt i)
Shape a -> AST.Shape (cvtA a)
ShapeSize e -> AST.ShapeSize (cvt e)
Intersect sh1 sh2 -> AST.Intersect (cvt sh1) (cvt sh2)
Union sh1 sh2 -> AST.Union (cvt sh1) (cvt sh2)
Foreign ff f e -> AST.Foreign ff (convertFun (recoverExpSharing config) f) (cvt e)
cvtA :: Arrays a => ScopedAcc a -> AST.OpenAcc aenv a
cvtA = convertSharingAcc config alyt aenv
cvtT :: Tuple ScopedExp tup -> Tuple (AST.OpenExp env aenv) tup
cvtT = convertSharingTuple config lyt alyt env' aenv
cvtFun1 :: (Elt a, Elt b) => (Exp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b)
cvtFun1 f = Lam (Body (convertSharingExp config lyt' alyt env' aenv body))
where
lyt' = incLayout lyt `PushLayout` ZeroIdx
body = f undefined
cvtPrimFun :: (Elt a, Elt r)
=> AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r
cvtPrimFun f e = case e of
AST.Let bnd body -> AST.Let bnd (cvtPrimFun f body)
x -> AST.PrimApp f x
convertSharingTuple
:: Config
-> Layout env env
-> Layout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> Tuple ScopedExp t
-> Tuple (AST.OpenExp env aenv) t
convertSharingTuple config lyt alyt env aenv tup =
case tup of
NilTup -> NilTup
SnocTup t e -> convertSharingTuple config lyt alyt env aenv t
`SnocTup` convertSharingExp config lyt alyt env aenv e
convertSharingFun1
:: forall a b aenv. (Elt a, Elt b)
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> (Exp a -> ScopedExp b)
-> AST.Fun aenv (a -> b)
convertSharingFun1 config alyt aenv f = Lam (Body openF)
where
a = Exp undefined
lyt = EmptyLayout
`PushLayout`
(ZeroIdx :: Idx ((), a) a)
openF = convertSharingExp config lyt alyt [] aenv (f a)
convertSharingFun2
:: forall a b c aenv. (Elt a, Elt b, Elt c)
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> (Exp a -> Exp b -> ScopedExp c)
-> AST.Fun aenv (a -> b -> c)
convertSharingFun2 config alyt aenv f = Lam (Lam (Body openF))
where
a = Exp undefined
b = Exp undefined
lyt = EmptyLayout
`PushLayout`
(SuccIdx ZeroIdx :: Idx (((), a), b) a)
`PushLayout`
(ZeroIdx :: Idx (((), a), b) b)
openF = convertSharingExp config lyt alyt [] aenv (f a b)
convertSharingStencilFun1
:: forall sh a stencil b aenv. (Elt a, Stencil sh a stencil, Elt b)
=> Config
-> ScopedAcc (Array sh a)
-> Layout aenv aenv
-> [StableSharingAcc]
-> (stencil -> ScopedExp b)
-> AST.Fun aenv (StencilRepr sh stencil -> b)
convertSharingStencilFun1 config _ alyt aenv stencilFun = Lam (Body openStencilFun)
where
stencil = Exp undefined :: Exp (StencilRepr sh stencil)
lyt = EmptyLayout
`PushLayout`
(ZeroIdx :: Idx ((), StencilRepr sh stencil)
(StencilRepr sh stencil))
body = stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil)
openStencilFun = convertSharingExp config lyt alyt [] aenv body
convertSharingStencilFun2
:: forall sh a b stencil1 stencil2 c aenv.
(Elt a, Stencil sh a stencil1,
Elt b, Stencil sh b stencil2,
Elt c)
=> Config
-> ScopedAcc (Array sh a)
-> ScopedAcc (Array sh b)
-> Layout aenv aenv
-> [StableSharingAcc]
-> (stencil1 -> stencil2 -> ScopedExp c)
-> AST.Fun aenv (StencilRepr sh stencil1 -> StencilRepr sh stencil2 -> c)
convertSharingStencilFun2 config _ _ alyt aenv stencilFun = Lam (Lam (Body openStencilFun))
where
stencil1 = Exp undefined :: Exp (StencilRepr sh stencil1)
stencil2 = Exp undefined :: Exp (StencilRepr sh stencil2)
lyt = EmptyLayout
`PushLayout`
(SuccIdx ZeroIdx :: Idx (((), StencilRepr sh stencil1),
StencilRepr sh stencil2)
(StencilRepr sh stencil1))
`PushLayout`
(ZeroIdx :: Idx (((), StencilRepr sh stencil1),
StencilRepr sh stencil2)
(StencilRepr sh stencil2))
body = stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil1)
(stencilPrj (undefined::sh) (undefined::b) stencil2)
openStencilFun = convertSharingExp config lyt alyt [] aenv body
data StableASTName c where
StableASTName :: (Typeable c, Typeable t) => StableName (c t) -> StableASTName c
instance Show (StableASTName c) where
show (StableASTName sn) = show $ hashStableName sn
instance Eq (StableASTName c) where
StableASTName sn1 == StableASTName sn2
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
instance Hashable (StableASTName c) where
hashWithSalt s (StableASTName sn) = hashWithSalt s sn
makeStableAST :: c t -> IO (StableName (c t))
makeStableAST e = e `seq` makeStableName e
data StableNameHeight t = StableNameHeight (StableName t) Int
instance Eq (StableNameHeight t) where
(StableNameHeight sn1 _) == (StableNameHeight sn2 _) = sn1 == sn2
higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool
StableNameHeight _ h1 `higherSNH` StableNameHeight _ h2 = h1 > h2
hashStableNameHeight :: StableNameHeight t -> Int
hashStableNameHeight (StableNameHeight sn _) = hashStableName sn
type HashTable key val = Hash.BasicHashTable key val
type ASTHashTable c v = HashTable (StableASTName c) v
type OccMapHash c = ASTHashTable c (Int, Int)
newASTHashTable :: IO (ASTHashTable c v)
newASTHashTable = Hash.new
enterOcc :: OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc occMap sa height
= do
entry <- Hash.lookup occMap sa
case entry of
Nothing -> Hash.insert occMap sa (1 , height) >> return Nothing
Just (n, heightS) -> Hash.insert occMap sa (n + 1, heightS) >> return (Just heightS)
type OccMap c = IntMap.IntMap [(StableASTName c, Int)]
freezeOccMap :: OccMapHash c -> IO (OccMap c)
freezeOccMap oc
= do
ocl <- Hash.toList oc
traceChunk "OccMap" (show ocl)
return . IntMap.fromList
. map (\kvs -> (key (head kvs), kvs))
. groupBy sameKey
. map dropHeight
$ ocl
where
key (StableASTName sn, _) = hashStableName sn
sameKey kv1 kv2 = key kv1 == key kv2
dropHeight (k, (cnt, _)) = (k, cnt)
lookupWithASTName :: OccMap c -> StableASTName c -> Int
lookupWithASTName oc sa@(StableASTName sn)
= fromMaybe 1 $ IntMap.lookup (hashStableName sn) oc >>= Prelude.lookup sa
lookupWithSharingAcc :: OccMap Acc -> StableSharingAcc -> Int
lookupWithSharingAcc oc (StableSharingAcc (StableNameHeight sn _) _)
= lookupWithASTName oc (StableASTName sn)
lookupWithSharingExp :: OccMap Exp -> StableSharingExp -> Int
lookupWithSharingExp oc (StableSharingExp (StableNameHeight sn _) _)
= lookupWithASTName oc (StableASTName sn)
type StableAccName arrs = StableNameHeight (Acc arrs)
data SharingAcc acc exp arrs where
AvarSharing :: Arrays arrs
=> StableAccName arrs -> SharingAcc acc exp arrs
AletSharing :: StableSharingAcc -> acc arrs -> SharingAcc acc exp arrs
AccSharing :: Arrays arrs
=> StableAccName arrs -> PreAcc acc exp arrs -> SharingAcc acc exp arrs
data UnscopedAcc t = UnscopedAcc [Int] (SharingAcc UnscopedAcc RootExp t)
data ScopedAcc t = ScopedAcc [StableSharingAcc] (SharingAcc ScopedAcc ScopedExp t)
data StableSharingAcc where
StableSharingAcc :: Arrays arrs
=> StableAccName arrs
-> SharingAcc ScopedAcc ScopedExp arrs
-> StableSharingAcc
instance Show StableSharingAcc where
show (StableSharingAcc sn _) = show $ hashStableNameHeight sn
instance Eq StableSharingAcc where
StableSharingAcc sn1 _ == StableSharingAcc sn2 _
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool
StableSharingAcc sn1 _ `higherSSA` StableSharingAcc sn2 _ = sn1 `higherSNH` sn2
matchStableAcc :: Typeable arrs => StableAccName arrs -> StableSharingAcc -> Bool
matchStableAcc sn1 (StableSharingAcc sn2 _)
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
noStableAccName :: StableAccName arrs
noStableAccName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0
type StableExpName t = StableNameHeight (Exp t)
data SharingExp (acc :: * -> *) exp t where
VarSharing :: Elt t
=> StableExpName t -> SharingExp acc exp t
LetSharing :: StableSharingExp -> exp t -> SharingExp acc exp t
ExpSharing :: Elt t
=> StableExpName t -> PreExp acc exp t -> SharingExp acc exp t
data UnscopedExp t = UnscopedExp [Int] (SharingExp UnscopedAcc UnscopedExp t)
data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t)
data RootExp t = RootExp (OccMap Exp) (UnscopedExp t)
data StableSharingExp where
StableSharingExp :: Elt t => StableExpName t -> SharingExp ScopedAcc ScopedExp t -> StableSharingExp
instance Show StableSharingExp where
show (StableSharingExp sn _) = show $ hashStableNameHeight sn
instance Eq StableSharingExp where
StableSharingExp sn1 _ == StableSharingExp sn2 _
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
higherSSE :: StableSharingExp -> StableSharingExp -> Bool
StableSharingExp sn1 _ `higherSSE` StableSharingExp sn2 _ = sn1 `higherSNH` sn2
matchStableExp :: Typeable t => StableExpName t -> StableSharingExp -> Bool
matchStableExp sn1 (StableSharingExp sn2 _)
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
noStableExpName :: StableExpName t
noStableExpName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0
makeOccMapAcc
:: Typeable arrs
=> Config
-> Level
-> Acc arrs
-> IO (UnscopedAcc arrs, OccMap Acc)
makeOccMapAcc config lvl acc = do
traceLine "makeOccMapAcc" "Enter"
accOccMap <- newASTHashTable
(acc', _) <- makeOccMapSharingAcc config accOccMap lvl acc
frozenAccOccMap <- freezeOccMap accOccMap
traceLine "makeOccMapAcc" "Exit"
return (acc', frozenAccOccMap)
makeOccMapSharingAcc
:: Typeable arrs
=> Config
-> OccMapHash Acc
-> Level
-> Acc arrs
-> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc config accOccMap = traverseAcc
where
traverseFun1 :: (Elt a, Typeable b) => Level -> (Exp a -> Exp b) -> IO (Exp a -> RootExp b, Int)
traverseFun1 = makeOccMapFun1 config accOccMap
traverseFun2 :: (Elt a, Elt b, Typeable c)
=> Level
-> (Exp a -> Exp b -> Exp c)
-> IO (Exp a -> Exp b -> RootExp c, Int)
traverseFun2 = makeOccMapFun2 config accOccMap
traverseAfun1 :: (Arrays a, Typeable b) => Level -> (Acc a -> Acc b) -> IO (Acc a -> UnscopedAcc b, Int)
traverseAfun1 = makeOccMapAfun1 config accOccMap
traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int)
traverseExp = makeOccMapExp config accOccMap
traverseBoundary
:: Level
-> PreBoundary Acc Exp t
-> IO (PreBoundary UnscopedAcc RootExp t, Int)
traverseBoundary lvl bndy =
case bndy of
Clamp -> return (Clamp, 0)
Mirror -> return (Mirror, 0)
Wrap -> return (Wrap, 0)
Constant v -> return (Constant v, 0)
Function f -> do
(f', h) <- traverseFun1 lvl f
return (Function f', h)
traverseAcc :: forall arrs. Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc lvl acc@(Acc pacc)
= mfix $ \ ~(_, height) -> do
sn <- makeStableAST acc
heightIfRepeatedOccurrence <- enterOcc accOccMap (StableASTName sn) height
traceLine (showPreAccOp pacc) $ do
let hash = show (hashStableName sn)
case heightIfRepeatedOccurrence of
Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
Nothing -> "first occurrence (sn = " ++ hash ++ ")"
let reconstruct :: Arrays arrs
=> IO (PreAcc UnscopedAcc RootExp arrs, Int)
-> IO (UnscopedAcc arrs, Int)
reconstruct newAcc
= case heightIfRepeatedOccurrence of
Just height | recoverAccSharing config
-> return (UnscopedAcc [] (AvarSharing (StableNameHeight sn height)), height)
_ -> do (acc, height) <- newAcc
return (UnscopedAcc [] (AccSharing (StableNameHeight sn height) acc), height)
case pacc of
Atag i -> reconstruct $ return (Atag i, 0)
Pipe afun1 afun2 acc -> reconstruct $ do
(afun1', h1) <- traverseAfun1 lvl afun1
(afun2', h2) <- traverseAfun1 lvl afun2
(acc', h3) <- traverseAcc lvl acc
return (Pipe afun1' afun2' acc'
, h1 `max` h2 `max` h3 + 1)
Aforeign ff afun acc -> reconstruct $ travA (Aforeign ff afun) acc
Acond e acc1 acc2 -> reconstruct $ do
(e' , h1) <- traverseExp lvl e
(acc1', h2) <- traverseAcc lvl acc1
(acc2', h3) <- traverseAcc lvl acc2
return (Acond e' acc1' acc2', h1 `max` h2 `max` h3 + 1)
Awhile pred iter init -> reconstruct $ do
(pred', h1) <- traverseAfun1 lvl pred
(iter', h2) <- traverseAfun1 lvl iter
(init', h3) <- traverseAcc lvl init
return (Awhile pred' iter' init'
, h1 `max` h2 `max` h3 + 1)
Atuple tup -> reconstruct $ do
(tup', h) <- travAtup tup
return (Atuple tup', h)
Aprj ix a -> reconstruct $ travA (Aprj ix) a
Use arr -> reconstruct $ return (Use arr, 1)
Unit e -> reconstruct $ do
(e', h) <- traverseExp lvl e
return (Unit e', h + 1)
Generate e f -> reconstruct $ do
(e', h1) <- traverseExp lvl e
(f', h2) <- traverseFun1 lvl f
return (Generate e' f', h1 `max` h2 + 1)
Reshape e acc -> reconstruct $ travEA Reshape e acc
Replicate e acc -> reconstruct $ travEA Replicate e acc
Slice acc e -> reconstruct $ travEA (flip Slice) e acc
Map f acc -> reconstruct $ do
(f' , h1) <- traverseFun1 lvl f
(acc', h2) <- traverseAcc lvl acc
return (Map f' acc', h1 `max` h2 + 1)
ZipWith f acc1 acc2 -> reconstruct $ travF2A2 ZipWith f acc1 acc2
Fold f e acc -> reconstruct $ travF2EA Fold f e acc
Fold1 f acc -> reconstruct $ travF2A Fold1 f acc
FoldSeg f e acc1 acc2 -> reconstruct $ do
(f' , h1) <- traverseFun2 lvl f
(e' , h2) <- traverseExp lvl e
(acc1', h3) <- traverseAcc lvl acc1
(acc2', h4) <- traverseAcc lvl acc2
return (FoldSeg f' e' acc1' acc2',
h1 `max` h2 `max` h3 `max` h4 + 1)
Fold1Seg f acc1 acc2 -> reconstruct $ travF2A2 Fold1Seg f acc1 acc2
Scanl f e acc -> reconstruct $ travF2EA Scanl f e acc
Scanl' f e acc -> reconstruct $ travF2EA Scanl' f e acc
Scanl1 f acc -> reconstruct $ travF2A Scanl1 f acc
Scanr f e acc -> reconstruct $ travF2EA Scanr f e acc
Scanr' f e acc -> reconstruct $ travF2EA Scanr' f e acc
Scanr1 f acc -> reconstruct $ travF2A Scanr1 f acc
Permute c acc1 p acc2 -> reconstruct $ do
(c' , h1) <- traverseFun2 lvl c
(p' , h2) <- traverseFun1 lvl p
(acc1', h3) <- traverseAcc lvl acc1
(acc2', h4) <- traverseAcc lvl acc2
return (Permute c' acc1' p' acc2',
h1 `max` h2 `max` h3 `max` h4 + 1)
Backpermute e p acc -> reconstruct $ do
(e' , h1) <- traverseExp lvl e
(p' , h2) <- traverseFun1 lvl p
(acc', h3) <- traverseAcc lvl acc
return (Backpermute e' p' acc', h1 `max` h2 `max` h3 + 1)
Stencil s bnd acc -> reconstruct $ do
(s' , h1) <- makeOccMapStencil1 config accOccMap acc lvl s
(bnd', h2) <- traverseBoundary lvl bnd
(acc', h3) <- traverseAcc lvl acc
return (Stencil s' bnd' acc', h1 `max` h2 `max` h3 + 1)
Stencil2 s bnd1 acc1
bnd2 acc2 -> reconstruct $ do
(s' , h1) <- makeOccMapStencil2 config accOccMap acc1 acc2 lvl s
(bnd1', h2) <- traverseBoundary lvl bnd1
(acc1', h3) <- traverseAcc lvl acc1
(bnd2', h4) <- traverseBoundary lvl bnd2
(acc2', h5) <- traverseAcc lvl acc2
return (Stencil2 s' bnd1' acc1' bnd2' acc2',
h1 `max` h2 `max` h3 `max` h4 `max` h5 + 1)
where
travA :: Arrays arrs'
=> (UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs)
-> Acc arrs' -> IO (PreAcc UnscopedAcc RootExp arrs, Int)
travA c acc
= do
(acc', h) <- traverseAcc lvl acc
return (c acc', h + 1)
travEA :: (Typeable b, Arrays arrs')
=> (RootExp b -> UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs)
-> Exp b -> Acc arrs' -> IO (PreAcc UnscopedAcc RootExp arrs, Int)
travEA c exp acc
= do
(exp', h1) <- traverseExp lvl exp
(acc', h2) <- traverseAcc lvl acc
return (c exp' acc', h1 `max` h2 + 1)
travF2A :: (Elt b, Elt c, Typeable d, Arrays arrs')
=> ((Exp b -> Exp c -> RootExp d) -> UnscopedAcc arrs'
-> PreAcc UnscopedAcc RootExp arrs)
-> (Exp b -> Exp c -> Exp d) -> Acc arrs'
-> IO (PreAcc UnscopedAcc RootExp arrs, Int)
travF2A c fun acc
= do
(fun', h1) <- traverseFun2 lvl fun
(acc', h2) <- traverseAcc lvl acc
return (c fun' acc', h1 `max` h2 + 1)
travF2EA :: (Elt b, Elt c, Typeable d, Typeable e, Arrays arrs')
=> ((Exp b -> Exp c -> RootExp d) -> RootExp e -> UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs)
-> (Exp b -> Exp c -> Exp d) -> Exp e -> Acc arrs'
-> IO (PreAcc UnscopedAcc RootExp arrs, Int)
travF2EA c fun exp acc
= do
(fun', h1) <- traverseFun2 lvl fun
(exp', h2) <- traverseExp lvl exp
(acc', h3) <- traverseAcc lvl acc
return (c fun' exp' acc', h1 `max` h2 `max` h3 + 1)
travF2A2 :: (Elt b, Elt c, Typeable d, Arrays arrs1, Arrays arrs2)
=> ((Exp b -> Exp c -> RootExp d) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> PreAcc UnscopedAcc RootExp arrs)
-> (Exp b -> Exp c -> Exp d) -> Acc arrs1 -> Acc arrs2
-> IO (PreAcc UnscopedAcc RootExp arrs, Int)
travF2A2 c fun acc1 acc2
= do
(fun' , h1) <- traverseFun2 lvl fun
(acc1', h2) <- traverseAcc lvl acc1
(acc2', h3) <- traverseAcc lvl acc2
return (c fun' acc1' acc2', h1 `max` h2 `max` h3 + 1)
travAtup :: Atuple Acc a
-> IO (Atuple UnscopedAcc a, Int)
travAtup NilAtup = return (NilAtup, 1)
travAtup (SnocAtup tup a) = do
(tup', h1) <- travAtup tup
(a', h2) <- traverseAcc lvl a
return (SnocAtup tup' a', h1 `max` h2 + 1)
makeOccMapAfun1 :: (Arrays a, Typeable b)
=> Config
-> OccMapHash Acc
-> Level
-> (Acc a -> Acc b)
-> IO (Acc a -> UnscopedAcc b, Int)
makeOccMapAfun1 config accOccMap lvl f = do
let x = Acc (Atag lvl)
(UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+1) (f x)
return (const (UnscopedAcc [lvl] body), height)
makeOccMapExp
:: Typeable e
=> Config
-> OccMapHash Acc
-> Level
-> Exp e
-> IO (RootExp e, Int)
makeOccMapExp config accOccMap lvl = makeOccMapRootExp config accOccMap lvl []
makeOccMapFun1
:: (Elt a, Typeable b)
=> Config
-> OccMapHash Acc
-> Level
-> (Exp a -> Exp b)
-> IO (Exp a -> RootExp b, Int)
makeOccMapFun1 config accOccMap lvl f = do
let x = Exp (Tag lvl)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x)
return (const body, height)
makeOccMapFun2
:: (Elt a, Elt b, Typeable c)
=> Config
-> OccMapHash Acc
-> Level
-> (Exp a -> Exp b -> Exp c)
-> IO (Exp a -> Exp b -> RootExp c, Int)
makeOccMapFun2 config accOccMap lvl f = do
let x = Exp (Tag (lvl+1))
y = Exp (Tag lvl)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y)
return (\_ _ -> body, height)
makeOccMapStencil1
:: forall sh a b stencil. (Stencil sh a stencil, Typeable b)
=> Config
-> OccMapHash Acc
-> Acc (Array sh a)
-> Level
-> (stencil -> Exp b)
-> IO (stencil -> RootExp b, Int)
makeOccMapStencil1 config accOccMap _ lvl stencil = do
let x = Exp (Tag lvl)
f = stencil . stencilPrj (undefined::sh) (undefined::a)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x)
return (const body, height)
makeOccMapStencil2
:: forall sh a b c stencil1 stencil2. (Stencil sh a stencil1, Stencil sh b stencil2, Typeable c)
=> Config
-> OccMapHash Acc
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Level
-> (stencil1 -> stencil2 -> Exp c)
-> IO (stencil1 -> stencil2 -> RootExp c, Int)
makeOccMapStencil2 config accOccMap _ _ lvl stencil = do
let x = Exp (Tag (lvl+1))
y = Exp (Tag lvl)
f a b = stencil (stencilPrj (undefined::sh) (undefined::a) a)
(stencilPrj (undefined::sh) (undefined::b) b)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y)
return (\_ _ -> body, height)
makeOccMapRootExp
:: Typeable e
=> Config
-> OccMapHash Acc
-> Level
-> [Int]
-> Exp e
-> IO (RootExp e, Int)
makeOccMapRootExp config accOccMap lvl fvs exp = do
traceLine "makeOccMapRootExp" "Enter"
expOccMap <- newASTHashTable
(UnscopedExp [] exp', height) <- makeOccMapSharingExp config accOccMap expOccMap lvl exp
frozenExpOccMap <- freezeOccMap expOccMap
traceLine "makeOccMapRootExp" "Exit"
return (RootExp frozenExpOccMap (UnscopedExp fvs exp'), height)
makeOccMapSharingExp
:: Typeable e
=> Config
-> OccMapHash Acc
-> OccMapHash Exp
-> Level
-> Exp e
-> IO (UnscopedExp e, Int)
makeOccMapSharingExp config accOccMap expOccMap = travE
where
travE :: forall a. Typeable a => Level -> Exp a -> IO (UnscopedExp a, Int)
travE lvl exp@(Exp pexp)
= mfix $ \ ~(_, height) -> do
sn <- makeStableAST exp
heightIfRepeatedOccurrence <- enterOcc expOccMap (StableASTName sn) height
traceLine (showPreExpOp pexp) $ do
let hash = show (hashStableName sn)
case heightIfRepeatedOccurrence of
Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
Nothing -> "first occurrence (sn = " ++ hash ++ ")"
let reconstruct :: Elt a
=> IO (PreExp UnscopedAcc UnscopedExp a, Int)
-> IO (UnscopedExp a, Int)
reconstruct newExp
= case heightIfRepeatedOccurrence of
Just height | recoverExpSharing config
-> return (UnscopedExp [] (VarSharing (StableNameHeight sn height)), height)
_ -> do (exp, height) <- newExp
return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height)
case pexp of
Tag i -> reconstruct $ return (Tag i, 0)
Const c -> reconstruct $ return (Const c, 1)
Tuple tup -> reconstruct $ do
(tup', h) <- travTup tup
return (Tuple tup', h)
Prj i e -> reconstruct $ travE1 (Prj i) e
IndexNil -> reconstruct $ return (IndexNil, 1)
IndexCons ix i -> reconstruct $ travE2 IndexCons ix i
IndexHead i -> reconstruct $ travE1 IndexHead i
IndexTail ix -> reconstruct $ travE1 IndexTail ix
IndexAny -> reconstruct $ return (IndexAny, 1)
ToIndex sh ix -> reconstruct $ travE2 ToIndex sh ix
FromIndex sh e -> reconstruct $ travE2 FromIndex sh e
Cond e1 e2 e3 -> reconstruct $ travE3 Cond e1 e2 e3
While p iter init -> reconstruct $ do
(p' , h1) <- traverseFun1 lvl p
(iter', h2) <- traverseFun1 lvl iter
(init', h3) <- travE lvl init
return (While p' iter' init', h1 `max` h2 `max` h3 + 1)
PrimConst c -> reconstruct $ return (PrimConst c, 1)
PrimApp p e -> reconstruct $ travE1 (PrimApp p) e
Index a e -> reconstruct $ travAE Index a e
LinearIndex a i -> reconstruct $ travAE LinearIndex a i
Shape a -> reconstruct $ travA Shape a
ShapeSize e -> reconstruct $ travE1 ShapeSize e
Intersect sh1 sh2 -> reconstruct $ travE2 Intersect sh1 sh2
Union sh1 sh2 -> reconstruct $ travE2 Union sh1 sh2
Foreign ff f e -> reconstruct $ do
(e', h) <- travE lvl e
return (Foreign ff f e', h+1)
where
traverseAcc :: Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int)
traverseAcc = makeOccMapSharingAcc config accOccMap
traverseFun1 :: (Elt a, Typeable b)
=> Level
-> (Exp a -> Exp b)
-> IO (Exp a -> UnscopedExp b, Int)
traverseFun1 lvl f
= do
let x = Exp (Tag lvl)
(UnscopedExp [] body, height) <- travE (lvl+1) (f x)
return (const (UnscopedExp [lvl] body), height + 1)
travE1 :: Typeable b => (UnscopedExp b -> PreExp UnscopedAcc UnscopedExp a) -> Exp b
-> IO (PreExp UnscopedAcc UnscopedExp a, Int)
travE1 c e
= do
(e', h) <- travE lvl e
return (c e', h + 1)
travE2 :: (Typeable b, Typeable c)
=> (UnscopedExp b -> UnscopedExp c -> PreExp UnscopedAcc UnscopedExp a)
-> Exp b -> Exp c
-> IO (PreExp UnscopedAcc UnscopedExp a, Int)
travE2 c e1 e2
= do
(e1', h1) <- travE lvl e1
(e2', h2) <- travE lvl e2
return (c e1' e2', h1 `max` h2 + 1)
travE3 :: (Typeable b, Typeable c, Typeable d)
=> (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> PreExp UnscopedAcc UnscopedExp a)
-> Exp b -> Exp c -> Exp d
-> IO (PreExp UnscopedAcc UnscopedExp a, Int)
travE3 c e1 e2 e3
= do
(e1', h1) <- travE lvl e1
(e2', h2) <- travE lvl e2
(e3', h3) <- travE lvl e3
return (c e1' e2' e3', h1 `max` h2 `max` h3 + 1)
travA :: Typeable b => (UnscopedAcc b -> PreExp UnscopedAcc UnscopedExp a) -> Acc b
-> IO (PreExp UnscopedAcc UnscopedExp a, Int)
travA c acc
= do
(acc', h) <- traverseAcc lvl acc
return (c acc', h + 1)
travAE :: (Typeable b, Typeable c)
=> (UnscopedAcc b -> UnscopedExp c -> PreExp UnscopedAcc UnscopedExp a)
-> Acc b -> Exp c
-> IO (PreExp UnscopedAcc UnscopedExp a, Int)
travAE c acc e
= do
(acc', h1) <- traverseAcc lvl acc
(e' , h2) <- travE lvl e
return (c acc' e', h1 `max` h2 + 1)
travTup :: Tuple Exp tup -> IO (Tuple UnscopedExp tup, Int)
travTup NilTup = return (NilTup, 1)
travTup (SnocTup tup e) = do
(tup', h1) <- travTup tup
(e' , h2) <- travE lvl e
return (SnocTup tup' e', h1 `max` h2 + 1)
type NodeCounts = ([NodeCount], Map.HashMap NodeName (Set.HashSet NodeName))
data NodeName where
NodeName :: Typeable a => StableName a -> NodeName
instance Eq NodeName where
(NodeName sn1) == (NodeName sn2) | Just sn2' <- gcast sn2 = sn1 == sn2'
| otherwise = False
instance Hashable NodeName where
hashWithSalt hash (NodeName sn1) = hash + hashStableName sn1
instance Show NodeName where
show (NodeName sn) = show (hashStableName sn)
data NodeCount = AccNodeCount StableSharingAcc Int
| ExpNodeCount StableSharingExp Int
deriving Show
noNodeCounts :: NodeCounts
noNodeCounts = ([], Map.empty)
insertAccNode :: StableSharingAcc -> NodeCounts -> NodeCounts
insertAccNode ssa@(StableSharingAcc (StableNameHeight sn _) _) (subterms,g)
= ([AccNodeCount ssa 1], g') +++ (subterms,g)
where
k = NodeName sn
hs = map nodeName subterms
g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs]
insertExpNode :: StableSharingExp -> NodeCounts -> NodeCounts
insertExpNode ssa@(StableSharingExp (StableNameHeight sn _) _) (subterms,g)
= ([ExpNodeCount ssa 1], g') +++ (subterms,g)
where
k = NodeName sn
hs = map nodeName subterms
g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs]
cleanCounts :: NodeCounts -> NodeCounts
cleanCounts (ns, g) = (ns, Map.fromList $ [(h, Set.filter (flip elem hs) (g Map.! h)) | h <- hs ])
where
hs = (map nodeName ns)
nodeName :: NodeCount -> NodeName
nodeName (AccNodeCount (StableSharingAcc (StableNameHeight sn _) _) _) = NodeName sn
nodeName (ExpNodeCount (StableSharingExp (StableNameHeight sn _) _) _) = NodeName sn
(+++) :: NodeCounts -> NodeCounts -> NodeCounts
(ns1,g1) +++ (ns2,g2) = (foldr insert ns1 ns2, Map.unionWith Set.union g1 g2)
where
insert x [] = [x]
insert x@(AccNodeCount sa1 count1) ys@(y@(AccNodeCount sa2 count2) : ys')
| sa1 == sa2 = AccNodeCount (sa1 `pickNoneAvar` sa2) (count1 + count2) : ys'
| sa1 `higherSSA` sa2 = x : ys
| otherwise = y : insert x ys'
insert x@(ExpNodeCount se1 count1) ys@(y@(ExpNodeCount se2 count2) : ys')
| se1 == se2 = ExpNodeCount (se1 `pickNoneVar` se2) (count1 + count2) : ys'
| se1 `higherSSE` se2 = x : ys
| otherwise = y : insert x ys'
insert x@(AccNodeCount _ _) (y@(ExpNodeCount _ _) : ys')
= y : insert x ys'
insert x@(ExpNodeCount _ _) (y@(AccNodeCount _ _) : ys')
= x : insert y ys'
(StableSharingAcc _ (AvarSharing _)) `pickNoneAvar` sa2 = sa2
sa1 `pickNoneAvar` _sa2 = sa1
(StableSharingExp _ (VarSharing _)) `pickNoneVar` sa2 = sa2
sa1 `pickNoneVar` _sa2 = sa1
buildInitialEnvAcc :: [Level] -> [StableSharingAcc] -> [StableSharingAcc]
buildInitialEnvAcc tags sas = map (lookupSA sas) tags
where
lookupSA sas tag1
= case filter hasTag sas of
[] -> noStableSharing
[sa] -> sa
sas2 -> $internalError "buildInitialEnvAcc"
$ "Encountered duplicate 'ATag's\n " ++ intercalate ", " (map showSA sas2)
where
hasTag (StableSharingAcc _ (AccSharing _ (Atag tag2))) = tag1 == tag2
hasTag sa
= $internalError "buildInitialEnvAcc"
$ "Encountered a node that is not a plain 'Atag'\n " ++ showSA sa
noStableSharing :: StableSharingAcc
noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ())
showSA (StableSharingAcc _ (AccSharing sn acc)) = show (hashStableNameHeight sn) ++ ": " ++
showPreAccOp acc
showSA (StableSharingAcc _ (AvarSharing sn)) = "AvarSharing " ++ show (hashStableNameHeight sn)
showSA (StableSharingAcc _ (AletSharing sa _ )) = "AletSharing " ++ show sa ++ "..."
buildInitialEnvExp :: [Level] -> [StableSharingExp] -> [StableSharingExp]
buildInitialEnvExp tags ses = map (lookupSE ses) tags
where
lookupSE ses tag1
= case filter hasTag ses of
[] -> noStableSharing
[se] -> se
ses2 -> $internalError "buildInitialEnvExp"
("Encountered a duplicate 'Tag'\n " ++ intercalate ", " (map showSE ses2))
where
hasTag (StableSharingExp _ (ExpSharing _ (Tag tag2))) = tag1 == tag2
hasTag se
= $internalError "buildInitialEnvExp"
("Encountered a node that is not a plain 'Tag'\n " ++ showSE se)
noStableSharing :: StableSharingExp
noStableSharing = StableSharingExp noStableExpName (undefined :: SharingExp acc exp ())
showSE (StableSharingExp _ (ExpSharing sn exp)) = show (hashStableNameHeight sn) ++ ": " ++
showPreExpOp exp
showSE (StableSharingExp _ (VarSharing sn)) = "VarSharing " ++ show (hashStableNameHeight sn)
showSE (StableSharingExp _ (LetSharing se _ )) = "LetSharing " ++ show se ++ "..."
isFreeVar :: NodeCount -> Bool
isFreeVar (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _))) _) = True
isFreeVar (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _))) _) = True
isFreeVar _ = False
determineScopesAcc
:: Typeable a
=> Config
-> [Level]
-> OccMap Acc
-> UnscopedAcc a
-> (ScopedAcc a, [StableSharingAcc])
determineScopesAcc config fvs accOccMap rootAcc
= let (sharingAcc, (counts, _)) = determineScopesSharingAcc config accOccMap rootAcc
unboundTrees = filter (not . isFreeVar) counts
in
if all isFreeVar counts
then (sharingAcc, buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- counts])
else $internalError "determineScopesAcc" ("unbound shared subtrees" ++ show unboundTrees)
determineScopesSharingAcc
:: Config
-> OccMap Acc
-> UnscopedAcc a
-> (ScopedAcc a, NodeCounts)
determineScopesSharingAcc config accOccMap = scopesAcc
where
scopesAcc :: forall arrs. UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
scopesAcc (UnscopedAcc _ (AletSharing _ _))
= $internalError "determineScopesSharingAcc: scopesAcc" "unexpected 'AletSharing'"
scopesAcc (UnscopedAcc _ (AvarSharing sn))
= (ScopedAcc [] (AvarSharing sn), StableSharingAcc sn (AvarSharing sn) `insertAccNode` noNodeCounts)
scopesAcc (UnscopedAcc _ (AccSharing sn pacc))
= case pacc of
Atag i -> reconstruct (Atag i) noNodeCounts
Pipe afun1 afun2 acc -> let
(afun1', accCount1) = scopesAfun1 afun1
(afun2', accCount2) = scopesAfun1 afun2
(acc', accCount3) = scopesAcc acc
in
reconstruct (Pipe afun1' afun2' acc')
(accCount1 +++ accCount2 +++ accCount3)
Aforeign ff afun acc -> let
(acc', accCount) = scopesAcc acc
in
reconstruct (Aforeign ff afun acc') accCount
Acond e acc1 acc2 -> let
(e' , accCount1) = scopesExp e
(acc1', accCount2) = scopesAcc acc1
(acc2', accCount3) = scopesAcc acc2
in
reconstruct (Acond e' acc1' acc2')
(accCount1 +++ accCount2 +++ accCount3)
Awhile pred iter init -> let
(pred', accCount1) = scopesAfun1 pred
(iter', accCount2) = scopesAfun1 iter
(init', accCount3) = scopesAcc init
in
reconstruct (Awhile pred' iter' init')
(accCount1 +++ accCount2 +++ accCount3)
Atuple tup -> let (tup', accCount) = travAtup tup
in reconstruct (Atuple tup') accCount
Aprj ix a -> travA (Aprj ix) a
Use arr -> reconstruct (Use arr) noNodeCounts
Unit e -> let
(e', accCount) = scopesExp e
in
reconstruct (Unit e') accCount
Generate sh f -> let
(sh', accCount1) = scopesExp sh
(f' , accCount2) = scopesFun1 f
in
reconstruct (Generate sh' f') (accCount1 +++ accCount2)
Reshape sh acc -> travEA Reshape sh acc
Replicate n acc -> travEA Replicate n acc
Slice acc i -> travEA (flip Slice) i acc
Map f acc -> let
(f' , accCount1) = scopesFun1 f
(acc', accCount2) = scopesAcc acc
in
reconstruct (Map f' acc') (accCount1 +++ accCount2)
ZipWith f acc1 acc2 -> travF2A2 ZipWith f acc1 acc2
Fold f z acc -> travF2EA Fold f z acc
Fold1 f acc -> travF2A Fold1 f acc
FoldSeg f z acc1 acc2 -> let
(f' , accCount1) = scopesFun2 f
(z' , accCount2) = scopesExp z
(acc1', accCount3) = scopesAcc acc1
(acc2', accCount4) = scopesAcc acc2
in
reconstruct (FoldSeg f' z' acc1' acc2')
(accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
Fold1Seg f acc1 acc2 -> travF2A2 Fold1Seg f acc1 acc2
Scanl f z acc -> travF2EA Scanl f z acc
Scanl' f z acc -> travF2EA Scanl' f z acc
Scanl1 f acc -> travF2A Scanl1 f acc
Scanr f z acc -> travF2EA Scanr f z acc
Scanr' f z acc -> travF2EA Scanr' f z acc
Scanr1 f acc -> travF2A Scanr1 f acc
Permute fc acc1 fp acc2 -> let
(fc' , accCount1) = scopesFun2 fc
(acc1', accCount2) = scopesAcc acc1
(fp' , accCount3) = scopesFun1 fp
(acc2', accCount4) = scopesAcc acc2
in
reconstruct (Permute fc' acc1' fp' acc2')
(accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
Backpermute sh fp acc -> let
(sh' , accCount1) = scopesExp sh
(fp' , accCount2) = scopesFun1 fp
(acc', accCount3) = scopesAcc acc
in
reconstruct (Backpermute sh' fp' acc')
(accCount1 +++ accCount2 +++ accCount3)
Stencil st bnd acc -> let
(st' , accCount1) = scopesStencil1 acc st
(bnd', accCount2) = scopesBoundary bnd
(acc', accCount3) = scopesAcc acc
in
reconstruct (Stencil st' bnd' acc') (accCount1 +++ accCount2 +++ accCount3)
Stencil2 st bnd1 acc1 bnd2 acc2
-> let
(st' , accCount1) = scopesStencil2 acc1 acc2 st
(bnd1', accCount2) = scopesBoundary bnd1
(acc1', accCount3) = scopesAcc acc1
(bnd2', accCount4) = scopesBoundary bnd2
(acc2', accCount5) = scopesAcc acc2
in
reconstruct (Stencil2 st' bnd1' acc1' bnd2' acc2')
(accCount1 +++ accCount2 +++ accCount3 +++ accCount4 +++ accCount5)
where
travEA :: (ScopedExp e -> ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs)
-> RootExp e
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travEA c e acc = reconstruct (c e' acc') (accCount1 +++ accCount2)
where
(e' , accCount1) = scopesExp e
(acc', accCount2) = scopesAcc acc
travF2A :: (Elt a, Elt b)
=> ((Exp a -> Exp b -> ScopedExp c) -> ScopedAcc arrs'
-> PreAcc ScopedAcc ScopedExp arrs)
-> (Exp a -> Exp b -> RootExp c)
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2A c f acc = reconstruct (c f' acc') (accCount1 +++ accCount2)
where
(f' , accCount1) = scopesFun2 f
(acc', accCount2) = scopesAcc acc
travF2EA :: (Elt a, Elt b)
=> ((Exp a -> Exp b -> ScopedExp c) -> ScopedExp e
-> ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs)
-> (Exp a -> Exp b -> RootExp c)
-> RootExp e
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travF2EA c f e acc = reconstruct (c f' e' acc') (accCount1 +++ accCount2 +++ accCount3)
where
(f' , accCount1) = scopesFun2 f
(e' , accCount2) = scopesExp e
(acc', accCount3) = scopesAcc acc
travF2A2 :: (Elt a, Elt b)
=> ((Exp a -> Exp b -> ScopedExp c) -> ScopedAcc arrs1
-> ScopedAcc arrs2 -> PreAcc ScopedAcc ScopedExp arrs)
-> (Exp a -> Exp b -> RootExp c)
-> UnscopedAcc arrs1
-> UnscopedAcc arrs2
-> (ScopedAcc arrs, NodeCounts)
travF2A2 c f acc1 acc2 = reconstruct (c f' acc1' acc2')
(accCount1 +++ accCount2 +++ accCount3)
where
(f' , accCount1) = scopesFun2 f
(acc1', accCount2) = scopesAcc acc1
(acc2', accCount3) = scopesAcc acc2
travAtup :: Atuple UnscopedAcc a
-> (Atuple ScopedAcc a, NodeCounts)
travAtup NilAtup = (NilAtup, noNodeCounts)
travAtup (SnocAtup tup a) = let (tup', accCountT) = travAtup tup
(a', accCountA) = scopesAcc a
in
(SnocAtup tup' a', accCountT +++ accCountA)
travA :: (ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs)
-> UnscopedAcc arrs'
-> (ScopedAcc arrs, NodeCounts)
travA c acc = reconstruct (c acc') accCount
where
(acc', accCount) = scopesAcc acc
accOccCount = let StableNameHeight sn' _ = sn
in
lookupWithASTName accOccMap (StableASTName sn')
reconstruct :: PreAcc ScopedAcc ScopedExp arrs
-> NodeCounts
-> (ScopedAcc arrs, NodeCounts)
reconstruct newAcc@(Atag _) _subCount
= let thisCount = StableSharingAcc sn (AccSharing sn newAcc) `insertAccNode` noNodeCounts
in
tracePure "FREE" (show thisCount)
(ScopedAcc [] (AvarSharing sn), thisCount)
reconstruct newAcc subCount
| accOccCount > 1 && recoverAccSharing config
= let allCount = (StableSharingAcc sn sharingAcc `insertAccNode` newCount)
in
tracePure ("SHARED" ++ completed) (show allCount)
(ScopedAcc [] (AvarSharing sn), allCount)
| otherwise
= tracePure ("Normal" ++ completed) (show newCount)
(ScopedAcc [] sharingAcc, newCount)
where
(newCount, bindHere) = filterCompleted subCount
lets = foldl (flip (.)) id . map (\x y -> AletSharing x (ScopedAcc [] y)) $ bindHere
sharingAcc = lets $ AccSharing sn newAcc
completed | null bindHere = ""
| otherwise = "(" ++ show (length bindHere) ++ " lets)"
filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
filterCompleted (ns, graph)
= let bindable = map (isBindable bindable (map nodeName ns)) ns
(bind, rest) = partition fst $ zip bindable ns
in ((map snd rest, graph), [sa | AccNodeCount sa _ <- map snd bind])
where
isCompleted nc@(AccNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingAcc accOccMap sa == n
isCompleted _ = False
isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
isBindable bindable nodes nc@(AccNodeCount _ _) =
let superTerms = Set.toList $ graph Map.! nodeName nc
unbound = mapMaybe (`elemIndex` nodes) superTerms
in isCompleted nc
&& all (bindable !!) unbound
isBindable _ _ (ExpNodeCount _ _) = False
scopesExp :: RootExp t -> (ScopedExp t, NodeCounts)
scopesExp = determineScopesExp config accOccMap
scopesAfun1 :: Arrays a1 => (Acc a1 -> UnscopedAcc a2) -> (Acc a1 -> ScopedAcc a2, NodeCounts)
scopesAfun1 f = (const (ScopedAcc ssa body'), (counts',graph))
where
body@(UnscopedAcc fvs _) = f undefined
((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
(freeCounts, counts') = partition isBoundHere counts
isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
isBoundHere _ = False
scopesFun1 :: Elt e1 => (Exp e1 -> RootExp e2) -> (Exp e1 -> ScopedExp e2, NodeCounts)
scopesFun1 f = (const body, counts)
where
(body, counts) = scopesExp (f undefined)
scopesFun2 :: (Elt e1, Elt e2)
=> (Exp e1 -> Exp e2 -> RootExp e3)
-> (Exp e1 -> Exp e2 -> ScopedExp e3, NodeCounts)
scopesFun2 f = (\_ _ -> body, counts)
where
(body, counts) = scopesExp (f undefined undefined)
scopesStencil1 :: forall sh e1 e2 stencil. Stencil sh e1 stencil
=> UnscopedAcc (Array sh e1)
-> (stencil -> RootExp e2)
-> (stencil -> ScopedExp e2, NodeCounts)
scopesStencil1 _ stencilFun = (const body, counts)
where
(body, counts) = scopesExp (stencilFun undefined)
scopesStencil2 :: forall sh e1 e2 e3 stencil1 stencil2.
(Stencil sh e1 stencil1, Stencil sh e2 stencil2)
=> UnscopedAcc (Array sh e1)
-> UnscopedAcc (Array sh e2)
-> (stencil1 -> stencil2 -> RootExp e3)
-> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts)
scopesStencil2 _ _ stencilFun = (\_ _ -> body, counts)
where
(body, counts) = scopesExp (stencilFun undefined undefined)
scopesBoundary :: PreBoundary UnscopedAcc RootExp t
-> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
scopesBoundary bndy =
case bndy of
Clamp -> (Clamp, noNodeCounts)
Mirror -> (Mirror, noNodeCounts)
Wrap -> (Wrap, noNodeCounts)
Constant v -> (Constant v, noNodeCounts)
Function f -> let (body, counts) = scopesFun1 f
in (Function body, counts)
determineScopesExp
:: Config
-> OccMap Acc
-> RootExp t
-> (ScopedExp t, NodeCounts)
determineScopesExp config accOccMap (RootExp expOccMap exp@(UnscopedExp fvs _))
= let
((ScopedExp [] expWithScopes), (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp
(expCounts, accCounts) = partition isExpNodeCount nodeCounts
isExpNodeCount ExpNodeCount{} = True
isExpNodeCount _ = False
in
(ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes, cleanCounts (accCounts,graph))
determineScopesSharingExp
:: Config
-> OccMap Acc
-> OccMap Exp
-> UnscopedExp t
-> (ScopedExp t, NodeCounts)
determineScopesSharingExp config accOccMap expOccMap = scopesExp
where
scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts)
scopesAcc = determineScopesSharingAcc config accOccMap
scopesFun1 :: (Exp a -> UnscopedExp b) -> (Exp a -> ScopedExp b, NodeCounts)
scopesFun1 f = tracePure ("LAMBDA " ++ (show ssa)) (show counts) (const (ScopedExp ssa body'), (counts',graph))
where
body@(UnscopedExp fvs _) = f undefined
((ScopedExp [] body'), (counts, graph)) = scopesExp body
ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts]
(freeCounts, counts') = partition isBoundHere counts
isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag i))) _) = i `elem` fvs
isBoundHere _ = False
scopesExp :: forall t. UnscopedExp t -> (ScopedExp t, NodeCounts)
scopesExp (UnscopedExp _ (LetSharing _ _))
= $internalError "determineScopesSharingExp: scopesExp" "unexpected 'LetSharing'"
scopesExp (UnscopedExp _ (VarSharing sn))
= (ScopedExp [] (VarSharing sn), StableSharingExp sn (VarSharing sn) `insertExpNode` noNodeCounts)
scopesExp (UnscopedExp _ (ExpSharing sn pexp))
= case pexp of
Tag i -> reconstruct (Tag i) noNodeCounts
Const c -> reconstruct (Const c) noNodeCounts
Tuple tup -> let (tup', accCount) = travTup tup
in
reconstruct (Tuple tup') accCount
Prj i e -> travE1 (Prj i) e
IndexNil -> reconstruct IndexNil noNodeCounts
IndexCons ix i -> travE2 IndexCons ix i
IndexHead i -> travE1 IndexHead i
IndexTail ix -> travE1 IndexTail ix
IndexAny -> reconstruct IndexAny noNodeCounts
ToIndex sh ix -> travE2 ToIndex sh ix
FromIndex sh e -> travE2 FromIndex sh e
Cond e1 e2 e3 -> travE3 Cond e1 e2 e3
While p it i -> let
(p' , accCount1) = scopesFun1 p
(it', accCount2) = scopesFun1 it
(i' , accCount3) = scopesExp i
in reconstruct (While p' it' i') (accCount1 +++ accCount2 +++ accCount3)
PrimConst c -> reconstruct (PrimConst c) noNodeCounts
PrimApp p e -> travE1 (PrimApp p) e
Index a e -> travAE Index a e
LinearIndex a e -> travAE LinearIndex a e
Shape a -> travA Shape a
ShapeSize e -> travE1 ShapeSize e
Intersect sh1 sh2 -> travE2 Intersect sh1 sh2
Union sh1 sh2 -> travE2 Union sh1 sh2
Foreign ff f e -> travE1 (Foreign ff f) e
where
travTup :: Tuple UnscopedExp tup -> (Tuple ScopedExp tup, NodeCounts)
travTup NilTup = (NilTup, noNodeCounts)
travTup (SnocTup tup e) = let
(tup', accCountT) = travTup tup
(e' , accCountE) = scopesExp e
in
(SnocTup tup' e', accCountT +++ accCountE)
travE1 :: (ScopedExp a -> PreExp ScopedAcc ScopedExp t) -> UnscopedExp a
-> (ScopedExp t, NodeCounts)
travE1 c e = reconstruct (c e') accCount
where
(e', accCount) = scopesExp e
travE2 :: (ScopedExp a -> ScopedExp b -> PreExp ScopedAcc ScopedExp t)
-> UnscopedExp a
-> UnscopedExp b
-> (ScopedExp t, NodeCounts)
travE2 c e1 e2 = reconstruct (c e1' e2') (accCount1 +++ accCount2)
where
(e1', accCount1) = scopesExp e1
(e2', accCount2) = scopesExp e2
travE3 :: (ScopedExp a -> ScopedExp b -> ScopedExp c -> PreExp ScopedAcc ScopedExp t)
-> UnscopedExp a
-> UnscopedExp b
-> UnscopedExp c
-> (ScopedExp t, NodeCounts)
travE3 c e1 e2 e3 = reconstruct (c e1' e2' e3') (accCount1 +++ accCount2 +++ accCount3)
where
(e1', accCount1) = scopesExp e1
(e2', accCount2) = scopesExp e2
(e3', accCount3) = scopesExp e3
travA :: (ScopedAcc a -> PreExp ScopedAcc ScopedExp t) -> UnscopedAcc a
-> (ScopedExp t, NodeCounts)
travA c acc = maybeFloatOutAcc c acc' accCount
where
(acc', accCount) = scopesAcc acc
travAE :: (ScopedAcc a -> ScopedExp b -> PreExp ScopedAcc ScopedExp t)
-> UnscopedAcc a
-> UnscopedExp b
-> (ScopedExp t, NodeCounts)
travAE c acc e = maybeFloatOutAcc (`c` e') acc' (accCountA +++ accCountE)
where
(acc', accCountA) = scopesAcc acc
(e' , accCountE) = scopesExp e
maybeFloatOutAcc :: (ScopedAcc a -> PreExp ScopedAcc ScopedExp t)
-> ScopedAcc a
-> NodeCounts
-> (ScopedExp t, NodeCounts)
maybeFloatOutAcc c acc@(ScopedAcc _ (AvarSharing _)) accCount
= reconstruct (c acc) accCount
maybeFloatOutAcc c acc accCount
| floatOutAcc config = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount)
| otherwise = reconstruct (c acc) accCount
where
(var, stableAcc) = abstract acc (\(ScopedAcc _ s) -> s)
abstract :: ScopedAcc a -> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
-> (ScopedAcc a, StableSharingAcc)
abstract (ScopedAcc _ (AvarSharing _)) _ = $internalError "sharingAccToVar" "AvarSharing"
abstract (ScopedAcc ssa (AletSharing sa acc)) lets = abstract acc (lets . (\x -> ScopedAcc ssa (AletSharing sa x)))
abstract acc@(ScopedAcc ssa (AccSharing sn _)) lets = (ScopedAcc ssa (AvarSharing sn), StableSharingAcc sn (lets acc))
expOccCount = let StableNameHeight sn' _ = sn
in
lookupWithASTName expOccMap (StableASTName sn')
reconstruct :: PreExp ScopedAcc ScopedExp t -> NodeCounts
-> (ScopedExp t, NodeCounts)
reconstruct newExp@(Tag _) _subCount
= let thisCount = StableSharingExp sn (ExpSharing sn newExp) `insertExpNode` noNodeCounts
in
tracePure "FREE" (show thisCount)
(ScopedExp [] (VarSharing sn), thisCount)
reconstruct newExp subCount
| expOccCount > 1 && recoverExpSharing config
= let allCount = StableSharingExp sn sharingExp `insertExpNode` newCount
in
tracePure ("SHARED" ++ completed) (show allCount)
(ScopedExp [] (VarSharing sn), allCount)
| otherwise
= tracePure ("Normal" ++ completed) (show newCount)
(ScopedExp [] sharingExp, newCount)
where
(newCount, bindHere) = filterCompleted subCount
lets = foldl (flip (.)) id . map (\x y -> LetSharing x (ScopedExp [] y)) $ bindHere
sharingExp = lets $ ExpSharing sn newExp
completed | null bindHere = ""
| otherwise = "(" ++ show (length bindHere) ++ " lets)"
filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingExp])
filterCompleted (ns,graph)
= let bindable = map (isBindable bindable (map nodeName ns)) ns
(bind, unbind) = partition fst $ zip bindable ns
in ((map snd unbind, graph), [se | ExpNodeCount se _ <- map snd bind])
where
isCompleted nc@(ExpNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingExp expOccMap sa == n
isCompleted _ = False
isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
isBindable bindable nodes nc@(ExpNodeCount _ _) =
let superTerms = Set.toList $ graph Map.! nodeName nc
unbound = mapMaybe (`elemIndex` nodes) superTerms
in isCompleted nc
&& all (bindable !!) unbound
isBindable _ _ (AccNodeCount _ _) = False
{-# NOINLINE recoverSharingAcc #-}
recoverSharingAcc
:: Typeable a
=> Config
-> Level
-> [Level]
-> Acc a
-> (ScopedAcc a, [StableSharingAcc])
recoverSharingAcc config alvl avars acc
= let (acc', occMap)
= unsafePerformIO
$ makeOccMapAcc config alvl acc
in
determineScopesAcc config avars occMap acc'
{-# NOINLINE recoverSharingExp #-}
recoverSharingExp
:: Typeable e
=> Config
-> Level
-> [Level]
-> Exp e
-> (ScopedExp e, [StableSharingExp])
recoverSharingExp config lvl fvar exp
= let
(rootExp, accOccMap) = unsafePerformIO $ do
accOccMap <- newASTHashTable
(exp', _) <- makeOccMapRootExp config accOccMap lvl fvar exp
frozenAccOccMap <- freezeOccMap accOccMap
return (exp', frozenAccOccMap)
(ScopedExp sse sharingExp, _) =
determineScopesExp config accOccMap rootExp
in
(ScopedExp [] sharingExp, sse)
traceLine :: String -> String -> IO ()
traceLine header msg
= Debug.traceIO Debug.dump_sharing
$ header ++ ": " ++ msg
traceChunk :: String -> String -> IO ()
traceChunk header msg
= Debug.traceIO Debug.dump_sharing
$ header ++ "\n " ++ msg
tracePure :: String -> String -> a -> a
tracePure header msg
= Debug.trace Debug.dump_sharing
$ header ++ ": " ++ msg