module Data.Array.Knead.Parameterized.Private where
import qualified Data.Array.Knead.Simple.Symbolic as Core
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import Foreign.Storable (Storable, )
import Control.Applicative (Applicative (pure, (<*>)), )
import Prelude hiding (id, map, zipWith, replicate, )
data Array p sh a =
forall parameter context.
(Storable parameter, MultiValueMemory.C parameter) =>
Array {
core :: MultiValue.T parameter -> Core.Array sh a,
createContext :: p -> IO (context, parameter),
deleteContext :: context -> IO ()
}
instance Core.C (Array p) where
lift0 arr = Array (const arr) (createPlain (const ())) deletePlain
lift1 f (Array arr create delete) = Array (f . arr) create delete
lift2 f (Array arrA createA deleteA) (Array arrB createB deleteB) =
Array
(\p ->
case MultiValue.unzip p of
(paramA, paramB) -> f (arrA paramA) (arrB paramB))
(combineCreate createA createB)
(combineDelete deleteA deleteB)
(!) ::
(Shape.C sh, Shape.Index sh ~ ix,
Storable ix, MultiValueMemory.C ix,
Shape.Scalar z) =>
Array p sh a -> Param.T p ix -> Array p z a
(!) arr pix =
runHull $
mapHullWithExp
(\ix carr -> Core.fromScalar $ carr Core.! ix)
(expParam pix)
(arrayHull arr)
fill ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
Param.T p sh -> Param.T p a -> Array p sh a
fill sh a =
Shape.paramWith sh $ \getSh valueSh ->
Param.withMulti a $ \getA valueA ->
Array
(\p ->
case MultiValue.unzip p of
(vsh, va) ->
Core.fill (valueSh vsh) (Expr.lift0 $ valueA va))
(createPlain $ \p -> (getSh p, getA p))
deletePlain
gather ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, MultiValue.C a) =>
Array p sh1 ix0 ->
Array p sh0 a ->
Array p sh1 a
gather = Core.gather
id ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.Index sh ~ ix) =>
Param.T p sh -> Array p sh ix
id sh =
Shape.paramWith sh $ \getSh valueSh ->
Array
(Core.id . valueSh)
(createPlain getSh)
deletePlain
map ::
(Shape.C sh, MultiValueMemory.C c, Storable c) =>
(Exp c -> Exp a -> Exp b) ->
Param.T p c -> Array p sh a -> Array p sh b
map = lift Core.map
mapWithIndex ::
(Shape.C sh, MultiValueMemory.C c, Storable c, Shape.Index sh ~ ix) =>
(Exp c -> Exp ix -> Exp a -> Exp b) ->
Param.T p c -> Array p sh a -> Array p sh b
mapWithIndex = lift Core.mapWithIndex
fold1 ::
(Shape.C sh0, Shape.C sh1,
MultiValueMemory.C c, Storable c, MultiValue.C a) =>
(Exp c -> Exp a -> Exp a -> Exp a) ->
Param.T p c -> Array p (sh0, sh1) a -> Array p sh0 a
fold1 = lift Core.fold1
fold1All ::
(Shape.C sh, Shape.Scalar z,
MultiValueMemory.C c, Storable c, MultiValue.C a) =>
(Exp c -> Exp a -> Exp a -> Exp a) ->
Param.T p c -> Array p sh a -> Array p z a
fold1All = lift Core.fold1All
lift ::
(Shape.C sh0, Shape.C sh1,
MultiValueMemory.C c, Storable c) =>
(f -> Core.Array sh0 a -> Core.Array sh1 b) ->
(Exp c -> f) ->
Param.T p c -> Array p sh0 a -> Array p sh1 b
lift g f c arr =
runHull $
mapHullWithExp
(\cexp -> g (f cexp))
(expParam c)
(arrayHull arr)
data Hull p a =
forall parameter context.
(Storable parameter, MultiValueMemory.C parameter) =>
Hull {
hullCore :: MultiValue.T parameter -> a,
hullCreateContext :: p -> IO (context, parameter),
hullDeleteContext :: context -> IO ()
}
instance Functor (Hull p) where
fmap f (Hull arr create delete) = Hull (f . arr) create delete
instance Applicative (Hull p) where
pure a = Hull (const a) (const $ return ((),())) return
Hull arrA createA deleteA <*> Hull arrB createB deleteB =
Hull
(\p -> case MultiValue.unzip p of (a,b) -> arrA a $ arrB b)
(combineCreate createA createB)
(combineDelete deleteA deleteB)
mapHullWithExp ::
(Exp sl -> a -> b) ->
Param.Tunnel p sl -> Hull p a -> Hull p b
mapHullWithExp f tunnel (Hull arr create delete) =
case tunnel of
Param.Tunnel getSl valueSl ->
Hull
(\p ->
case MultiValue.unzip p of
(arrp, sl) -> f (Expr.lift0 $ valueSl sl) $ arr arrp)
(\p -> do
(ctx, param) <- create p
return (ctx, (param, getSl p)))
delete
expHull :: Param.Tunnel p sl -> Hull p (Exp sl)
expHull tunnel =
case tunnel of
Param.Tunnel getSl valueSl ->
Hull
(Expr.lift0 . valueSl)
(\p -> return ((), getSl p))
return
arrayHull :: Array p sh a -> Hull p (Core.Array sh a)
arrayHull (Array arr create delete) = Hull arr create delete
runHull :: Hull p (Core.Array sh a) -> Array p sh a
runHull (Hull arr create delete) = Array arr create delete
extendHull :: (q -> p) -> Hull p a -> Hull q a
extendHull f (Hull arr create delete) = Hull arr (create . f) delete
expParam ::
(Storable a, MultiValueMemory.C a) => Param.T p a -> Param.Tunnel p a
expParam = Param.tunnel MultiValue.cons
createPlain :: (Monad m) => (p -> pl) -> p -> m ((), pl)
createPlain f p = return ((), f p)
deletePlain :: (Monad m) => () -> m ()
deletePlain () = return ()
combineCreate ::
Monad m =>
(p -> m (ctxA, paramA)) -> (p -> m (ctxB, paramB)) ->
p -> m ((ctxA, ctxB), (paramA, paramB))
combineCreate createA createB p = do
(ctxA, paramA) <- createA p
(ctxB, paramB) <- createB p
return ((ctxA, ctxB), (paramA, paramB))
combineDelete ::
Monad m =>
(ctxA -> m ()) -> (ctxB -> m ()) -> (ctxA, ctxB) -> m ()
combineDelete deleteA deleteB (ctxA, ctxB) = do
deleteA ctxA
deleteB ctxB
extendParameter ::
(q -> p) -> Array p sh a -> Array q sh a
extendParameter f (Array arr create delete) =
Array arr (create . f) delete