-- {-# LANGUAGE BangPatterns #-} -- for Debug.Trace

-- |
-- Copyright   : (c) Johannes Kropp
-- License     : BSD 3-Clause
-- Maintainer  : Johannes Kropp <jodak932@gmail.com>

module Math.Nuha.Types where


import Data.Vector.Unboxed (Vector, Unbox, (!))
import qualified Data.Vector.Unboxed as V
-- import Foreign.Storable (Storable, sizeOf)
-- import qualified Debug.Trace as D
-- import Math.Nuha.Base (sizeOfElems)
import Math.Nuha.Internal


{- | Datatype for a holor which is basically a multidimensional array. Sometimes in literature the word tensor is used instead, but this is not correct because a tensor has additional properties that do not apply for multidimensional arrays in general. For an explanation of holors see https://en.wikipedia.org/wiki/Parry_Moon#Holors

Most often a holor is used as the type for a vector or matrix. Holors as row vectors have shape [1,n] and column vectors [m,1]. Matrices are of shape [m,n]. A holor with a single element has shape [1,1]. Note that the length of the shape is always at least two.

The indexing of the holor entries starts with 0 in each dimension.
-}
data Holor a = Holor
    { Holor a -> [Int]
hShape :: ![Int] -- ^ Shape of the holor. The dimension is the length of the shape
    , Holor a -> [Int]
hStrides :: ![Int] -- ^ Step sizes for each dimension, needed for indexing
    , Holor a -> Vector a
hValues :: !(Vector a)  -- ^ Values of the holor in row-major order
}

-- | Sum type of various errors that can be thrown in non trivial algorithms
data Error
    = NoUpperTriError
    | DimensionMismatchError
    | NoMatrixError
    | NoSquareMatrixError
    | TooFewRowsError
    | RankDeficiencyError
    | UnderdeterminedSystemError
    deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)

-- | 2-tuple
type T2 a  = (a, a)
-- | 3-tuple
type T3 a  = (a, a, a)
-- | 4-tuple
type T4 a  = (a, a, a, a)
-- | 2,2-tuple
type T22 a = T2 (T2 a)
-- | 3,2-tuple
type T32 a = T3 (T2 a)
-- | 4,2-tuple
type T42 a = T4 (T2 a)
-- | 2,3-tuple
type T23 a = T2 (T3 a)
-- | 3,3-tuple
type T33 a = T3 (T3 a)
-- | 4,3-tuple
type T43 a = T4 (T3 a)
-- | 2,4-tuple
type T24 a = T2 (T4 a)
-- | 3,4-tuple
type T34 a = T3 (T4 a)
-- | 4,4-tuple
type T44 a = T4 (T4 a)



-- TODO: improve formatting
instance (Show a, Unbox a) => Show (Holor a) where
    show :: Holor a -> String
show (Holor [Int]
shape [Int]
strides Vector a
values)
        | [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Int -> Int -> String
showVector Int
0 (Vector a -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector a
values)
        | [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = Int -> Int -> String
showMatrix Int
0 Int
lenMatrix
        | Bool
otherwise = Int -> Int -> String
showMatrices Int
0 Int
lenMatrix
        where
            [Int
i,Int
j] = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) [Int]
shape
            lenMatrix :: Int
lenMatrix = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j
            showVector :: Int -> Int -> String
showVector Int
start Int
end
                | Int
start Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end =
                    String
"  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show (Vector a
valuesVector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
!Int
start) String -> ShowS
forall a. [a] -> [a] -> [a]
++
                    Int -> Int -> String
showVector (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
end
                | Bool
otherwise = ShowS
forall a. a -> a
id String
"\n"

            showMatrix :: Int -> Int -> String
showMatrix Int
start Int
end
                | Int
start Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end =
                    Int -> Int -> String
showVector Int
start (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
j) String -> ShowS
forall a. [a] -> [a] -> [a]
++
                    Int -> Int -> String
showMatrix (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
j) Int
end
                | Bool
otherwise = ShowS
forall a. a -> a
id String
""

            showMatrices :: Int -> Int -> String
showMatrices Int
start Int
end
                | Int
start Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end =
                    ShowS
forall a. a -> a
id ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ (Int -> ShowS
forall a. Int -> [a] -> [a]
take (String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
subIndicesStr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) String
subIndicesStr)String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. a -> a
id String
",:,:] =" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
                    Int -> Int -> String
showMatrix Int
start (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenMatrix) String -> ShowS
forall a. [a] -> [a] -> [a]
++
                    Int -> Int -> String
showMatrices (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenMatrix) Int
lenValues
                | Bool
otherwise = ShowS
forall a. a -> a
id String
""
                where
                    subIndices :: [Int]
subIndices = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) ([Int] -> Int -> [Int]
fromIndexToMultiIndex [Int]
strides Int
start)
                    -- !subIndices_ = trace ("subIndices: " ++ show subIndices) ()
                    subIndicesStr :: String
subIndicesStr = [Int] -> String
forall a. Show a => a -> String
show [Int]
subIndices
                    lenValues :: Int
lenValues = Vector a -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector a
values
                    -- !lenValues_ = trace ("lenValues: " ++ show lenValues) ()


instance (Eq a, Unbox a) => Eq (Holor a) where
    {-# INLINE (==) #-}
    Holor a
h1 == :: Holor a -> Holor a -> Bool
== Holor a
h2 = (Holor a -> [Int]
forall a. Holor a -> [Int]
hShape Holor a
h1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== Holor a -> [Int]
forall a. Holor a -> [Int]
hShape Holor a
h2 Bool -> Bool -> Bool
&& Holor a -> [Int]
forall a. Holor a -> [Int]
hStrides Holor a
h1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== Holor a -> [Int]
forall a. Holor a -> [Int]
hStrides Holor a
h2 Bool -> Bool -> Bool
&& Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h1 Vector a -> Vector a -> Bool
forall a. Eq a => a -> a -> Bool
== Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h2)

    {-# INLINE (/=) #-}
    Holor a
h1 /= :: Holor a -> Holor a -> Bool
/= Holor a
h2 = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Holor a
h1 Holor a -> Holor a -> Bool
forall a. Eq a => a -> a -> Bool
== Holor a
h2

-- | Ordering is only relevant on the holor values (i.e. independent of shape)
instance (Ord a, Unbox a) => Ord (Holor a) where
    {-# INLINE compare #-}
    compare :: Holor a -> Holor a -> Ordering
compare Holor a
h1 Holor a
h2 = Vector a -> Vector a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h1) (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h2)

    {-# INLINE (<) #-}
    Holor a
h1 < :: Holor a -> Holor a -> Bool
< Holor a
h2 = (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h1) Vector a -> Vector a -> Bool
forall a. Ord a => a -> a -> Bool
< (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h2)

    {-# INLINE (<=) #-}
    Holor a
h1 <= :: Holor a -> Holor a -> Bool
<= Holor a
h2 = (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h1) Vector a -> Vector a -> Bool
forall a. Ord a => a -> a -> Bool
<= (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h2)

    {-# INLINE (>) #-}
    Holor a
h1 > :: Holor a -> Holor a -> Bool
> Holor a
h2 = (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h1) Vector a -> Vector a -> Bool
forall a. Ord a => a -> a -> Bool
> (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h2)

    {-# INLINE (>=) #-}
    Holor a
h1 >= :: Holor a -> Holor a -> Bool
>= Holor a
h2 = (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h1) Vector a -> Vector a -> Bool
forall a. Ord a => a -> a -> Bool
>= (Holor a -> Vector a
forall a. Holor a -> Vector a
hValues Holor a
h2)


{- TODO:
instance Storable a => Storable (Holor a) where
    sizeOfInt = 8
    sizeOfValues = sizeOfElems hlr

    {-# INLINE sizeOf #-}
    -- Size as size of holor values + size of dimension variable + size of shape information
    sizeOf hlr = sizeOfValues + sizeOfInt + (dim hlr) * sizeOfInt

    {-# INLINE alignment #-}
    alignment _ = alignment (undefined::a)

    {-# INLINE peek #-}
    peek ptr1 = do
        let ptr2 = castPtr ptr1 :: Ptr a
        valuesPtr <- peek ptr2
        dimPtr <- peekByteOff ptr2 (sizeOfValues)
        shapePtr <- peekByteOff ptr2 (sizeOfValues + dim * sizeOfInt)


        let values = peekArray
        let strides = fromShapeToStrides shape
        return $ Holor shape strides values

    {-# INLINE poke #-}
    poke ptr1 (Holor shape strides values) = do
        let ptr2 = castPtr ptr1 :: Ptr a

-}