module Numeric.Tensor where
import GHC.TypeLits
import GHC.Prim
import Data.Proxy
data Dim (ds :: [Nat]) = Dim
class Dimensions (ds :: [Nat]) where
dims :: Dim ds -> [Int]
headDim :: Dim (d ': ds) -> Proxy d
headDim _ = Proxy
tailDim :: Dim (d ': ds) -> Dim ds
tailDim _ = Dim
instance Dimensions '[] where
dims _ = []
instance (KnownNat d, Dimensions ds) => Dimensions (d ': ds) where
dims x = (fromIntegral . natVal $ headDim x) : dims (tailDim x)
printCrazy :: Dimensions d => Dim d -> String
printCrazy d = show $ dims d
contraDimsType :: Tensor t n m -> Dim n
contraDimsType _ = Dim
coDimsType :: Tensor t n m -> Dim m
coDimsType _ = Dim
contraDims :: Dimensions n => Tensor t n m -> [Int]
contraDims = dims . contraDimsType
coDims :: Dimensions m => Tensor t n m -> [Int]
coDims = dims . coDimsType
type Vec2 = Tensor Double '[2] '[]
vec2 :: Double -> Double -> Vec2
vec2 x y = T10 $ Vector2 x y
class TensorCalculus t (ns :: [Nat]) (ms :: [Nat]) where
data Tensor t ns ms
type TensorStore t ns ms
instance TensorCalculus t '[] '[] where
newtype Tensor t '[] '[] = T00 t deriving (Bounded, Enum, Eq, Integral, Num, Fractional, Floating, Ord, Read, Real, RealFrac, RealFloat, Show)
type TensorStore t '[] '[] = t
instance TensorCalculus t '[n] '[] where
newtype Tensor t '[n] '[] = T10 (SomeVector t n)
type TensorStore t '[n] '[] = SomeVector t n
instance TensorCalculus t '[] '[m] where
newtype Tensor t '[] '[m] = T01 (SomeVector t m)
type TensorStore t '[] '[m] = SomeVector t m
instance TensorCalculus t '[n1, n2] '[] where
newtype Tensor t '[n1, n2] '[] = T20 (SomeMatrix t n1 n2)
type TensorStore t '[n1, n2] '[] = SomeMatrix t n1 n2
instance TensorCalculus t '[n] '[m] where
newtype Tensor t '[n] '[m] = T11 (SomeMatrix t n m)
type TensorStore t '[n] '[m] = SomeMatrix t n m
instance TensorCalculus t '[] '[m1,m2] where
newtype Tensor t '[] '[m1,m2] = T02 (SomeMatrix t m1 m2)
type TensorStore t '[] '[m1,m2] = SomeMatrix t m1 m2
instance TensorCalculus t (n1 ': n2 ': n3 ': ns) '[] where
newtype Tensor t (n1 ': n2 ': n3 ': ns) '[] = Tn0 (NDArray t) deriving Show
type TensorStore t (n1 ': n2 ': n3 ': ns) '[] = NDArray t
instance TensorCalculus t (n1 ': n2 ': n3 ': ns) '[m0] where
newtype Tensor t (n1 ': n2 ': n3 ': ns) '[m0] = Tn1 (NDArray t) deriving Show
type TensorStore t (n1 ': n2 ': n3 ': ns) '[m0] = NDArray t
instance TensorCalculus t (n1 ': n2 ': n3 ': ns) '[m0,m1] where
newtype Tensor t (n1 ': n2 ': n3 ': ns) '[m0, m1] = Tn2 (NDArray t) deriving Show
type TensorStore t (n1 ': n2 ': n3 ': ns) '[m0, m1] = NDArray t
instance TensorCalculus t '[] (m1 ': m2 ': m3 ': ms) where
newtype Tensor t '[] (m1 ': m2 ': m3 ': ms) = T0m (NDArray t) deriving Show
type TensorStore t '[] (m1 ': m2 ': m3 ': ms) = NDArray t
instance TensorCalculus t '[n1] (m1 ': m2 ': m3 ': ms) where
newtype Tensor t '[n1] (m1 ': m2 ': m3 ': ms) = T1m (NDArray t) deriving Show
type TensorStore t '[n1] (m1 ': m2 ': m3 ': ms) = NDArray t
instance TensorCalculus t '[n1, n2] (m1 ': m2 ': m3 ': ms) where
newtype Tensor t '[n1, n2] (m1 ': m2 ': m3 ': ms) = T2m (NDArray t) deriving Show
type TensorStore t '[n1, n2] (m1 ': m2 ': m3 ': ms) = NDArray t
instance TensorCalculus t (n1 ': n2 ': n3 ': ns) (m1 ': m2 ': m3 ': ms) where
newtype Tensor t (n1 ': n2 ': n3 ': ns) (m1 ': m2 ': m3 ': ms) = Tnm (NDArray t) deriving Show
type TensorStore t (n1 ': n2 ': n3 ': ns) (m1 ': m2 ': m3 ': ms) = NDArray t
deriving instance Show (SomeVector t n) => Show (Tensor t '[n] '[])
deriving instance Plus (SomeVector t n) => Plus (Tensor t '[n] '[])
deriving instance Show (SomeVector t m) => Show (Tensor t '[] '[m])
deriving instance Plus (SomeVector t m) => Plus (Tensor t '[] '[m])
deriving instance Num (SomeVector t n) => Num (Tensor t '[n] '[])
deriving instance Show (SomeMatrix t n0 n1) => Show (Tensor t '[n0, n1] '[])
deriving instance Show (SomeMatrix t n0 m0) => Show (Tensor t '[n0] '[m0])
deriving instance Show (SomeMatrix t m0 m1) => Show (Tensor t '[] '[m0,m1])
--contraV
class Plus a where
plus :: a -> a -> a
instance Num t => Plus (Vector1 t) where
plus (Vector1 a) (Vector1 b) = Vector1 (a+b)
instance Num t => Plus (Vector2 t) where
plus (Vector2 a1 a2) (Vector2 b1 b2) = Vector2 (a1+b1) (a2+b2)
instance Num t => Plus (Vector3 t) where
plus (Vector3 a1 a2 a3) (Vector3 b1 b2 b3) = Vector3 (a1+b1) (a2+b2) (a3+b3)
instance Num t => Plus (Vector4 t) where
plus (Vector4 a1 a2 a3 a4) (Vector4 b1 b2 b3 b4) = Vector4 (a1+b1) (a2+b2) (a3+b3) (a4+b4)
instance Num t => Plus (VectorN t n) where
plus (VectorN as) (VectorN bs) = VectorN $ zipWith (+) as bs
newtype Vector1 t = Vector1 t
deriving Show
data Vector2 t = Vector2 t t
deriving Show
data Vector3 t = Vector3 t t t
deriving Show
data Vector4 t = Vector4 t t t t
deriving Show
newtype VectorN t (n::Nat) = VectorN [t]
deriving Show
newtype Matrix1x1 t = Matrix1x1 t
deriving Show
data Matrix2x2 t = Matrix2x2 t t t t
deriving Show
newtype MatrixNxM t (n::Nat) (m::Nat) = MatrixNxM [[t]]
deriving Show
data NDArray t = NDArray ByteArray#
instance Show (NDArray t) where
show _ = "Big array"
type family SomeVector t (n :: Nat) = v | v -> t n where
SomeVector t 1 = Vector1 t
SomeVector t 2 = Vector2 t
SomeVector t 3 = Vector3 t
SomeVector t 4 = Vector4 t
SomeVector t n = VectorN t n
type family SomeMatrix t (n :: Nat) (m :: Nat) = v | v -> t n m where
SomeMatrix t 1 1 = Matrix1x1 t
SomeMatrix t 2 2 = Matrix2x2 t
SomeMatrix t n m = MatrixNxM t n m
class AppendDim a b c | a b -> c where
appendDim :: a -> b -> c
instance AppendDim t (Vector2 t) (Vector3 t) where
appendDim a (Vector2 b1 b2) = Vector3 a b1 b2
instance AppendDim t (Vector3 t) (Vector4 t) where
appendDim a (Vector3 b1 b2 b3) = Vector4 a b1 b2 b3
instance AppendDim t (Vector4 t) (VectorN t 5) where
appendDim a (Vector4 b1 b2 b3 b4) = VectorN [a,b1,b2,b3,b4]
instance (m ~ (n+1), 5 <= m, 4 <= n) => AppendDim t (VectorN t n) (VectorN t m) where
appendDim a (VectorN bs) = VectorN $ a : bs