module Data.Array.Knead.Simple.Fold (
T,
Linear,
apply,
passAny,
pass,
fold,
(Core.$:.),
) where
import qualified Data.Array.Knead.Simple.Private as Core
import Data.Array.Knead.Simple.Private (Array(Array), Code, Val, )
import qualified Data.Array.Knead.Index.Linear as Linear
import qualified Data.Array.Knead.Index.Linear.Int as IndexInt
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Index.Linear ((#:.), (:.)((:.)), )
import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )
import Prelude hiding (zipWith, zipWith3, zip, zip3, replicate, )
data T sh0 sh1 a =
forall ix0 ix1.
(Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
Cons
(Exp sh0 -> Exp sh1)
(forall r. Val sh0 -> (Val ix0 -> Code r a) -> (Val ix1 -> Code r a))
apply ::
(Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
T sh0 sh1 a ->
array sh0 a ->
array sh1 a
apply (Cons fsh reduce) =
Core.lift1 $ \(Array sh code) ->
Array (fsh sh) (\ix -> do sh0 <- unExp sh; reduce sh0 code ix)
type Linear sh0 sh1 = T (Linear.Shape sh0) (Linear.Shape sh1)
passAny :: Linear sh sh a
passAny = Cons id (const id)
pass ::
Linear sh0 sh1 a ->
Linear (sh0:.i) (sh1:.i) a
pass (Cons fsh reduce) =
Cons
(Expr.modify (Linear.shape (atom:.atom)) $ \(sh:.s) -> fsh sh :. s)
(\sh code ->
Linear.switchR $ \jx j ->
reduce (Linear.tail sh) (\kx -> code (kx #:. j)) jx)
fold1CodeLinear ::
(MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
Exp IndexInt.Int ->
(Val (Linear.Index (sh :. IndexInt.Int)) -> Code r a) ->
(Val (Linear.Index sh) -> Code r a)
fold1CodeLinear f nc code ix =
Core.fold1Code f (IndexInt.decons nc)
(\jx j -> code (jx #:. IndexInt.cons j))
ix
fold ::
(MultiValue.C a) =>
(Exp a -> Exp a -> Exp a) ->
Linear sh0 sh1 a ->
Linear (sh0:.IndexInt.Int) sh1 a
fold f (Cons fsh reduce) =
Cons
(fsh . Linear.tail)
(\sh code jx ->
reduce (Linear.tail sh)
(fold1CodeLinear f (Expr.lift0 (Linear.head sh)) code) jx)
instance Core.Process (T sh0 sh1 a) where