module Data.Array.Knead.Parameterized.PhysicalHull (
render,
Scatter(..),
scatter,
ScatterMaybe(..),
scatterMaybe,
MapAccumL(..),
mapAccumL,
FoldOuterL(..),
foldOuterL,
) where
import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (compile, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import Foreign.Marshal.Utils (with, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )
import Control.Exception (bracket, )
import Control.Monad.HT (void, )
import Control.Applicative (liftA2, )
import Data.Word (Word32, )
type Importer f = FunPtr f -> f
foreign import ccall safe "dynamic" callShaper ::
Importer (Ptr param -> Ptr shape -> IO Word32)
foreign import ccall safe "dynamic" callFill ::
Importer (Ptr param -> Ptr shape -> Ptr am -> IO ())
materialize ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
String ->
(core -> Exp sh) ->
(core ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction () ()) ->
Sym.Hull p core -> IO (p -> IO (Phys.Array sh a))
materialize name shape fill (Sym.Hull core create delete) = do
(fsh, farr) <-
compile name $
liftA2 (,)
(Code.createFunction callShaper "shape" $
\paramPtr resultPtr -> do
param <- Memory.load paramPtr
sh <- unExp $ shape $ core param
MultiValueMemory.store sh resultPtr
Shape.sizeCode sh >>= LLVM.ret)
(Code.createFunction callFill "fill" $
\paramPtr shapePtr bufferPtr -> do
param <- Memory.load paramPtr
fill (core param) shapePtr bufferPtr
LLVM.ret ())
return $ \p ->
bracket (create p) (delete . fst) $ \(_ctx, param) ->
alloca $ \shptr ->
with param $ \paramPtr -> do
let paramMVPtr = MultiValueMemory.castStructPtr paramPtr
let shapeMVPtr = MultiValueMemory.castStructPtr shptr
n <- fsh paramMVPtr shapeMVPtr
fptr <- mallocForeignPtrArray (fromIntegral n)
withForeignPtr fptr $
farr paramMVPtr shapeMVPtr . MultiValueMemory.castStructPtr
sh <- peek shptr
return (Phys.Array sh fptr)
render ::
(Shape.C sh, Shape.Index sh ~ ix,
Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
Sym.Hull p (Core.Array sh a) -> IO (p -> IO (Phys.Array sh a))
render =
materialize "render" Core.shape
(\(Core.Array esh code) shapePtr bufferPtr -> do
let step ix p = do
flip Memory.store p =<< code ix
A.advanceArrayElementPtr p
sh <- Shape.load esh shapePtr
void $ Shape.loop step sh bufferPtr)
data Scatter sh0 sh1 a =
Scatter {
scatterAccum :: Exp a -> Exp a -> Exp a,
scatterInit :: Core.Array sh1 a,
scatterMap :: Core.Array sh0 (Shape.Index sh1, a)
}
scatter ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
Sym.Hull p (Scatter sh0 sh1 a) -> IO (p -> IO (Phys.Array sh1 a))
scatter =
materialize "scatter"
(Core.shape . scatterInit)
(\(Scatter accum arrInit arrMap) ->
Priv.scatter accum arrInit arrMap)
data ScatterMaybe sh0 sh1 a =
ScatterMaybe {
scatterMaybeAccum :: Exp a -> Exp a -> Exp a,
scatterMaybeInit :: Core.Array sh1 a,
scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a))
}
scatterMaybe ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
Sym.Hull p (ScatterMaybe sh0 sh1 a) -> IO (p -> IO (Phys.Array sh1 a))
scatterMaybe =
materialize "scatterMaybe"
(Core.shape . scatterMaybeInit)
(\(ScatterMaybe accum arrInit arrMap) ->
Priv.scatterMaybe accum arrInit arrMap)
data MapAccumL sh n acc a b =
MapAccumL {
mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b),
mapAccumLInit :: Core.Array sh acc,
mapAccumLMap :: Core.Array (sh, n) a
}
mapAccumL ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Shape.C n, Storable n, MultiValueMemory.C n,
MultiValue.C acc,
Storable a, MultiValueMemory.C a,
Storable b, MultiValueMemory.C b) =>
Sym.Hull p (MapAccumL sh n acc a b) -> IO (p -> IO (Phys.Array (sh,n) b))
mapAccumL =
materialize "mapAccumL"
(Core.shape . mapAccumLMap)
(\(MapAccumL f arrInit arrData) -> Priv.mapAccumL f arrInit arrData)
data FoldOuterL n sh a b =
FoldOuterL {
foldOuterLAccum :: Exp a -> Exp b -> Exp a,
foldOuterLInit :: Core.Array sh a,
foldOuterLMap :: Core.Array (n,sh) b
}
foldOuterL ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Shape.C n, Storable n, MultiValueMemory.C n,
Storable a, MultiValueMemory.C a) =>
Sym.Hull p (FoldOuterL n sh a b) -> IO (p -> IO (Phys.Array sh a))
foldOuterL =
materialize "foldOuterL"
(Core.shape . foldOuterLInit)
(\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData)