{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Matrix.Class
( Mutable
, Matrix(..)
, rows
, cols
, (!)
, fromVector
, fromList
, empty
, toList
, fromLists
, matrix
, fromRows
, takeRow
, toRows
, takeColumn
, toColumns
, toLists
, create
) where
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.ST (ST, runST)
import qualified Data.Vector.Generic as G
import Text.Printf
import qualified Data.Matrix.Class.Mutable as MM
type family Mutable (m :: (* -> *) -> * -> *) :: (* -> * -> *) -> * -> * -> *
class (MM.MMatrix (Mutable m) (G.Mutable v) a, G.Vector v a) => Matrix m v a where
dim :: m v a -> (Int, Int)
unsafeIndex :: m v a -> (Int, Int) -> a
unsafeFromVector :: (Int, Int) -> v a -> m v a
flatten :: m v a -> v a
flatten mat = G.generate (r*c) $ \i -> unsafeIndex mat (i `divMod` c)
where
(r,c) = dim mat
{-# INLINE flatten #-}
unsafeTakeRow :: m v a -> Int -> v a
unsafeTakeRow mat i = G.generate c $ \j -> unsafeIndex mat (i,j)
where
(_,c) = dim mat
{-# INLINE unsafeTakeRow #-}
unsafeTakeColumn :: m v a -> Int -> v a
unsafeTakeColumn mat j = G.generate r $ \i -> unsafeIndex mat (i,j)
where
(r,_) = dim mat
{-# INLINE unsafeTakeColumn #-}
takeDiag :: m v a -> v a
takeDiag mat = G.generate n $ \i -> unsafeIndex mat (i,i)
where
n = uncurry min . dim $ mat
{-# INLINE takeDiag #-}
thaw :: PrimMonad s => m v a -> s ((Mutable m) (G.Mutable v) (PrimState s) a)
unsafeThaw :: PrimMonad s
=> m v a -> s ((Mutable m) (G.Mutable v) (PrimState s) a)
freeze :: PrimMonad s
=> (Mutable m) (G.Mutable v) (PrimState s) a -> s (m v a)
unsafeFreeze :: PrimMonad s
=> (Mutable m) (G.Mutable v) (PrimState s) a -> s (m v a)
{-# MINIMAL dim, unsafeIndex, unsafeFromVector, thaw, unsafeThaw, freeze, unsafeFreeze #-}
rows :: Matrix m v a => m v a -> Int
rows = fst . dim
{-# INLINE rows #-}
cols :: Matrix m v a => m v a -> Int
cols = snd . dim
{-# INLINE cols #-}
(!) :: Matrix m v a => m v a -> (Int, Int) -> a
(!) mat (i,j) | i < 0 || i >= r || j < 0 || j >= c =
error "Index out of bounds"
| otherwise = unsafeIndex mat (i,j)
where
(r,c) = dim mat
{-# INLINE (!) #-}
toList :: Matrix m v a => m v a -> [a]
toList = G.toList . flatten
{-# INLINE toList #-}
empty :: Matrix m v a => m v a
empty = fromVector (0,0) G.empty
{-# INLINE empty #-}
fromVector :: Matrix m v a => (Int, Int) -> v a -> m v a
fromVector (r,c) vec | r*c /= n = error errMsg
| otherwise = unsafeFromVector (r,c) vec
where
errMsg = printf "fromVector: incorrect length (%d * %d != %d)" r c n
n = G.length vec
{-# INLINE fromVector #-}
fromList :: Matrix m v a => (Int, Int) -> [a] -> m v a
fromList (r,c) = fromVector (r,c) . G.fromList
{-# INLINE fromList #-}
matrix :: Matrix m v a
=> Int
-> [a]
-> m v a
matrix ncol xs | n `mod` ncol /= 0 = error "incorrect length"
| otherwise = unsafeFromVector (nrow,ncol) vec
where
vec = G.fromList xs
nrow = n `div` ncol
n = G.length vec
{-# INLINE matrix #-}
fromLists :: Matrix m v a => [[a]] -> m v a
fromLists xs | null xs = empty
| otherwise = fromVector (r,c) . G.fromList . concat $ xs
where
r = length xs
c = length . head $ xs
{-# INLINE fromLists #-}
fromRows :: Matrix m v a => [v a] -> m v a
fromRows xs | null xs = empty
| otherwise = fromVector (r,c) . G.concat $ xs
where
r = length xs
c = G.length . head $ xs
{-# INLINE fromRows #-}
takeRow :: Matrix m v a => m v a -> Int -> v a
takeRow mat i | i < 0 || i >= r =
error $ printf "index out of bounds: (%d,%d)" i r
| otherwise = unsafeTakeRow mat i
where
(r,_) = dim mat
{-# INLINE takeRow #-}
toRows :: Matrix m v a => m v a -> [v a]
toRows mat = map (unsafeTakeRow mat) [0..r-1]
where
(r,_) = dim mat
{-# INLINE toRows #-}
takeColumn :: Matrix m v a => m v a -> Int -> v a
takeColumn mat j | j < 0 || j >= c =
error $ printf "index out of bounds: (%d,%d)" j c
| otherwise = unsafeTakeColumn mat j
where
(_,c) = dim mat
{-# INLINE takeColumn #-}
toColumns :: Matrix m v a => m v a -> [v a]
toColumns mat = map (unsafeTakeColumn mat) [0..c-1]
where
(_,c) = dim mat
{-# INLINE toColumns #-}
toLists :: Matrix m v a => m v a -> [[a]]
toLists = map G.toList . toRows
{-# INLINE toLists #-}
create :: Matrix m v a => (forall s . ST s ((Mutable m) (G.Mutable v) s a)) -> m v a
create m = runST $ unsafeFreeze =<< m
{-# INLINE create #-}