module Data.Sized.Matrix where
import Prelude as P hiding (all)
import Control.Applicative
import qualified Data.Traversable as T
import qualified Data.Foldable as F
import qualified Data.List as L hiding (all)
import Data.Array.Base as B
import Data.Array.IArray as I
import GHC.TypeLits
import Data.Typeable
import Numeric
import Data.Sized.Fin
newtype Matrix ix a = Matrix (Array ix a)
deriving (Typeable, Eq, Ord)
type Vector (ix :: Nat) a = Matrix (Fin ix) a
type Vector2 (ix :: Nat) (iy :: Nat) a = Matrix (Fin ix,Fin iy) a
instance (Ix ix) => Functor (Matrix ix) where
fmap f (Matrix xs) = Matrix (fmap f xs)
instance IArray Matrix a where
bounds (Matrix arr) = B.bounds arr
numElements (Matrix arr) = B.numElements arr
unsafeArray (a,b) ass = Matrix $ B.unsafeArray (a,b) ass
unsafeAt (Matrix arr) i = B.unsafeAt arr i
instance (Bounded i, Ix i) => Applicative (Matrix i) where
pure a = fmap (const a) coord
a <*> b = forAll $ \ i -> (a ! i) (b ! i)
matrix :: forall i a . (Bounded i, Ix i) => [a] -> Matrix i a
matrix xs | size' == fromIntegral (L.length xs) = I.listArray (low,high) xs
| otherwise = error $ "bad length of fromList for Matrix, "
++ "expecting " ++ show size' ++ " elements"
++ ", found " ++ show (L.length xs) ++ " elements."
where
size' = rangeSize (low,high)
low :: i
low = minBound
high :: i
high = maxBound
population :: forall i a . (Bounded i, Ix i) => Matrix i a -> Int
population _ = rangeSize (minBound :: i,maxBound)
allIndices :: (Bounded i, Ix i) => Matrix i a -> [i]
allIndices _ = universe
zeroOf :: (Bounded i, Ix i) => Matrix i a -> i
zeroOf _ = minBound
coord :: (Bounded i, Ix i) => Matrix i i
coord = matrix universe
zipWith :: (Bounded i, Ix i) => (a -> b -> c) -> Matrix i a -> Matrix i b -> Matrix i c
zipWith f a b = forAll $ \ i -> f (a ! i) (b ! i)
forEach :: (Bounded i, Ix i) => Matrix i a -> (i -> a -> b) -> Matrix i b
forEach a f = Data.Sized.Matrix.zipWith f coord a
forAll :: (Bounded i, Ix i) => (i -> a) -> Matrix i a
forAll f = fmap f coord
mm :: (Bounded m, Ix m, Bounded n, Ix n, Bounded o, Ix o, Num a) => Matrix (m,n) a -> Matrix (n,o) a -> Matrix (m,o) a
mm a b = forAll $ \ (i,j) -> sum [ a ! (i,r) * b ! (r,j) | r <- universe ]
transpose :: (Bounded x, Ix x, Bounded y, Ix y) => Matrix (x,y) a -> Matrix (y,x) a
transpose = ixmap corners $ \ (x,y) -> (y,x)
identity :: (Bounded x, Ix x, Num a) => Matrix (x,x) a
identity = (\ (x,y) -> if x == y then 1 else 0) <$> coord
append :: (SingI left, SingI right, SingI (left + right))
=> Vector left a -> Vector right a -> Vector (left + right) a
append m1 m2 = matrix (I.elems m1 ++ I.elems m2)
above :: (SingI top, SingI bottom, SingI y, SingI (top + bottom))
=> Vector2 top y a -> Vector2 bottom y a -> Vector2 (top + bottom) y a
above m1 m2 = matrix (I.elems m1 ++ I.elems m2)
beside :: (SingI left, SingI right, SingI x, SingI (left + right))
=> Vector2 x left a -> Vector2 x right a -> Vector2 x (left + right) a
beside m1 m2 = transpose (transpose m1 `above` transpose m2)
ixfmap :: (Bounded i, Ix i, Bounded j, Ix j, Functor f) => (i -> f j) -> Matrix j a -> Matrix i (f a)
ixfmap f m = (fmap (\ j -> m ! j) . f) <$> coord
rows :: (Bounded n, Ix n, Bounded m, Ix m) => Matrix (m,n) a -> Matrix m (Matrix n a)
rows a = (\ m -> matrix [ a ! (m,n) | n <- universe ]) <$> coord
columns :: (Bounded n, Ix n, Bounded m, Ix m) => Matrix (m,n) a -> Matrix n (Matrix m a)
columns = rows . transpose
joinRows :: (Bounded n, Ix n, Bounded m, Ix m) => Matrix m (Matrix n a) -> Matrix (m,n) a
joinRows a = (\ (m,n) -> (a ! m) ! n) <$> coord
joinColumns :: (Bounded n, Ix n, Bounded m, Ix m) => Matrix n (Matrix m a) -> Matrix (m,n) a
joinColumns a = (\ (m,n) -> (a ! n) ! m) <$> coord
instance (Bounded ix, Ix ix) => T.Traversable (Matrix ix) where
traverse f a = matrix <$> (T.traverse f $ I.elems a)
instance (Bounded ix, Ix ix) => F.Foldable (Matrix ix) where
foldMap f m = F.foldMap f (I.elems m)
show2D :: (Bounded n, Ix n, Bounded m, Ix m, Show a) => Matrix (m, n) a -> String
show2D m0 = (joinLines $ map showRow m_rows)
where
m = fmap show m0
m' = forEach m $ \ (x,y) a -> (x == maxBound && y == maxBound,a)
joinLines = unlines . addTail . L.zipWith (++) ("[":repeat " ")
addTail xs = init xs ++ [last xs ++ " ]"]
showRow r = concat (I.elems $ Data.Sized.Matrix.zipWith showEle r m_cols_size)
showEle (f,str) s = take (s L.length str) (cycle " ") ++ " " ++ str ++ (if f then "" else ",")
m_cols = columns m
m_rows = I.elems $ rows m'
m_cols_size = fmap (maximum . map L.length . I.elems) m_cols
instance (Show a, Show ix, Bounded ix, Ix ix) => Show (Matrix ix a) where
show m = "matrix " ++ show (I.bounds m) ++ " " ++ show (I.elems m)
newtype S = S String
instance Show S where
show (S s) = s
showAsE :: (RealFloat a) => Int -> a -> S
showAsE i a = S $ showEFloat (Just i) a ""
showAsF :: (RealFloat a) => Int -> a -> S
showAsF i a = S $ showFFloat (Just i) a ""