{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Slice where

import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.Array.Comfort.Boxed as BoxedArray
import Data.Array.Comfort.Shape ((::+)((::+)))

import qualified Data.Traversable as Trav
import qualified Data.List as List
import Data.Map (Map)


{- $setup
>>> import qualified Numeric.BLAS.Slice as Slice
>>> import Test.Slice (shapeInt)
>>>
>>> import qualified Data.Array.Comfort.Boxed as Array
>>> import qualified Data.Array.Comfort.Shape as Shape
>>> import qualified Data.Map as Map
>>> import Data.Array.Comfort.Shape ((::+)((::+)))
>>> import Data.Array.Comfort.Boxed ((!))
>>>
>>> import Control.Applicative (liftA3, pure)
>>>
>>> import qualified Test.QuickCheck as QC
>>>
>>> genSlice :: sh -> QC.Gen (Slice.T sh)
>>> genSlice sh =
>>>    liftA3 Slice.Cons (QC.choose (0,100)) (QC.choose (1,100)) (pure sh)
>>>
>>> genSlice2 :: shA -> shB -> QC.Gen (Slice.T shA, Slice.T shB)
>>> genSlice2 shA shB = do
>>>    s <- QC.choose (0,100)
>>>    k <- QC.choose (1,100)
>>>    return (Slice.Cons s k shA, Slice.Cons s k shB)
-}


data T sh = Cons {forall sh. T sh -> Int
start, forall sh. T sh -> Int
skip :: Int, forall sh. T sh -> sh
shape :: sh}
   deriving (T sh -> T sh -> Bool
forall sh. Eq sh => T sh -> T sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: T sh -> T sh -> Bool
$c/= :: forall sh. Eq sh => T sh -> T sh -> Bool
== :: T sh -> T sh -> Bool
$c== :: forall sh. Eq sh => T sh -> T sh -> Bool
Eq, Int -> T sh -> ShowS
forall sh. Show sh => Int -> T sh -> ShowS
forall sh. Show sh => [T sh] -> ShowS
forall sh. Show sh => T sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [T sh] -> ShowS
$cshowList :: forall sh. Show sh => [T sh] -> ShowS
show :: T sh -> String
$cshow :: forall sh. Show sh => T sh -> String
showsPrec :: Int -> T sh -> ShowS
$cshowsPrec :: forall sh. Show sh => Int -> T sh -> ShowS
Show)

fromShape :: (Shape.C sh) => sh -> T sh
fromShape :: forall sh. C sh => sh -> T sh
fromShape = forall sh. Int -> Int -> sh -> T sh
Cons Int
0 Int
1


row ::
   (Shape.Indexed sh0, Shape.C sh1) => Shape.Index sh0 -> T (sh0,sh1) -> T sh1
row :: forall sh0 sh1.
(Indexed sh0, C sh1) =>
Index sh0 -> T (sh0, sh1) -> T sh1
row Index sh0
ix0 (Cons Int
s Int
k (sh0
sh0,sh1
sh1)) =
   forall sh. Int -> Int -> sh -> T sh
Cons (Int
s forall a. Num a => a -> a -> a
+ Int
k forall a. Num a => a -> a -> a
* forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset sh0
sh0 Index sh0
ix0 forall a. Num a => a -> a -> a
* forall sh. C sh => sh -> Int
Shape.size sh1
sh1) Int
k sh1
sh1

column ::
   (Shape.C sh0, Shape.Indexed sh1) => Shape.Index sh1 -> T (sh0,sh1) -> T sh0
column :: forall sh0 sh1.
(C sh0, Indexed sh1) =>
Index sh1 -> T (sh0, sh1) -> T sh0
column Index sh1
ix1 (Cons Int
s Int
k (sh0
sh0,sh1
sh1)) =
   let (Int
size1, Index sh1 -> Int
offset1) = forall sh. Indexed sh => sh -> (Int, Index sh -> Int)
Shape.sizeOffset sh1
sh1
   in forall sh. Int -> Int -> sh -> T sh
Cons (Int
s forall a. Num a => a -> a -> a
+ Int
k forall a. Num a => a -> a -> a
* Index sh1 -> Int
offset1 Index sh1
ix1) (Int
k forall a. Num a => a -> a -> a
* Int
size1) sh0
sh0

{- |
prop> QC.forAll (QC.choose (1,100)) $ \numRows -> QC.forAll (QC.choose (0,100)) $ \numColumns -> QC.forAll (genSlice (shapeInt numRows, shapeInt numColumns)) $ \slice -> QC.forAll (QC.elements $ Shape.indices $ shapeInt numRows) $ \ix -> Slice.row ix slice == Slice.rowArray slice ! ix
-}
rowArray ::
   (Shape.Indexed sh0, Shape.C sh1) =>
   T (sh0,sh1) -> BoxedArray.Array sh0 (T sh1)
rowArray :: forall sh0 sh1.
(Indexed sh0, C sh1) =>
T (sh0, sh1) -> Array sh0 (T sh1)
rowArray (Cons Int
s Int
k (sh0
sh0,sh1
sh1)) =
   let step :: Int
step = forall sh. C sh => sh -> Int
Shape.size sh1
sh1 forall a. Num a => a -> a -> a
* Int
k
   in forall sh a. C sh => sh -> [a] -> Array sh a
BoxedArray.fromList sh0
sh0 forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
List.map (\Int
si -> forall sh. Int -> Int -> sh -> T sh
Cons Int
si Int
k sh1
sh1) forall a b. (a -> b) -> a -> b
$
      forall a. Int -> [a] -> [a]
List.take (forall sh. C sh => sh -> Int
Shape.size sh0
sh0) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (Int
stepforall a. Num a => a -> a -> a
+) Int
s

{- |
prop> QC.forAll (QC.choose (0,100)) $ \numRows -> QC.forAll (QC.choose (1,100)) $ \numColumns -> QC.forAll (genSlice (shapeInt numRows, shapeInt numColumns)) $ \slice -> QC.forAll (QC.elements $ Shape.indices $ shapeInt numColumns) $ \ix -> Slice.column ix slice == Slice.columnArray slice ! ix
-}
columnArray ::
   (Shape.C sh0, Shape.Indexed sh1) =>
   T (sh0,sh1) -> BoxedArray.Array sh1 (T sh0)
columnArray :: forall sh0 sh1.
(C sh0, Indexed sh1) =>
T (sh0, sh1) -> Array sh1 (T sh0)
columnArray (Cons Int
s Int
k (sh0
sh0,sh1
sh1)) =
   let step :: Int
step = forall sh. C sh => sh -> Int
Shape.size sh1
sh1 forall a. Num a => a -> a -> a
* Int
k
   in forall sh a. C sh => sh -> [a] -> Array sh a
BoxedArray.fromList sh1
sh1 forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
List.map (\Int
si -> forall sh. Int -> Int -> sh -> T sh
Cons Int
si Int
step sh0
sh0) forall a b. (a -> b) -> a -> b
$
      forall a. Int -> [a] -> [a]
List.take (forall sh. C sh => sh -> Int
Shape.size sh1
sh1) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (Int
kforall a. Num a => a -> a -> a
+) Int
s


topSubmatrix ::
   (Shape.C sh, Shape.C sh0, Shape.C sh1) =>
   T (sh0::+sh1, sh) -> T (sh0,sh)
topSubmatrix :: forall sh sh0 sh1.
(C sh, C sh0, C sh1) =>
T (sh0 ::+ sh1, sh) -> T (sh0, sh)
topSubmatrix (Cons Int
s Int
k (sh0
sh0::+sh1
_sh1, sh
sh)) =
   forall sh. Int -> Int -> sh -> T sh
Cons Int
s Int
k (sh0
sh0,sh
sh)

bottomSubmatrix ::
   (Shape.C sh, Shape.C sh0, Shape.C sh1) =>
   T (sh0::+sh1, sh) -> T (sh1,sh)
bottomSubmatrix :: forall sh sh0 sh1.
(C sh, C sh0, C sh1) =>
T (sh0 ::+ sh1, sh) -> T (sh1, sh)
bottomSubmatrix (Cons Int
s Int
k (sh0
sh0::+sh1
sh1, sh
sh)) =
   forall sh. Int -> Int -> sh -> T sh
Cons (Int
s forall a. Num a => a -> a -> a
+ Int
k forall a. Num a => a -> a -> a
* forall sh. C sh => sh -> Int
Shape.size sh0
sh0) Int
k (sh1
sh1,sh
sh)


cartesianFromSquare :: T (Shape.Square sh) -> T (sh,sh)
cartesianFromSquare :: forall sh. T (Square sh) -> T (sh, sh)
cartesianFromSquare (Cons Int
s Int
k (Shape.Square sh
sh)) = forall sh. Int -> Int -> sh -> T sh
Cons Int
s Int
k (sh
sh,sh
sh)

squareRow ::
   (Shape.Indexed sh) => Shape.Index sh -> T (Shape.Square sh) -> T sh
squareRow :: forall sh. Indexed sh => Index sh -> T (Square sh) -> T sh
squareRow Index sh
ix0 = forall sh0 sh1.
(Indexed sh0, C sh1) =>
Index sh0 -> T (sh0, sh1) -> T sh1
row Index sh
ix0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh. T (Square sh) -> T (sh, sh)
cartesianFromSquare

squareColumn ::
   (Shape.Indexed sh) => Shape.Index sh -> T (Shape.Square sh) -> T sh
squareColumn :: forall sh. Indexed sh => Index sh -> T (Square sh) -> T sh
squareColumn Index sh
ix1 = forall sh0 sh1.
(C sh0, Indexed sh1) =>
Index sh1 -> T (sh0, sh1) -> T sh0
column Index sh
ix1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh. T (Square sh) -> T (sh, sh)
cartesianFromSquare



plane12 ::
   (Shape.Indexed sh0, Shape.C sh1, Shape.C sh2) =>
   Shape.Index sh0 -> T (sh0,sh1,sh2) -> T (sh1,sh2)
plane12 :: forall sh0 sh1 sh2.
(Indexed sh0, C sh1, C sh2) =>
Index sh0 -> T (sh0, sh1, sh2) -> T (sh1, sh2)
plane12 Index sh0
ix0 (Cons Int
s Int
k (sh0
sh0,sh1
sh1,sh2
sh2)) =
   forall sh. Int -> Int -> sh -> T sh
Cons (Int
s forall a. Num a => a -> a -> a
+ Int
k forall a. Num a => a -> a -> a
* forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset sh0
sh0 Index sh0
ix0 forall a. Num a => a -> a -> a
* forall sh. C sh => sh -> Int
Shape.size (sh1
sh1,sh2
sh2)) Int
k (sh1
sh1,sh2
sh2)

plane01 ::
   (Shape.C sh0, Shape.C sh1, Shape.Indexed sh2) =>
   Shape.Index sh2 -> T (sh0,sh1,sh2) -> T (sh0,sh1)
plane01 :: forall sh0 sh1 sh2.
(C sh0, C sh1, Indexed sh2) =>
Index sh2 -> T (sh0, sh1, sh2) -> T (sh0, sh1)
plane01 Index sh2
ix2 (Cons Int
s Int
k (sh0
sh0,sh1
sh1,sh2
sh2)) =
   let (Int
size2, Index sh2 -> Int
offset2) = forall sh. Indexed sh => sh -> (Int, Index sh -> Int)
Shape.sizeOffset sh2
sh2
   in forall sh. Int -> Int -> sh -> T sh
Cons (Int
s forall a. Num a => a -> a -> a
+ Int
k forall a. Num a => a -> a -> a
* Index sh2 -> Int
offset2 Index sh2
ix2) (Int
k forall a. Num a => a -> a -> a
* Int
size2) (sh0
sh0,sh1
sh1)

column2of3 ::
   (Shape.Indexed sh0, Shape.Indexed sh1, Shape.C sh2) =>
   Shape.Index sh0 -> Shape.Index sh1 -> T (sh0,sh1,sh2) -> T sh2
column2of3 :: forall sh0 sh1 sh2.
(Indexed sh0, Indexed sh1, C sh2) =>
Index sh0 -> Index sh1 -> T (sh0, sh1, sh2) -> T sh2
column2of3 Index sh0
ix0 Index sh1
ix1 = forall sh0 sh1.
(Indexed sh0, C sh1) =>
Index sh0 -> T (sh0, sh1) -> T sh1
row Index sh1
ix1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh0 sh1 sh2.
(Indexed sh0, C sh1, C sh2) =>
Index sh0 -> T (sh0, sh1, sh2) -> T (sh1, sh2)
plane12 Index sh0
ix0

column1of3 ::
   (Shape.Indexed sh0, Shape.C sh1, Shape.Indexed sh2) =>
   Shape.Index sh0 -> Shape.Index sh2 -> T (sh0,sh1,sh2) -> T sh1
column1of3 :: forall sh0 sh1 sh2.
(Indexed sh0, C sh1, Indexed sh2) =>
Index sh0 -> Index sh2 -> T (sh0, sh1, sh2) -> T sh1
column1of3 Index sh0
ix0 Index sh2
ix2 = forall sh0 sh1.
(C sh0, Indexed sh1) =>
Index sh1 -> T (sh0, sh1) -> T sh0
column Index sh2
ix2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh0 sh1 sh2.
(Indexed sh0, C sh1, C sh2) =>
Index sh0 -> T (sh0, sh1, sh2) -> T (sh1, sh2)
plane12 Index sh0
ix0

column0of3 ::
   (Shape.C sh0, Shape.Indexed sh1, Shape.Indexed sh2) =>
   Shape.Index sh1 -> Shape.Index sh2 -> T (sh0,sh1,sh2) -> T sh0
column0of3 :: forall sh0 sh1 sh2.
(C sh0, Indexed sh1, Indexed sh2) =>
Index sh1 -> Index sh2 -> T (sh0, sh1, sh2) -> T sh0
column0of3 Index sh1
ix1 Index sh2
ix2 = forall sh0 sh1.
(C sh0, Indexed sh1) =>
Index sh1 -> T (sh0, sh1) -> T sh0
column Index sh1
ix1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh0 sh1 sh2.
(C sh0, C sh1, Indexed sh2) =>
Index sh2 -> T (sh0, sh1, sh2) -> T (sh0, sh1)
plane01 Index sh2
ix2


left :: (Shape.C sh0, Shape.C sh1) => T (sh0::+sh1) -> T sh0
left :: forall sh0 sh1. (C sh0, C sh1) => T (sh0 ::+ sh1) -> T sh0
left (Cons Int
s Int
k (sh0
sh0::+sh1
_sh1)) = forall sh. Int -> Int -> sh -> T sh
Cons Int
s Int
k sh0
sh0

right :: (Shape.C sh0, Shape.C sh1) => T (sh0::+sh1) -> T sh1
right :: forall sh0 sh1. (C sh0, C sh1) => T (sh0 ::+ sh1) -> T sh1
right (Cons Int
s Int
k (sh0
sh0::+sh1
sh1)) = forall sh. Int -> Int -> sh -> T sh
Cons (Int
s forall a. Num a => a -> a -> a
+ Int
k forall a. Num a => a -> a -> a
* forall sh. C sh => sh -> Int
Shape.size sh0
sh0) Int
k sh1
sh1


{- |
prop> QC.forAll (fmap shapeInt $ QC.choose (0,100)) $ \shapeA -> QC.forAll (fmap shapeInt $ QC.choose (0,100)) $ \shapeB -> QC.forAll (fmap shapeInt $ QC.choose (0,100)) $ \shapeC -> QC.forAll (genSlice2 (Map.fromList $ ('a', shapeA) : ('b', shapeB) : ('c', shapeC) : []) (shapeA ::+ shapeB ::+ shapeC)) $ \(sliceMap, sliceParted) -> Slice.map sliceMap Map.! 'b' == Slice.left (Slice.right sliceParted)

prop> QC.forAll (QC.choose (0,100)) $ \numRows -> QC.forAll (QC.choose (0,100)) $ \numColumns -> let rowShape = shapeInt numRows; columnShape = shapeInt numColumns; mapShape = Map.fromList $ map (\k -> (k, columnShape)) (Shape.indices rowShape) in QC.forAll (genSlice2 mapShape (rowShape, columnShape)) $ \(sliceMap, sliceMatrix) -> Map.toAscList (Slice.map sliceMap) == Array.toAssociations (Slice.rowArray sliceMatrix)
-}
map :: (Shape.C sh) => T (Map k sh) -> Map k (T sh)
map :: forall sh k. C sh => T (Map k sh) -> Map k (T sh)
map (Cons Int
s Int
k Map k sh
m) =
   forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
   forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
Trav.mapAccumL
      (\Int
offset sh
sh -> (Int
offset forall a. Num a => a -> a -> a
+ forall sh. C sh => sh -> Int
Shape.size sh
sh forall a. Num a => a -> a -> a
* Int
k, forall sh. Int -> Int -> sh -> T sh
Cons Int
offset Int
k sh
sh)) Int
s Map k sh
m