module Data.Array.Knead.Simple.ShapeDependent where
import qualified Data.Array.Knead.Simple.Private as Core
import Data.Array.Knead.Simple.Private (Array(Array), )
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified Control.Monad.HT as Monad
import Control.Monad ((<=<), )
shape :: (Core.C array, Shape.C sh, Shape.Scalar z) => array sh a -> array z sh
shape = Core.lift1 $ Core.fromScalar . Core.shape
backpermute ::
(Core.C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) ->
(Exp ix1 -> Exp ix0) ->
array sh0 a ->
array sh1 a
backpermute createShape projectIndex =
Core.lift1 $ \(Array sh code) ->
Array (createShape sh)
(code <=< Expr.unliftM1 projectIndex)
backpermuteExtra ::
(Core.C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Shape.C sh, Shape.Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh) ->
(Exp ix -> Exp ix0) ->
array sh0 a -> array sh1 b -> array sh a
backpermuteExtra newShape projectIndex =
Core.lift2 $ \(Array sh0 code) (Array sh1 _code) ->
Array (newShape sh0 sh1)
(\ix -> code =<< Expr.unliftM1 projectIndex ix)
backpermute2 ::
(Core.C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Shape.C sh, Shape.Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh) ->
(Exp ix -> Exp ix0) ->
(Exp ix -> Exp ix1) ->
(Exp a -> Exp b -> Exp c) ->
array sh0 a -> array sh1 b -> array sh c
backpermute2 combineShape projectIndex0 projectIndex1 f =
Core.lift2 $ \(Array sha codeA) (Array shb codeB) ->
Array (combineShape sha shb)
(\ix ->
Monad.liftJoin2 (Expr.unliftM2 f)
(codeA =<< Expr.unliftM1 projectIndex0 ix)
(codeB =<< Expr.unliftM1 projectIndex1 ix))
fill ::
(Core.C array) =>
(Exp sh0 -> Exp sh1) -> Exp b ->
array sh0 a -> array sh1 b
fill fsh a =
Core.lift1 $ \arr ->
Core.fill (fsh $ Core.shape arr) a