{- |
Generate and apply index maps.
This unifies the @replicate@ and @slice@ functions of the @accelerate@ package.
However the structure of slicing and replicating cannot depend on parameters.
If you need that, you must use 'ShapeDep.backpermute' and friends.
-}
{-
Some notes on the design choice:

Instead of the shallow embedding implemented by the 'T' type,
we could maintain a symbolic representation of the Slice and Replicate pattern,
like the accelerate package does.
We actually used that representation in former versions.
It has however some drawbacks:

* We need additional type functions that map from the pattern
  to the source and the target shape and we need a proof,
  that the images of these type functions are actually shapes.
  This worked already, but was rather cumbersome.

* We need a way to store and pass this pattern through the Parameter handler.
  This yields new problems:
  We need a wrapper type for wrapping Index, Shape, Slice, Replicate, Fold patterns.
  Then the question is whether we use one Wrap type with a phantom parameter
  or whether we define a Wrap type for every pattern type.
  That is, the options are to write either

  > Wrap Shape (Z:.Int:.Int)

  or

  > Shape (Z:.Int:.Int)

  The first one seems to save us many duplicate instances of
  Storable, MultiValue etc.
  and it allows us easily to reuse the (:.) for all kinds of patterns.
  However, we need a way to restrict the element type of the (:.)-list elements.
  We can define that using variable ConstraintKinds,
  but e.g. we are not able to add a Storable superclass constraint
  to the instance Storable (Wrap constr).
  That is, we are left with the second option
  and had to define a lot of similar Storable, MultiValue instances.
-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Symbolic.Slice (
   T,
   Cubic,
   apply,
   passAny,
   pass,
   pick,
   pickFst,
   pickSnd,
   extrude,
   extrudeFst,
   extrudeSnd,
   transpose,
   (Core.$:.),

   id,
   first,
   second,
   compose,
   ) where

import qualified Data.Array.Knead.Symbolic.ShapeDependent as ShapeDep
import qualified Data.Array.Knead.Symbolic.Private as Core

import qualified Data.Array.Knead.Shape.Cubic.Int as Index
import qualified Data.Array.Knead.Shape.Cubic as Cubic
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Shape.Cubic ((#:.), (:.)((:.)), )
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )

import qualified Type.Data.Num.Unary as Unary

import qualified Prelude as P
import Prelude hiding (id, zipWith, zipWith3, zip, zip3, replicate, )



{-
This data type is almost identical to Core.Array.
The only difference is,
that the shape @sh1@ in T can depend on another shape @sh0@.
-}
data T sh0 sh1 =
   forall ix0 ix1.
   (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
   Cons
      (Exp sh0 -> Exp sh1)
      (Exp ix1 -> Exp ix0)

{- |
This is essentially a 'ShapeDep.backpermute'.
-}
apply ::
   (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T sh0 sh1 ->
   array sh0 a ->
   array sh1 a
apply :: forall (array :: * -> * -> *) sh0 sh1 a.
(C array, C sh0, C sh1, C a) =>
T sh0 sh1 -> array sh0 a -> array sh1 a
apply (Cons Exp sh0 -> Exp sh1
fsh Exp ix1 -> Exp ix0
fix) =
   (Exp sh0 -> Exp sh1)
-> (Exp ix1 -> Exp ix0) -> array sh0 a -> array sh1 a
forall (array :: * -> * -> *) sh0 ix0 sh1 ix1 a.
(C array, C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1)
-> (Exp ix1 -> Exp ix0) -> array sh0 a -> array sh1 a
ShapeDep.backpermute Exp sh0 -> Exp sh1
fsh Exp ix1 -> Exp ix0
fix


pickFst :: Exp (Shape.Index n) -> T (n,sh) sh
pickFst :: forall n sh. Exp (Index n) -> T (n, sh) sh
pickFst Exp (Index n)
i = (Exp (n, sh) -> Exp sh)
-> (Exp (Index sh) -> Exp (Index n, Index sh)) -> T (n, sh) sh
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons Exp (n, sh) -> Exp sh
forall (val :: * -> *) a b. Value val => val (a, b) -> val b
Expr.snd (Exp (Index n) -> Exp (Index sh) -> Exp (Index n, Index sh)
forall (val :: * -> *) a b.
Value val =>
val a -> val b -> val (a, b)
Expr.zip Exp (Index n)
i)

pickSnd :: Exp (Shape.Index n) -> T (sh,n) sh
pickSnd :: forall n sh. Exp (Index n) -> T (sh, n) sh
pickSnd Exp (Index n)
i = (Exp (sh, n) -> Exp sh)
-> (Exp (Index sh) -> Exp (Index sh, Index n)) -> T (sh, n) sh
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons Exp (sh, n) -> Exp sh
forall (val :: * -> *) a b. Value val => val (a, b) -> val a
Expr.fst ((Exp (Index sh) -> Exp (Index n) -> Exp (Index sh, Index n))
-> Exp (Index n) -> Exp (Index sh) -> Exp (Index sh, Index n)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Exp (Index sh) -> Exp (Index n) -> Exp (Index sh, Index n)
forall (val :: * -> *) a b.
Value val =>
val a -> val b -> val (a, b)
Expr.zip Exp (Index n)
i)

{- |
Extrusion has the potential to do duplicate work.
Only use it to add dimensions of size 1, e.g. numeric 1 or unit @()@
or to duplicate slices of physical arrays.
-}
extrudeFst :: Exp n -> T sh (n,sh)
extrudeFst :: forall n sh. Exp n -> T sh (n, sh)
extrudeFst Exp n
n = (Exp sh -> Exp (n, sh))
-> (Exp (Index n, Index sh) -> Exp (Index sh)) -> T sh (n, sh)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons (Exp n -> Exp sh -> Exp (n, sh)
forall (val :: * -> *) a b.
Value val =>
val a -> val b -> val (a, b)
Expr.zip Exp n
n) Exp (Index n, Index sh) -> Exp (Index sh)
forall (val :: * -> *) a b. Value val => val (a, b) -> val b
Expr.snd

extrudeSnd :: Exp n -> T sh (sh,n)
extrudeSnd :: forall n sh. Exp n -> T sh (sh, n)
extrudeSnd Exp n
n = (Exp sh -> Exp (sh, n))
-> (Exp (Index sh, Index n) -> Exp (Index sh)) -> T sh (sh, n)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons ((Exp sh -> Exp n -> Exp (sh, n)) -> Exp n -> Exp sh -> Exp (sh, n)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Exp sh -> Exp n -> Exp (sh, n)
forall (val :: * -> *) a b.
Value val =>
val a -> val b -> val (a, b)
Expr.zip Exp n
n) Exp (Index sh, Index n) -> Exp (Index sh)
forall (val :: * -> *) a b. Value val => val (a, b) -> val a
Expr.fst

transpose :: T (sh0,sh1) (sh1,sh0)
transpose :: forall sh0 sh1. T (sh0, sh1) (sh1, sh0)
transpose = (Exp (sh0, sh1) -> Exp (sh1, sh0))
-> (Exp (Index sh1, Index sh0) -> Exp (Index sh0, Index sh1))
-> T (sh0, sh1) (sh1, sh0)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons Exp (sh0, sh1) -> Exp (sh1, sh0)
forall (val :: * -> *) a b. Value val => val (a, b) -> val (b, a)
Expr.swap Exp (Index sh1, Index sh0) -> Exp (Index sh0, Index sh1)
forall (val :: * -> *) a b. Value val => val (a, b) -> val (b, a)
Expr.swap


-- Arrow combinators

id :: T sh sh
id :: forall sh. T sh sh
id = (Exp sh -> Exp sh) -> (Exp (Index sh) -> Exp (Index sh)) -> T sh sh
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons Exp sh -> Exp sh
forall a. a -> a
P.id Exp (Index sh) -> Exp (Index sh)
forall a. a -> a
P.id

first :: T sh0 sh1 -> T (sh0,sh) (sh1,sh)
first :: forall sh0 sh1 sh. T sh0 sh1 -> T (sh0, sh) (sh1, sh)
first (Cons Exp sh0 -> Exp sh1
fsh Exp ix1 -> Exp ix0
fix) = (Exp (sh0, sh) -> Exp (sh1, sh))
-> (Exp (ix1, Index sh) -> Exp (ix0, Index sh))
-> T (sh0, sh) (sh1, sh)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons ((Exp sh0 -> Exp sh1) -> Exp (sh0, sh) -> Exp (sh1, sh)
forall a b c. (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c)
Expr.mapFst Exp sh0 -> Exp sh1
fsh) ((Exp ix1 -> Exp ix0) -> Exp (ix1, Index sh) -> Exp (ix0, Index sh)
forall a b c. (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c)
Expr.mapFst Exp ix1 -> Exp ix0
fix)

second :: T sh0 sh1 -> T (sh,sh0) (sh,sh1)
second :: forall sh0 sh1 sh. T sh0 sh1 -> T (sh, sh0) (sh, sh1)
second (Cons Exp sh0 -> Exp sh1
fsh Exp ix1 -> Exp ix0
fix) = (Exp (sh, sh0) -> Exp (sh, sh1))
-> (Exp (Index sh, ix1) -> Exp (Index sh, ix0))
-> T (sh, sh0) (sh, sh1)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons ((Exp sh0 -> Exp sh1) -> Exp (sh, sh0) -> Exp (sh, sh1)
forall b c a. (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c)
Expr.mapSnd Exp sh0 -> Exp sh1
fsh) ((Exp ix1 -> Exp ix0) -> Exp (Index sh, ix1) -> Exp (Index sh, ix0)
forall b c a. (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c)
Expr.mapSnd Exp ix1 -> Exp ix0
fix)

infixr 1 `compose`

compose :: T sh0 sh1 -> T sh1 sh2 -> T sh0 sh2
compose :: forall sh0 sh1 sh2. T sh0 sh1 -> T sh1 sh2 -> T sh0 sh2
compose (Cons Exp sh0 -> Exp sh1
fshA Exp ix1 -> Exp ix0
fixA) (Cons Exp sh1 -> Exp sh2
fshB Exp ix1 -> Exp ix0
fixB) = (Exp sh0 -> Exp sh2) -> (Exp ix1 -> Exp ix0) -> T sh0 sh2
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons (Exp sh1 -> Exp sh2
fshB (Exp sh1 -> Exp sh2) -> (Exp sh0 -> Exp sh1) -> Exp sh0 -> Exp sh2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp sh0 -> Exp sh1
fshA) (Exp ix1 -> Exp ix0
fixA (Exp ix1 -> Exp ix0) -> (Exp ix1 -> Exp ix1) -> Exp ix1 -> Exp ix0
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp ix1 -> Exp ix1
Exp ix1 -> Exp ix0
fixB)


type Cubic rank0 rank1 = T (Cubic.Shape rank0) (Cubic.Shape rank1)

{- |
Like @Any@ in @accelerate@.
-}
passAny :: Cubic rank rank
passAny :: forall rank. Cubic rank rank
passAny = (Exp (Shape rank) -> Exp (Shape rank))
-> (Exp (Index rank) -> Exp (Index rank))
-> T (Shape rank) (Shape rank)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons Exp (Shape rank) -> Exp (Shape rank)
forall a. a -> a
P.id Exp (Index rank) -> Exp (Index rank)
forall a. a -> a
P.id

{- |
Like @All@ in @accelerate@.
-}
pass ::
   (Unary.Natural rank0, Unary.Natural rank1) =>
   Cubic rank0 rank1 ->
   Cubic (Unary.Succ rank0) (Unary.Succ rank1)
pass :: forall rank0 rank1.
(Natural rank0, Natural rank1) =>
Cubic rank0 rank1 -> Cubic (Succ rank0) (Succ rank1)
pass (Cons Exp (Shape rank0) -> Exp (Shape rank1)
fsh Exp ix1 -> Exp ix0
fix) =
   (Exp (Shape (Succ rank0)) -> Exp (Shape (Succ rank1)))
-> (Exp (Index (Succ rank1)) -> Exp (Index (Succ rank0)))
-> T (Shape (Succ rank0)) (Shape (Succ rank1))
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons
      ((Atom (Shape rank0) :. Atom Int)
-> (Decomposed Exp (Atom (Shape rank0) :. Atom Int)
    -> Exp (Shape rank1) :. Exp Int)
-> Exp (PatternTuple (Atom (Shape rank0) :. Atom Int))
-> Exp (Composed (Exp (Shape rank1) :. Exp Int))
forall a pattern.
(Compose a, Decompose pattern) =>
pattern
-> (Decomposed Exp pattern -> a)
-> Exp (PatternTuple pattern)
-> Exp (Composed a)
Expr.modify (Atom (Shape rank0)
forall a. Atom a
atomAtom (Shape rank0) -> Atom Int -> Atom (Shape rank0) :. Atom Int
forall tail head. tail -> head -> tail :. head
:.Atom Int
forall a. Atom a
atom) ((Decomposed Exp (Atom (Shape rank0) :. Atom Int)
  -> Exp (Shape rank1) :. Exp Int)
 -> Exp (PatternTuple (Atom (Shape rank0) :. Atom Int))
 -> Exp (Composed (Exp (Shape rank1) :. Exp Int)))
-> (Decomposed Exp (Atom (Shape rank0) :. Atom Int)
    -> Exp (Shape rank1) :. Exp Int)
-> Exp (PatternTuple (Atom (Shape rank0) :. Atom Int))
-> Exp (Composed (Exp (Shape rank1) :. Exp Int))
forall a b. (a -> b) -> a -> b
$ \(Exp (Shape rank0)
sh:.Exp Int
s) -> Exp (Shape rank0) -> Exp (Shape rank1)
fsh Exp (Shape rank0)
sh Exp (Shape rank1) -> Exp Int -> Exp (Shape rank1) :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
s)
      ((Atom (Index rank1) :. Atom Int)
-> (Decomposed Exp (Atom (Index rank1) :. Atom Int)
    -> Exp ix0 :. Exp Int)
-> Exp (PatternTuple (Atom (Index rank1) :. Atom Int))
-> Exp (Composed (Exp ix0 :. Exp Int))
forall a pattern.
(Compose a, Decompose pattern) =>
pattern
-> (Decomposed Exp pattern -> a)
-> Exp (PatternTuple pattern)
-> Exp (Composed a)
Expr.modify (Atom (Index rank1)
forall a. Atom a
atomAtom (Index rank1) -> Atom Int -> Atom (Index rank1) :. Atom Int
forall tail head. tail -> head -> tail :. head
:.Atom Int
forall a. Atom a
atom) ((Decomposed Exp (Atom (Index rank1) :. Atom Int)
  -> Exp ix0 :. Exp Int)
 -> Exp (PatternTuple (Atom (Index rank1) :. Atom Int))
 -> Exp (Composed (Exp ix0 :. Exp Int)))
-> (Decomposed Exp (Atom (Index rank1) :. Atom Int)
    -> Exp ix0 :. Exp Int)
-> Exp (PatternTuple (Atom (Index rank1) :. Atom Int))
-> Exp (Composed (Exp ix0 :. Exp Int))
forall a b. (a -> b) -> a -> b
$ \(Exp ix1
ix:.Exp Int
i) -> Exp ix1 -> Exp ix0
fix Exp ix1
ix Exp ix0 -> Exp Int -> Exp ix0 :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
i)

{- |
Like @Int@ in @accelerate/slice@.
-}
pick ::
   (Unary.Natural rank0, Unary.Natural rank1) =>
   Exp Index.Int ->
   Cubic rank0 rank1 ->
   Cubic (Unary.Succ rank0) rank1
pick :: forall rank0 rank1.
(Natural rank0, Natural rank1) =>
Exp Int -> Cubic rank0 rank1 -> Cubic (Succ rank0) rank1
pick Exp Int
i (Cons Exp (Shape rank0) -> Exp (Shape rank1)
fsh Exp ix1 -> Exp ix0
fix) =
   (Exp (Shape (Succ rank0)) -> Exp (Shape rank1))
-> (Exp ix1 -> Exp (T IndexTag (Succ rank0)))
-> T (Shape (Succ rank0)) (Shape rank1)
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons
      (Exp (Shape rank0) -> Exp (Shape rank1)
fsh (Exp (Shape rank0) -> Exp (Shape rank1))
-> (Exp (Shape (Succ rank0)) -> Exp (Shape rank0))
-> Exp (Shape (Succ rank0))
-> Exp (Shape rank1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Shape (Succ rank0)) -> Exp (Shape rank0)
forall (val :: * -> *) rank tag.
(Value val, Natural rank) =>
val (T tag (Succ rank)) -> val (T tag rank)
Cubic.tail)
      (\Exp ix1
ix -> Exp ix1 -> Exp ix0
fix Exp ix1
ix Exp (Index rank0) -> Exp Int -> Exp (T IndexTag (Succ rank0))
forall (val :: * -> *) tag rank.
Value val =>
val (T tag rank) -> val Int -> val (T tag (Succ rank))
#:. Exp Int
i)

{- |
Like @Int@ in @accelerate/replicate@.
-}
extrude ::
   (Unary.Natural rank0, Unary.Natural rank1) =>
   Exp Index.Int ->
   Cubic rank0 rank1 ->
   Cubic rank0 (Unary.Succ rank1)
extrude :: forall rank0 rank1.
(Natural rank0, Natural rank1) =>
Exp Int -> Cubic rank0 rank1 -> Cubic rank0 (Succ rank1)
extrude Exp Int
n (Cons Exp (Shape rank0) -> Exp (Shape rank1)
fsh Exp ix1 -> Exp ix0
fix) =
   (Exp (Shape rank0) -> Exp (Shape (Succ rank1)))
-> (Exp (T IndexTag (Succ rank1)) -> Exp ix0)
-> T (Shape rank0) (Shape (Succ rank1))
forall sh0 sh1 ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) -> (Exp ix1 -> Exp ix0) -> T sh0 sh1
Cons
      (\Exp (Shape rank0)
sh -> Exp (Shape rank0) -> Exp (Shape rank1)
fsh Exp (Shape rank0)
sh Exp (Shape rank1) -> Exp Int -> Exp (Shape (Succ rank1))
forall (val :: * -> *) tag rank.
Value val =>
val (T tag rank) -> val Int -> val (T tag (Succ rank))
#:. Exp Int
n)
      (Exp ix1 -> Exp ix0
fix (Exp ix1 -> Exp ix0)
-> (Exp (T IndexTag (Succ rank1)) -> Exp ix1)
-> Exp (T IndexTag (Succ rank1))
-> Exp ix0
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (T IndexTag (Succ rank1)) -> Exp ix1
Exp (T IndexTag (Succ rank1)) -> Exp (Index rank1)
forall (val :: * -> *) rank tag.
(Value val, Natural rank) =>
val (T tag (Succ rank)) -> val (T tag rank)
Cubic.tail)


instance Core.Process (T sh0 sh1) where