{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Symbolic.PhysicalParametric (
the,
theMarshal,
render,
MapFilter(..),
mapFilter,
FilterOuter(..),
filterOuter,
Scatter(..),
scatter,
ScatterMaybe(..),
scatterMaybe,
MapAccumLSimple(..),
mapAccumLSimple,
MapAccumLSequence(..),
mapAccumLSequence,
MapAccumL(..),
mapAccumL,
FoldOuterL(..),
foldOuterL,
AddDimension(..),
addDimension,
Parametric,
Rendered,
) where
import qualified Data.Array.Knead.Symbolic.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Symbolic.Private as Core
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Symbolic.PhysicalPrivate (MarshalPtr)
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import qualified LLVM.DSL.Execution as Code
import LLVM.DSL.Expression (Exp(Exp), unExp)
import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import Foreign.Marshal.Array (allocaArray, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, peekElemOff, )
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )
import Control.Exception (finally)
import Control.Monad.HT (void, )
import Control.Applicative (liftA2, )
mallocArray :: (Storable a) => Shape.Size -> IO (ForeignPtr a)
mallocArray = mallocForeignPtrArray . fromIntegral
type Importer f = FunPtr f -> f
type Parametric p a = Exp p -> a
type Rendered p a = IO (p, IO ()) -> IO a
withManagedParam :: Monad m => (p -> IO a) -> m (Rendered p a)
withManagedParam act =
return $ \create -> do
(param, final) <- create
finally (act param) final
foreign import ccall safe "dynamic" callThe ::
Importer (LLVM.Ptr param -> Ptr a -> IO ())
the ::
(Marshal.C p, Shape.Scalar z, Storable.C a) =>
Parametric p (Core.Array z a) -> IO (Rendered p a)
the arr = do
func <-
Code.compile "the" $
Code.createFunction callThe "eval" $
\paramPtr resultPtr -> do
case arr $ Exp (Memory.load paramPtr) of
Core.Array z code ->
code (Shape.zeroIndex z) >>=
flip Storable.store resultPtr
withManagedParam $ \param ->
Marshal.with param $ \pptr ->
alloca $ \aptr -> func pptr aptr >> peek aptr
foreign import ccall safe "dynamic" callTheMarshal ::
Importer (LLVM.Ptr param -> LLVM.Ptr a -> IO ())
theMarshal ::
(Marshal.C p, Shape.Scalar z, Marshal.C a) =>
Parametric p (Core.Array z a) -> IO (Rendered p a)
theMarshal arr = do
func <-
Code.compile "the-marshal" $
Code.createFunction callTheMarshal "eval" $
\paramPtr resultPtr -> do
case arr $ Exp (Memory.load paramPtr) of
Core.Array z code ->
code (Shape.zeroIndex z) >>=
flip Memory.store resultPtr
withManagedParam $ \param ->
Marshal.with param $ \pptr ->
Marshal.alloca $ \aptr ->
func pptr aptr >>
Marshal.peek aptr
foreign import ccall safe "dynamic" callShaper ::
Importer (LLVM.Ptr param -> LLVM.Ptr shape -> IO Shape.Size)
foreign import ccall safe "dynamic" callFill ::
Importer (LLVM.Ptr param -> LLVM.Ptr shape -> Ptr a -> IO ())
materialize ::
(Shape.C sh, Marshal.C sh, Marshal.C p, Storable.C a) =>
String ->
(core -> Exp sh) ->
(core ->
LLVM.Value (MarshalPtr sh) -> LLVM.Value (Ptr a) ->
LLVM.CodeGenFunction () ()) ->
Parametric p core -> IO (Rendered p (Array sh a))
materialize name shape fill core = do
(fsh, farr) <-
Code.compile name $
liftA2 (,)
(Code.createFunction callShaper "shape" $
\paramPtr resultPtr -> do
sh <- unExp $ shape $ core $ Exp (Memory.load paramPtr)
Memory.store sh resultPtr
Shape.size sh)
(Code.createFunction callFill "fill" $
\paramPtr shapePtr bufferPtr ->
fill (core $ Exp (Memory.load paramPtr)) shapePtr bufferPtr)
withManagedParam $ \param ->
Marshal.alloca $ \shptr ->
Marshal.with param $ \paramPtr -> do
fptr <- mallocArray =<< fsh paramPtr shptr
withForeignPtr fptr $ farr paramPtr shptr
sh <- Marshal.peek shptr
return (Array sh fptr)
foreign import ccall safe "dynamic" callFillExpArray ::
Importer (LLVM.Ptr param -> Ptr final -> LLVM.Ptr shape -> Ptr a -> IO ())
materializeExpArray ::
(Shape.C sh, Marshal.C sh, Marshal.C p, Storable.C a, Storable.C b) =>
String ->
(core -> Exp sh) ->
(core ->
LLVM.Value (Ptr b) ->
LLVM.Value (MarshalPtr sh) ->
LLVM.Value (Ptr a) ->
LLVM.CodeGenFunction () ()) ->
Parametric p core -> IO (Rendered p (b, Array sh a))
materializeExpArray name shape fill core = do
(fsh, farr) <-
Code.compile name $
liftA2 (,)
(Code.createFunction callShaper "shape" $
\paramPtr resultPtr -> do
sh <- unExp $ shape $ core $ Exp (Memory.load paramPtr)
Memory.store sh resultPtr
Shape.size sh)
(Code.createFunction callFillExpArray "fill" $
\paramPtr finalPtr shapePtr bufferPtr ->
fill
(core $ Exp (Memory.load paramPtr))
finalPtr shapePtr bufferPtr)
withManagedParam $ \param ->
Marshal.alloca $ \shptr ->
alloca $ \finalPtr ->
Marshal.with param $ \paramPtr -> do
fptr <- mallocArray =<< fsh paramPtr shptr
withForeignPtr fptr $ farr paramPtr finalPtr shptr
sh <- Marshal.peek shptr
final <- peek finalPtr
return (final, Array sh fptr)
foreign import ccall safe "dynamic" callShaper2 ::
Importer
(LLVM.Ptr param ->
LLVM.Ptr shapeA -> LLVM.Ptr shapeB -> Ptr Shape.Size -> IO ())
foreign import ccall safe "dynamic" callFill2 ::
Importer
(LLVM.Ptr param ->
LLVM.Ptr shapeA -> Ptr a -> LLVM.Ptr shapeB -> Ptr b -> IO ())
materialize2 ::
(Shape.C sha, Marshal.C sha,
Shape.C shb, Marshal.C shb,
Marshal.C p, Storable.C a, Storable.C b) =>
String ->
(core -> Exp (sha,shb)) ->
(core ->
(LLVM.Value (MarshalPtr sha), LLVM.Value (Ptr a)) ->
(LLVM.Value (MarshalPtr shb), LLVM.Value (Ptr b)) ->
LLVM.CodeGenFunction () ()) ->
Parametric p core -> IO (Rendered p (Array sha a, Array shb b))
materialize2 name shape fill core = do
(fsh, farr) <-
Code.compile name $
liftA2 (,)
(Code.createFunction callShaper2 "shape" $
\paramPtr shapeAPtr shapeBPtr sizesPtr -> do
(sha,shb) <-
fmap MultiValue.unzip $ unExp $
shape $ core $ Exp (Memory.load paramPtr)
Memory.store sha shapeAPtr
Memory.store shb shapeBPtr
sizeAPtr <- LLVM.bitcast sizesPtr
flip LLVM.store sizeAPtr =<< Shape.size sha
sizeBPtr <- A.advanceArrayElementPtr sizeAPtr
flip LLVM.store sizeBPtr =<< Shape.size shb)
(Code.createFunction callFill2 "fill" $
\paramPtr shapeAPtr bufferAPtr shapeBPtr bufferBPtr ->
fill
(core $ Exp (Memory.load paramPtr))
(shapeAPtr, bufferAPtr) (shapeBPtr, bufferBPtr))
withManagedParam $ \param ->
Marshal.alloca $ \shaPtr ->
Marshal.alloca $ \shbPtr ->
allocaArray 2 $ \sizesPtr ->
Marshal.with param $ \paramPtr -> do
fsh paramPtr shaPtr shbPtr sizesPtr
afptr <- mallocArray =<< peekElemOff sizesPtr 0
bfptr <- mallocArray =<< peekElemOff sizesPtr 1
withForeignPtr afptr $ \aptr ->
withForeignPtr bfptr $ \bptr ->
farr paramPtr shaPtr aptr shbPtr bptr
sha <- Marshal.peek shaPtr
shb <- Marshal.peek shbPtr
return (Array sha afptr, Array shb bfptr)
render ::
(Shape.C sh, Shape.Index sh ~ ix, Marshal.C sh,
Marshal.C p, Storable.C a) =>
Parametric p (Core.Array sh a) -> IO (Rendered p (Array sh a))
render =
materialize "render" Core.shape
(\(Core.Array esh code) shapePtr bufferPtr -> do
let step ix p = flip Storable.storeNext p =<< code ix
sh <- Shape.load esh shapePtr
void $ Shape.loop step sh bufferPtr)
data Scatter sh0 sh1 a =
Scatter {
scatterAccum :: Exp a -> Exp a -> Exp a,
scatterInit :: Core.Array sh1 a,
scatterMap :: Core.Array sh0 (Shape.Index sh1, a)
}
scatter ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.C sh1,
Marshal.C p, Storable.C a) =>
Parametric p (Scatter sh0 sh1 a) -> IO (Rendered p (Array sh1 a))
scatter =
materialize "scatter"
(Core.shape . scatterInit)
(\(Scatter accum arrInit arrMap) ->
Priv.scatter accum arrInit arrMap)
data ScatterMaybe sh0 sh1 a =
ScatterMaybe {
scatterMaybeAccum :: Exp a -> Exp a -> Exp a,
scatterMaybeInit :: Core.Array sh1 a,
scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a))
}
scatterMaybe ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.C sh1,
Marshal.C p, Storable.C a) =>
Parametric p (ScatterMaybe sh0 sh1 a) -> IO (Rendered p (Array sh1 a))
scatterMaybe =
materialize "scatterMaybe"
(Core.shape . scatterMaybeInit)
(\(ScatterMaybe accum arrInit arrMap) ->
Priv.scatterMaybe accum arrInit arrMap)
data MapAccumLSimple sh n acc a b =
MapAccumLSimple {
mapAccumLSimpleAccum :: Exp acc -> Exp a -> Exp (acc,b),
mapAccumLSimpleInit :: Core.Array sh acc,
mapAccumLSimpleArray :: Core.Array (sh, n) a
}
mapAccumLSimple ::
(Shape.C sh, Marshal.C sh,
Shape.C n, Marshal.C n,
MultiValue.C acc, Marshal.C p, Storable.C a, Storable.C b) =>
Parametric p (MapAccumLSimple sh n acc a b) ->
IO (Rendered p (Array (sh,n) b))
mapAccumLSimple =
materialize "mapAccumLSimple"
(Core.shape . mapAccumLSimpleArray)
(\(MapAccumLSimple f arrInit arrData) ->
Priv.mapAccumLSimple f arrInit arrData)
data MapAccumLSequence n acc final a b =
MapAccumLSequence {
mapAccumLSequenceAccum :: Exp acc -> Exp a -> Exp (acc,b),
mapAccumLSequenceFinal :: Exp acc -> Exp final,
mapAccumLSequenceInit :: Exp acc,
mapAccumLSequenceArray :: Core.Array n a
}
mapAccumLSequence ::
(Shape.C n, Marshal.C n, MultiValue.C acc, Storable.C final,
Marshal.C p, Storable.C a, Storable.C b) =>
Parametric p (MapAccumLSequence n acc final a b) ->
IO (Rendered p (final, Array n b))
mapAccumLSequence =
materializeExpArray "mapAccumLSequence"
(Core.shape . mapAccumLSequenceArray)
(\(MapAccumLSequence f final expInit arr) ->
Priv.mapAccumLSequence f final expInit arr)
data MapAccumL sh n acc final a b =
MapAccumL {
mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b),
mapAccumLFinal :: Exp acc -> Exp final,
mapAccumLInit :: Core.Array sh acc,
mapAccumLArray :: Core.Array (sh, n) a
}
mapAccumL ::
(Shape.C sh, Marshal.C sh,
Shape.C n, Marshal.C n,
MultiValue.C acc, Storable.C final,
Marshal.C p, Storable.C a, Storable.C b) =>
Parametric p (MapAccumL sh n acc final a b) ->
IO (Rendered p (Array sh final, Array (sh,n) b))
mapAccumL =
materialize2 "mapAccumL"
(\core ->
Expr.zip
(Core.shape $ mapAccumLInit core)
(Core.shape $ mapAccumLArray core))
(\(MapAccumL f final arrInit arrData) ->
Priv.mapAccumL f final arrInit arrData)
data FoldOuterL n sh a b =
FoldOuterL {
foldOuterLAccum :: Exp a -> Exp b -> Exp a,
foldOuterLInit :: Core.Array sh a,
foldOuterLArray :: Core.Array (n,sh) b
}
foldOuterL ::
(Shape.C n, Marshal.C n,
Shape.C sh, Marshal.C sh,
Marshal.C p, Storable.C a) =>
Parametric p (FoldOuterL n sh a b) -> IO (Rendered p (Array sh a))
foldOuterL =
materialize "foldOuterL"
(Core.shape . foldOuterLInit)
(\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData)
data MapFilter n a b =
MapFilter {
mapFilterMap :: Exp a -> Exp b,
mapFilterPredicate :: Exp a -> Exp Bool,
mapFilterArray :: Core.Array n a
}
mapFilter ::
(Shape.Sequence n, Marshal.C n, Marshal.C p, Storable.C b) =>
Parametric p (MapFilter n a b) -> IO (Rendered p (Array n b))
mapFilter =
materialize "mapFilter"
(Core.shape . mapFilterArray)
(\(MapFilter f p arr) shapePtr bufferPtr ->
flip Memory.store shapePtr
=<< Priv.mapFilter f p arr shapePtr bufferPtr)
data FilterOuter n sh a =
FilterOuter {
filterOuterPredicate :: Core.Array n Bool,
filterOuterArray :: Core.Array (n,sh) a
}
filterOuter ::
(Shape.Sequence n, Marshal.C n,
Shape.C sh, Marshal.C sh,
Marshal.C p, Storable.C a) =>
Parametric p (FilterOuter n sh a) -> IO (Rendered p (Array (n,sh) a))
filterOuter =
materialize "filterOuter"
(Core.shape . filterOuterArray)
(\(FilterOuter p arr) shapePtr bufferPtr ->
flip Memory.store shapePtr
=<< Priv.filterOuter p arr shapePtr bufferPtr)
data AddDimension sh n a b =
AddDimension {
addDimensionSize :: Exp n,
addDimensionSelect :: Exp (Shape.Index n) -> Exp a -> Exp b,
addDimensionArray :: Core.Array sh a
}
addDimension ::
(Shape.C sh, Marshal.C sh,
Shape.C n, Marshal.C n,
Marshal.C p, Storable.C b) =>
Parametric p (AddDimension sh n a b) -> IO (Rendered p (Array (sh,n) b))
addDimension =
materialize "addDimension"
(\r -> Expr.zip (Core.shape (addDimensionArray r)) (addDimensionSize r))
(\(AddDimension n select arr) -> Priv.addDimension n select arr)