{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Trafo.Var -- Copyright : [2012..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Trafo.Var where import Data.Array.Accelerate.AST import Data.Array.Accelerate.AST.Environment import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Type data DeclareVars s t aenv where DeclareVars :: LeftHandSide s t env env' -> (env :> env') -> (forall env''. env' :> env'' -> Vars s env'' t) -> DeclareVars s t env declareVars :: TupR s t -> DeclareVars s t env declareVars TupRunit = DeclareVars LeftHandSideUnit weakenId $ const $ TupRunit declareVars (TupRsingle s) = DeclareVars (LeftHandSideSingle s) (weakenSucc weakenId) $ \k -> TupRsingle $ Var s $ k >:> ZeroIdx declareVars (TupRpair r1 r2) | DeclareVars lhs1 subst1 a1 <- declareVars r1 , DeclareVars lhs2 subst2 a2 <- declareVars r2 = DeclareVars (LeftHandSidePair lhs1 lhs2) (subst2 .> subst1) $ \k -> a1 (k .> subst2) `TupRpair` a2 k type InjectAcc acc = forall env t. PreOpenAcc acc env t -> acc env t type ExtractAcc acc = forall env t. acc env t -> Maybe (PreOpenAcc acc env t) avarIn :: InjectAcc acc -> ArrayVar aenv a -> acc aenv a avarIn inject v@(Var ArrayR{} _) = inject (Avar v) avarsIn :: forall acc aenv arrs. InjectAcc acc -> ArrayVars aenv arrs -> acc aenv arrs avarsIn inject = go where go :: ArrayVars aenv t -> acc aenv t go TupRunit = inject Anil go (TupRsingle v) = avarIn inject v go (TupRpair a b) = inject (go a `Apair` go b) avarsOut :: ExtractAcc acc -> PreOpenAcc acc aenv a -> Maybe (ArrayVars aenv a) avarsOut extract = \case Anil -> Just $ TupRunit Avar v -> Just $ TupRsingle v Apair al ar | Just pl <- extract al , Just pr <- extract ar , Just as <- avarsOut extract pl , Just bs <- avarsOut extract pr -> Just (TupRpair as bs) _ -> Nothing