{-# LANGUAGE PolyKinds #-} {-# LANGUAGE DataKinds #-} {-# OPTIONS_GHC -Wall #-} -- | Two-dimensional arrays. Two classes are supplied -- -- - 'Matrix' where shape information is held at type level, and -- - 'SomeMatrix' where shape is held at the value level. -- -- In both cases, the underlying data is contained as a flat vector for efficiency purposes. module NumHask.Matrix ( Matrix(..) , SomeMatrix(..) , ShapeM(..) -- * Conversion , someMatrix , unsafeToMatrix , toMatrix , unsafeFromVV , toCol , toRow , fromCol , fromRow , col , row -- * Operations , mmult ) where import qualified Protolude as P import Protolude (($), Functor(..), Show, Eq(..), (.), (<$>), Foldable(..), Int, Maybe(..)) import Data.Distributive as D import Data.Functor.Rep import Data.Proxy (Proxy(..)) import GHC.TypeLits import NumHask.Algebra import NumHask.Naperian import NumHask.Vector import Test.QuickCheck import qualified Data.Vector as V import GHC.Show import GHC.Exts -- | a two-dimensional array where shape is specified at the type level -- The main purpose of this, beyond safe typing, is to supply the Representable instance with an initial object. -- A single Boxed 'Data.Vector.Vector' is used underneath for efficient slicing, but this may change or become polymorphic in the future. newtype Matrix m n a = Matrix { flattenMatrix :: V.Vector a } deriving (Functor, Eq, Foldable) instance forall m n. (KnownNat m, KnownNat n) => HasShape (Matrix (m::Nat) (n::Nat)) where type Shape (Matrix m n) = (Int,Int) shape _= ( P.fromInteger $ natVal (Proxy :: Proxy m) , P.fromInteger $ natVal (Proxy :: Proxy n)) ndim = P.length . shape instance (KnownNat m, KnownNat n) => Naperian (Matrix (m::Nat) (n::Nat)) instance (Show a, KnownNat m, KnownNat n) => Show (Matrix (m::Nat) (n::Nat) a) where show = show . someMatrix instance (KnownNat m, KnownNat n, Arbitrary a, AdditiveUnital a) => Arbitrary (Matrix m n a) where arbitrary = frequency [ (1, P.pure zero) , (9,fromList <$> vector (m*n)) ] where n = P.fromInteger $ natVal (Proxy :: Proxy n) m = P.fromInteger $ natVal (Proxy :: Proxy m) instance (KnownNat m, KnownNat n) => Distributive (Matrix m n) where distribute f = Matrix $ V.generate (n*m) $ \i -> fmap (\(Matrix v) -> V.unsafeIndex v i) f where m = P.fromInteger $ natVal (Proxy :: Proxy m) n = P.fromInteger $ natVal (Proxy :: Proxy n) instance (KnownNat m, KnownNat n) => Representable (Matrix m n) where type Rep (Matrix m n) = (P.Int, P.Int) tabulate f = Matrix $ V.generate (m*n) (\x -> f (divMod x (m*n))) where m = P.fromInteger $ natVal (Proxy :: Proxy m) n = P.fromInteger $ natVal (Proxy :: Proxy n) index (Matrix xs) (i0,i1) = xs V.! (i0*m + i1) where m = P.fromInteger $ natVal (Proxy :: Proxy m) -- | a two-dimensional array where shape is specified at the value level as a '(Int,Int)' -- Use this to avoid type-level hasochism by demoting a 'Matrix' with 'someMatrix' data SomeMatrix a = SomeMatrix (Int,Int) (V.Vector a) deriving (Functor, Eq, Foldable) instance HasShape SomeMatrix where type Shape SomeMatrix = (Int,Int) shape (SomeMatrix sh _) = sh ndim = P.length . shape instance (Show a) => Show (SomeMatrix a) where show (SomeMatrix _ v) = show (P.toList v) -- ** conversion -- | convert from a 'Matrix' to a 'SomeMatrix' someMatrix :: (KnownNat m, KnownNat n) => Matrix (m::Nat) (n::Nat) a -> SomeMatrix a someMatrix v = SomeMatrix (shape v) (flattenMatrix v) -- | convert from a 'SomeMatrix' to a 'Matrix' with no shape check unsafeToMatrix :: SomeMatrix a -> Matrix (m::Nat) (n::Nat) a unsafeToMatrix (SomeMatrix _ v) = Matrix v -- | convert from a 'SomeMatrix' to a 'Matrix', checking shape toMatrix :: forall a m n. (KnownNat m, KnownNat n) => SomeMatrix a -> Maybe (Matrix (m::Nat) (n::Nat) a) toMatrix (SomeMatrix s v) = if s==(m,n) then Just $ Matrix v else Nothing where m = P.fromInteger $ natVal (Proxy :: Proxy m) n = P.fromInteger $ natVal (Proxy :: Proxy n) -- | from flat list instance (KnownNat m, KnownNat n, AdditiveUnital a) => IsList (Matrix m n a) where type Item (Matrix m n a) = a fromList l = Matrix $ V.fromList $ P.take (m*n) $ l P.++ P.repeat zero where m = P.fromInteger $ natVal (Proxy :: Proxy m) n = P.fromInteger $ natVal (Proxy :: Proxy n) toList (Matrix v) = V.toList v -- | from nested list instance IsList (SomeMatrix a) where type Item (SomeMatrix a) = [a] fromList l = SomeMatrix (P.length l,P.length $ P.head l) (V.fromList $ P.mconcat l) toList (SomeMatrix (m,n) v) = (\i -> V.toList $ V.unsafeSlice (i*n) n v) <$> [0..(m - 1)] -- | just used to get sensible arbitrary instances of SomeMatrix newtype ShapeM = ShapeM { unshapeM :: (Int,Int) } instance Arbitrary ShapeM where arbitrary = (\m n -> ShapeM (unshapeV m, unshapeV n)) <$> arbitrary P.<*> arbitrary instance (Arbitrary a) => Arbitrary (SomeMatrix a) where arbitrary = frequency [ (1, P.pure (SomeMatrix (zero,zero) V.empty)) , (9,fromList <$> (P.take <$> ((\m n -> unshapeV m * unshapeV n) <$> arbitrary P.<*> arbitrary) P.<*> vector 20)) ] -- | conversion from a double Vector representation unsafeFromVV :: forall a m n. ( ) => Vector m (Vector n a) -> Matrix m n a unsafeFromVV vv = Matrix $ P.foldr ((V.++) . toVec) V.empty vv -- | convert a 'Vector' to a column 'Matrix' toCol :: forall a n. ( ) => Vector n a -> Matrix 1 n a toCol v = Matrix $ toVec v -- | convert a 'Vector' to a row 'Matrix' toRow :: forall a m. ( ) => Vector m a -> Matrix m 1 a toRow v = Matrix $ toVec v -- | convert a row 'Matrix' to a 'Vector' fromCol :: forall a n. ( ) => Matrix 1 n a -> Vector n a fromCol m = Vector $ flattenMatrix m -- | convert a column 'Matrix' to a 'Vector' fromRow :: forall a m. ( ) => Matrix m 1 a -> Vector m a fromRow m = Vector $ flattenMatrix m -- | extract a row from a 'Matrix' as a 'Vector' row :: forall a m n. (KnownNat m, KnownNat n) => P.Int -> Matrix m n a -> Vector n a row i (Matrix a) = Vector $ V.unsafeSlice (i*m) n a where m = P.fromInteger $ natVal (Proxy :: Proxy m) n = P.fromInteger $ natVal (Proxy :: Proxy n) -- | extract a column from a 'Matrix' as a 'Vector' col :: forall a m n. (KnownNat m, KnownNat n) => P.Int -> Matrix m n a -> Vector m a col i (Matrix a) = Vector $ V.generate m (\x -> a V.! (i+x*n)) where m = P.fromInteger $ natVal (Proxy :: Proxy m) n = P.fromInteger $ natVal (Proxy :: Proxy n) -- ** Operations -- | matrix multiplication for a 'Matrix' mmult :: forall m n k a. (CRing a, KnownNat m, KnownNat n, KnownNat k) => Matrix m k a -> Matrix k n a -> Matrix m n a mmult x y = tabulate (\(i,j) -> row i x <.> col j y)