module Numerical.HBLAS.MatrixTypes where
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as SM
import Control.Monad.Primitive
import Data.Typeable
data Orientation = Row | Column
deriving (Eq,Show,Typeable)
type Row = 'Row
type Column = 'Column
data SOrientation :: Orientation -> * where
SRow :: SOrientation Row
SColumn :: SOrientation Column
#if defined(__GLASGOW_HASKELL_) && (__GLASGOW_HASKELL__ >= 707)
deriving (Typeable)
#endif
instance Show (SOrientation Row) where
show _ = "SRow"
instance Show (SOrientation Column) where
show _ = "SColumn"
instance Eq (SOrientation Row) where
(==) _ _ = True
instance Eq (SOrientation Column) where
(==) _ _ = True
sTranpose :: (x~ TransposeF y, y~TransposeF x ) =>SOrientation x -> SOrientation y
sTranpose SColumn = SRow
sTranpose SRow = SColumn
data Transpose = NoTranspose | Transpose | ConjTranspose | ConjNoTranspose
deriving(Typeable,Eq,Show)
data MatUpLo = MatUpper | MatLower
deriving(Typeable,Eq,Show)
data MatDiag= MatUnit | MatNonUnit
deriving(Typeable,Eq,Show)
data EquationSide = LeftSide | RightSide
deriving(Typeable,Eq,Show)
type family TransposeF (x :: Orientation) :: Orientation
type instance TransposeF Row = Column
type instance TransposeF Column = Row
data Variant = Direct | Implicit
deriving(Typeable,Eq,Show)
data SVariant :: Variant -> * where
SImplicit :: {_frontPadding :: !Int, _endPadding:: !Int } -> SVariant Implicit
SDirect :: SVariant Direct
data DenseVector :: Variant -> * -> * where
DenseVector :: { _VariantDenseVect :: !(SVariant varnt)
,_LogicalDimDenseVector :: !Int
,_StrideDenseVector :: ! Int
,_bufferDenseVector :: !(S.Vector elem)
} -> DenseVector varnt elem
#if defined(__GLASGOW_HASKELL_) && (__GLASGOW_HASKELL__ >= 707)
deriving (Typeable)
#endif
data MDenseVector :: * -> Variant -> * -> * where
MutableDenseVector :: { _VariantMutDenseVect :: !(SVariant varnt)
,_LogicalDimMutDenseVector :: !Int
,_StrideMutDenseVector :: ! Int
,_bufferMutDenseVector :: !(S.MVector s elem)
} -> MDenseVector s varnt elem
#if defined(__GLASGOW_HASKELL_) && (__GLASGOW_HASKELL__ >= 707)
deriving (Typeable)
#endif
data DenseMatrix :: Orientation -> * -> * where
DenseMatrix ::{ _OrientationMat :: SOrientation ornt ,
_XdimDenMat :: !Int,
_YdimDenMat :: !Int ,
_StrideDenMat :: !Int ,
_bufferDenMat :: !(S.Vector elem) }-> DenseMatrix ornt elem
#if defined(__GLASGOW_HASKELL_) && (__GLASGOW_HASKELL__ >= 707)
deriving (Typeable)
#endif
mutableVectorToList :: (PrimMonad m, S.Storable a) => S.MVector (PrimState m) a -> m [a]
mutableVectorToList mv = do
v <- S.unsafeFreeze mv
return (S.toList v )
instance (Show el,SM.Storable el )=> Show (DenseMatrix Row el) where
show mat@(DenseMatrix SRow xdim ydim stride buffer)
| stride == xdim = "DenseMatrix SRow " ++ " " ++show xdim ++ " " ++ show ydim ++ " " ++ show stride ++ "(" ++ show buffer ++ ")"
| otherwise = show $ mapDenseMatrix id mat
instance (Show el,SM.Storable el )=> Show (DenseMatrix Column el) where
show mat@(DenseMatrix SColumn xdim ydim stride buffer)
| stride == ydim = "DenseMatrix SColumn " ++ " " ++show xdim ++ " " ++ show ydim ++ " " ++ show stride ++ "(" ++ show buffer ++ ")"
| otherwise = show $ mapDenseMatrix id mat
data MDenseMatrix :: * ->Orientation -> * -> * where
MutableDenseMatrix :: { _OrientationMutMat :: SOrientation ornt ,
_XdimDenMutMat :: !Int ,
_YdimDenMutMat :: !Int,
_StrideDenMutMat :: !Int,
_bufferDenMutMat :: !(SM.MVector s elem) } -> MDenseMatrix s ornt elem
type IODenseMatrix = MDenseMatrix RealWorld
unsafeFreezeDenseMatrix :: (SM.Storable elem, PrimMonad m)=> MDenseMatrix (PrimState m) or elem -> m (DenseMatrix or elem)
unsafeFreezeDenseMatrix (MutableDenseMatrix ornt a b c mv) = do
v <- S.unsafeFreeze mv
return $! DenseMatrix ornt a b c v
unsafeThawDenseMatrix :: (SM.Storable elem, PrimMonad m)=> DenseMatrix or elem-> m (MDenseMatrix (PrimState m) or elem)
unsafeThawDenseMatrix (DenseMatrix ornt a b c v) = do
mv <- S.unsafeThaw v
return $! MutableDenseMatrix ornt a b c mv
--freezeDenseMatrix
getDenseMatrixRow :: DenseMatrix or elem -> Int
getDenseMatrixRow (DenseMatrix _ _ ydim _ _)= ydim
getDenseMatrixColumn :: DenseMatrix or elem -> Int
getDenseMatrixColumn (DenseMatrix _ xdim _ _ _)= xdim
getDenseMatrixLeadingDimStride :: DenseMatrix or elem -> Int
getDenseMatrixLeadingDimStride (DenseMatrix _ _ _ stride _ ) = stride
getDenseMatrixArray :: DenseMatrix or elem -> S.Vector elem
getDenseMatrixArray (DenseMatrix _ _ _ _ arr) = arr
getDenseMatrixOrientation :: DenseMatrix or elem -> SOrientation or
getDenseMatrixOrientation m = _OrientationMat m
uncheckedDenseMatrixIndex :: (S.Storable elem )=> DenseMatrix or elem -> (Int,Int) -> elem
uncheckedDenseMatrixIndex (DenseMatrix SRow _ _ ystride arr) = \ (x,y)-> arr `S.unsafeIndex` (x + y * ystride)
uncheckedDenseMatrixIndex (DenseMatrix SColumn _ _ xstride arr) = \ (x,y)-> arr `S.unsafeIndex` (y + x* xstride)
uncheckedDenseMatrixIndexM :: (Monad m ,S.Storable elem )=> DenseMatrix or elem -> (Int,Int) -> m elem
uncheckedDenseMatrixIndexM (DenseMatrix SRow _ _ ystride arr) = \ (x,y)-> return $! arr `S.unsafeIndex` (x + y * ystride)
uncheckedDenseMatrixIndexM (DenseMatrix SColumn _ _ xstride arr) = \ (x,y)-> return $! arr `S.unsafeIndex` (y + x* xstride)
uncheckedMutableDenseMatrixIndexM :: (PrimMonad m ,S.Storable elem )=> MDenseMatrix (PrimState m) or elem -> (Int,Int) -> m elem
uncheckedMutableDenseMatrixIndexM (MutableDenseMatrix SRow _ _ ystride arr) = \ (x,y)-> arr `SM.unsafeRead` (x + y * ystride)
uncheckedMutableDenseMatrixIndexM (MutableDenseMatrix SColumn _ _ xstride arr) = \ (x,y)-> arr `SM.unsafeRead` (y + x* xstride)
swap :: (a,b)->(b,a)
swap = \ (!x,!y)-> (y,x)
mapDenseMatrix :: (S.Storable a, S.Storable b) => (a->b) -> DenseMatrix or a -> DenseMatrix or b
mapDenseMatrix f rm@(DenseMatrix SRow xdim ydim _ _) =
DenseMatrix SRow xdim ydim xdim $!
S.generate (xdim * ydim) (\ix -> f $! uncheckedDenseMatrixIndex rm (swap $ quotRem ix xdim ) )
mapDenseMatrix f rm@(DenseMatrix SColumn xdim ydim _ _) =
DenseMatrix SColumn xdim ydim ydim $!
S.generate (xdim * ydim ) (\ix -> f $! uncheckedDenseMatrixIndex rm ( quotRem ix ydim ) )
imapDenseMatrix :: (S.Storable a, S.Storable b) => ((Int,Int)->a->b) -> DenseMatrix or a -> DenseMatrix or b
imapDenseMatrix f rm@(DenseMatrix sornt xdim ydim _ _) =
generateDenseMatrix sornt (xdim,ydim) (\ix -> f ix $! uncheckedDenseMatrixIndex rm ix )
uncheckedDenseMatrixNextTuple :: DenseMatrix or elem -> (Int,Int) -> Maybe (Int,Int)
uncheckedDenseMatrixNextTuple (DenseMatrix SRow xdim ydim _ _) =
\(!x,!y)-> if (x >= xdim && y >= ydim) then Nothing else Just $! swap $! quotRem (x+ xdim * y + 1) xdim
uncheckedDenseMatrixNextTuple (DenseMatrix SColumn xdim ydim _ _ ) =
\(!x,!y) -> if (x >= xdim && y >= ydim) then Nothing else Just $! quotRem (y + ydim * x + 1) ydim
generateDenseMatrix :: (S.Storable a)=> SOrientation x -> (Int,Int)->((Int,Int)-> a) -> DenseMatrix x a
generateDenseMatrix SRow (xdim,ydim) f = DenseMatrix SRow xdim ydim xdim $!
S.generate (xdim * ydim) (\ix -> let !ixtup@(!_,!_) = swap $ quotRem ix xdim in
f ixtup )
generateDenseMatrix SColumn (xdim,ydim) f = DenseMatrix SColumn xdim ydim ydim $!
S.generate (xdim * ydim ) (\ix -> let ixtup@(!_,!_) = ( quotRem ix ydim ) in
f ixtup )
generateMutableDenseMatrix :: (S.Storable a,PrimMonad m)=>
SOrientation x -> (Int,Int)->((Int,Int)-> a) -> m (MDenseMatrix (PrimState m) x a)
generateMutableDenseMatrix sor dims fun = do
x <- unsafeThawDenseMatrix $! generateDenseMatrix sor dims fun
return x
generateMutableDenseVector :: (S.Storable a,PrimMonad m) => Int -> (Int -> a) ->
m (MDenseVector (PrimState m ) Direct a)
generateMutableDenseVector size init = do
mv <- S.unsafeThaw $ S.generate size init
return $! MutableDenseVector SDirect size 1 mv
uncheckedDenseMatrixSlice :: (S.Storable elem)=> DenseMatrix or elem -> (Int,Int)-> (Int,Int)-> DenseMatrix or elem
uncheckedDenseMatrixSlice (DenseMatrix SRow xdim _ ystride arr) (xstart,ystart) (xend,yend) = res
where !res = DenseMatrix SRow (xend xstart + 1)
(yend ystart+1)
(ystride + xstart + (xdim xend))
(S.slice ixStart (ixEnd ixStart) arr )
!ixStart = (xstart+ystart*ystride)
!ixEnd = (xend+yend*ystride)
uncheckedDenseMatrixSlice (DenseMatrix SColumn _ ydim xstride arr) (xstart,ystart) (xend,yend) = res
where !res = DenseMatrix SColumn (xend xstart + 1)
(yend ystart+1)
(xstride + ystart + (ydim yend))
(S.slice ixStart (ixEnd ixStart) arr )
!ixStart = (ystart+xstart*xstride)
!ixEnd = (yend+xend*xstride)
transposeDenseMatrix :: (inor ~ (TransposeF outor) , outor ~ (TransposeF inor) ) => DenseMatrix inor elem -> DenseMatrix outor elem
transposeDenseMatrix (DenseMatrix orient x y stride arr)= (DenseMatrix (sTranpose orient) y x stride arr)