{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Trafo.Var
where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Type
data DeclareVars s t aenv where
DeclareVars :: LeftHandSide s t env env'
-> (env :> env')
-> (forall env''. env' :> env'' -> Vars s env'' t)
-> DeclareVars s t env
declareVars :: TupR s t -> DeclareVars s t env
declareVars TupRunit
= DeclareVars LeftHandSideUnit weakenId $ const $ TupRunit
declareVars (TupRsingle s)
= DeclareVars (LeftHandSideSingle s) (weakenSucc weakenId) $ \k -> TupRsingle $ Var s $ k >:> ZeroIdx
declareVars (TupRpair r1 r2)
| DeclareVars lhs1 subst1 a1 <- declareVars r1
, DeclareVars lhs2 subst2 a2 <- declareVars r2
= DeclareVars (LeftHandSidePair lhs1 lhs2) (subst2 .> subst1) $ \k -> a1 (k .> subst2) `TupRpair` a2 k
type InjectAcc acc = forall env t. PreOpenAcc acc env t -> acc env t
type ExtractAcc acc = forall env t. acc env t -> Maybe (PreOpenAcc acc env t)
avarIn :: InjectAcc acc
-> ArrayVar aenv a
-> acc aenv a
avarIn inject v@(Var ArrayR{} _) = inject (Avar v)
avarsIn :: forall acc aenv arrs.
InjectAcc acc
-> ArrayVars aenv arrs
-> acc aenv arrs
avarsIn inject = go
where
go :: ArrayVars aenv t -> acc aenv t
go TupRunit = inject Anil
go (TupRsingle v) = avarIn inject v
go (TupRpair a b) = inject (go a `Apair` go b)
avarsOut
:: ExtractAcc acc
-> PreOpenAcc acc aenv a
-> Maybe (ArrayVars aenv a)
avarsOut extract = \case
Anil -> Just $ TupRunit
Avar v -> Just $ TupRsingle v
Apair al ar
| Just pl <- extract al
, Just pr <- extract ar
, Just as <- avarsOut extract pl
, Just bs <- avarsOut extract pr
-> Just (TupRpair as bs)
_ -> Nothing