{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Trafo.Environment
where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Debug.Stats as Stats
data Gamma env env' aenv where
EmptyExp :: Gamma env env' aenv
PushExp :: Gamma env env' aenv
-> WeakOpenExp env aenv t
-> Gamma env (env', t) aenv
data WeakOpenExp env aenv t where
Subst :: env :> env'
-> OpenExp env aenv t
-> OpenExp env' aenv t
-> WeakOpenExp env' aenv t
incExp
:: Gamma env env' aenv
-> Gamma (env,s) env' aenv
incExp EmptyExp = EmptyExp
incExp (PushExp env w) = incExp env `PushExp` subs w
where
subs :: forall env aenv s t. WeakOpenExp env aenv t -> WeakOpenExp (env,s) aenv t
subs (Subst k (e :: OpenExp env_ aenv t) _) = Subst (weakenSucc' k) e (weakenE (weakenSucc' k) e)
prjExp :: HasCallStack => Idx env' t -> Gamma env env' aenv -> OpenExp env aenv t
prjExp ZeroIdx (PushExp _ (Subst _ _ e)) = e
prjExp (SuccIdx ix) (PushExp env _) = prjExp ix env
prjExp _ _ = internalError "inconsistent valuation"
pushExp :: Gamma env env' aenv -> OpenExp env aenv t -> Gamma env (env',t) aenv
pushExp env e = env `PushExp` Subst weakenId e e
data Extend s f env env' where
BaseEnv :: Extend s f env env
PushEnv :: Extend s f env env'
-> LeftHandSide s t env' env''
-> f env' t
-> Extend s f env env''
pushArrayEnv
:: HasArraysR acc
=> Extend ArrayR acc aenv aenv'
-> acc aenv' (Array sh e)
-> Extend ArrayR acc aenv (aenv', Array sh e)
pushArrayEnv env a = PushEnv env (LeftHandSideSingle $ arrayR a) a
append :: Extend s acc env env' -> Extend s acc env' env'' -> Extend s acc env env''
append x BaseEnv = x
append x (PushEnv e lhs a) = PushEnv (append x e) lhs a
bind :: (forall env t. PreOpenAcc acc env t -> acc env t)
-> Extend ArrayR acc aenv aenv'
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
bind _ BaseEnv = id
bind inject (PushEnv g lhs a) = bind inject g . Alet lhs a . inject
sinkA :: Sink f => Extend s acc env env' -> f env t -> f env' t
sinkA env = weaken (sinkWeaken env)
sink1 :: Sink f => Extend s acc env env' -> f (env,t') t -> f (env',t') t
sink1 env = weaken $ sink $ sinkWeaken env
sinkWeaken :: Extend s acc env env' -> env :> env'
sinkWeaken (PushEnv e (LeftHandSideWildcard _) _) = sinkWeaken e
sinkWeaken (PushEnv e (LeftHandSideSingle _) _) = weakenSucc' $ sinkWeaken e
sinkWeaken (PushEnv e (LeftHandSidePair l1 l2) _) = sinkWeaken (PushEnv (PushEnv e l1 undefined) l2 undefined)
sinkWeaken BaseEnv = Stats.substitution "sink" weakenId
newtype OpenExp' aenv env e = OpenExp' (OpenExp env aenv e)
bindExps :: Extend ScalarType (OpenExp' aenv) env env'
-> OpenExp env' aenv e
-> OpenExp env aenv e
bindExps BaseEnv = id
bindExps (PushEnv g lhs (OpenExp' b)) = bindExps g . Let lhs b