{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Trafo.Delayed
where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Analysis.Hash
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Debug.Stats as Stats
import Control.DeepSeq
import Data.ByteString.Builder
import Data.ByteString.Builder.Extra
type DelayedAcc = DelayedOpenAcc ()
type DelayedAfun = PreOpenAfun DelayedOpenAcc ()
type DelayedOpenAfun = PreOpenAfun DelayedOpenAcc
data DelayedOpenAcc aenv a where
Manifest :: PreOpenAcc DelayedOpenAcc aenv a
-> DelayedOpenAcc aenv a
Delayed ::
{ reprD :: ArrayR (Array sh e)
, extentD :: Exp aenv sh
, indexD :: Fun aenv (sh -> e)
, linearIndexD :: Fun aenv (Int -> e)
} -> DelayedOpenAcc aenv (Array sh e)
instance HasArraysR DelayedOpenAcc where
arraysR (Manifest a) = arraysR a
arraysR Delayed{..} = TupRsingle reprD
instance Rebuildable DelayedOpenAcc where
type AccClo DelayedOpenAcc = DelayedOpenAcc
rebuildPartial v = \case
Manifest pacc -> Manifest <$> rebuildPartial v pacc
Delayed{..} -> (\e i l -> Delayed reprD (unOpenAccExp e) (unOpenAccFun i) (unOpenAccFun l))
<$> rebuildPartial v (OpenAccExp extentD)
<*> rebuildPartial v (OpenAccFun indexD)
<*> rebuildPartial v (OpenAccFun linearIndexD)
instance Sink DelayedOpenAcc where
weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k)
instance NFData (DelayedOpenAfun aenv t) where
rnf = rnfPreOpenAfun rnfDelayedOpenAcc
instance NFData (DelayedOpenAcc aenv t) where
rnf = rnfDelayedOpenAcc
encodeDelayedOpenAcc :: EncodeAcc DelayedOpenAcc
encodeDelayedOpenAcc options acc =
let
travE :: Exp aenv sh -> Builder
travE = encodeOpenExp
travF :: Fun aenv f -> Builder
travF = encodeOpenFun
travA :: PreOpenAcc DelayedOpenAcc aenv a -> Builder
travA = encodePreOpenAcc options encodeDelayedOpenAcc
deepA :: PreOpenAcc DelayedOpenAcc aenv a -> Builder
deepA | perfect options = travA
| otherwise = encodeArraysType . arraysR
in
case acc of
Manifest pacc -> intHost $(hashQ ("Manifest" :: String)) <> deepA pacc
Delayed _ sh f g -> intHost $(hashQ ("Delayed" :: String)) <> travE sh <> travF f <> travF g
matchDelayedOpenAcc :: MatchAcc DelayedOpenAcc
matchDelayedOpenAcc (Manifest pacc1) (Manifest pacc2)
= matchPreOpenAcc matchDelayedOpenAcc pacc1 pacc2
matchDelayedOpenAcc (Delayed _ sh1 ix1 lx1) (Delayed _ sh2 ix2 lx2)
| Just Refl <- matchOpenExp sh1 sh2
, Just Refl <- matchOpenFun ix1 ix2
, Just Refl <- matchOpenFun lx1 lx2
= Just Refl
matchDelayedOpenAcc _ _
= Nothing
rnfDelayedOpenAcc :: NFDataAcc DelayedOpenAcc
rnfDelayedOpenAcc (Manifest pacc) =
rnfPreOpenAcc rnfDelayedOpenAcc pacc
rnfDelayedOpenAcc (Delayed aR sh ix lx) =
rnfArrayR aR `seq` rnfOpenExp sh `seq` rnfOpenFun ix `seq` rnfOpenFun lx
liftDelayedOpenAcc :: LiftAcc DelayedOpenAcc
liftDelayedOpenAcc (Manifest pacc) =
[|| Manifest $$(liftPreOpenAcc liftDelayedOpenAcc pacc) ||]
liftDelayedOpenAcc (Delayed aR sh ix lx) =
[|| Delayed $$(liftArrayR aR) $$(liftOpenExp sh) $$(liftOpenFun ix) $$(liftOpenFun lx) ||]