{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
module Data.Array.Accelerate.Trafo.LetSplit (
convertAfun,
convertAcc,
) where
import Prelude hiding ( exp )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.Trafo.Substitution
convertAfun :: PreOpenAfun OpenAcc aenv f -> PreOpenAfun OpenAcc aenv f
convertAfun (Alam lhs f) = Alam lhs (convertAfun f)
convertAfun (Abody a) = Abody (convertAcc a)
convertAcc :: OpenAcc aenv a -> OpenAcc aenv a
convertAcc (OpenAcc pacc) = OpenAcc (convertPreOpenAcc pacc)
convertPreOpenAcc :: PreOpenAcc OpenAcc aenv a -> PreOpenAcc OpenAcc aenv a
convertPreOpenAcc = \case
Alet lhs bnd body -> convertLHS lhs (convertAcc bnd) (convertAcc body)
Avar var -> Avar var
Apair a1 a2 -> Apair (convertAcc a1) (convertAcc a2)
Anil -> Anil
Apply repr f a -> Apply repr (convertAfun f) (convertAcc a)
Aforeign repr asm f a -> Aforeign repr asm (convertAfun f) (convertAcc a)
Acond e a1 a2 -> Acond e (convertAcc a1) (convertAcc a2)
Awhile c f a -> Awhile (convertAfun c) (convertAfun f) (convertAcc a)
Use repr arr -> Use repr arr
Unit tp e -> Unit tp e
Reshape shr e a -> Reshape shr e a
Generate repr e f -> Generate repr e f
Transform repr sh f g a -> Transform repr sh f g (convertAcc a)
Replicate slix sl a -> Replicate slix sl (convertAcc a)
Slice slix a sl -> Slice slix (convertAcc a) sl
Map tp f a -> Map tp f (convertAcc a)
ZipWith tp f a1 a2 -> ZipWith tp f (convertAcc a1) (convertAcc a2)
Fold f e a -> Fold f e (convertAcc a)
FoldSeg i f e a s -> FoldSeg i f e (convertAcc a) (convertAcc s)
Scan d f e a -> Scan d f e (convertAcc a)
Scan' d f e a -> Scan' d f e (convertAcc a)
Permute f a1 g a2 -> Permute f (convertAcc a1) g (convertAcc a2)
Backpermute shr sh f a -> Backpermute shr sh f (convertAcc a)
Stencil s tp f b a -> Stencil s tp f b (convertAcc a)
Stencil2 s1 s2 tp f b1 a1 b2 a2 -> Stencil2 s1 s2 tp f b1 (convertAcc a1) b2 (convertAcc a2)
convertLHS
:: ALeftHandSide bnd aenv aenv'
-> OpenAcc aenv bnd
-> OpenAcc aenv' a
-> PreOpenAcc OpenAcc aenv a
convertLHS lhs bnd@(OpenAcc pbnd) a@(OpenAcc pa) =
case lhs of
LeftHandSideWildcard{} -> pa
LeftHandSideSingle{} -> Alet lhs bnd a
LeftHandSidePair l1 l2 ->
case pbnd of
Apair b1 b2 -> convertLHS l1 b1 (OpenAcc (convertLHS l2 (weaken (weakenWithLHS l1) b2) a))
_ -> Alet lhs bnd a