module Data.Sparse.SpVector where
import Control.Exception
import Control.Monad.Catch (MonadThrow (..))
import Control.Exception.Common
import GHC.Exts
import Data.Sparse.Utils
import Data.Sparse.Types
import Data.Sparse.Internal.IntM
import Numeric.Eps
import Numeric.LinearAlgebra.Class
import Data.Complex
import Data.Maybe
import qualified Data.IntMap.Strict as IM
import qualified Data.Foldable as F
import qualified Data.Vector as V
import Data.VectorSpace hiding (magnitude)
data SpVector a = SV { svDim :: !Int ,
svData :: !(IntM a)} deriving Eq
instance Show a => Show (SpVector a) where
show (SV d x) = "SV (" ++ show d ++ ") "++ show (toList x)
spySV :: Fractional b => SpVector a -> b
spySV s = fromIntegral (size (dat s)) / fromIntegral (dim s)
nzSV :: SpVector a -> Int
nzSV sv = size (dat sv)
sizeStrSV :: SpVector a -> String
sizeStrSV sv = unwords ["(",show (dim sv),"elements ) , ",show (nzSV sv),"NZ ( sparsity", show (spy sv),")"]
instance Functor SpVector where
fmap f (SV n x) = SV n (fmap f x)
instance Set SpVector where
liftU2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftU2 f2 x1 x2)
liftI2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftI2 f2 x1 x2)
instance Foldable SpVector where
foldr f d v = F.foldr f d (svData v)
instance FiniteDim SpVector where
type FDSize SpVector = Int
dim = svDim
instance HasData SpVector a where
type HDData SpVector a = IntM a
dat = svData
nnz (SV _ x) = length x
instance Sparse SpVector a where
spy = spySV
instance Elt a => SpContainer SpVector a where
type ScIx SpVector = Int
scInsert = insertSpVector
scLookup v i = lookupSV i v
scToList = toListSV
v @@ i = lookupDenseSV i v
#define SpVectorInstance(t) \
instance AdditiveGroup (SpVector (t)) where { zeroV = SV 0 empty; (^+^) = liftU2 (+); negateV = fmap negate };\
instance AdditiveGroup (SpVector (Complex t)) where { zeroV = SV 0 empty; (^+^) = liftU2 (+); negateV = fmap negate };\
instance VectorSpace (SpVector t) where { type (Scalar (SpVector t)) = t; n *^ v = scale n v};\
instance VectorSpace (SpVector (Complex t)) where { type (Scalar (SpVector (Complex t))) = Complex t; n *^ v = scale n v};\
instance InnerSpace (SpVector (t)) where { (<.>) = dotS };\
instance InnerSpace (SpVector (Complex (t))) where { (<.>) = dotS };\
instance Normed (SpVector (t)) where {type RealScalar (SpVector (t)) = t; type Magnitude (SpVector (t)) = t; norm1 (SV _ v) = norm1 v; norm2Sq (SV _ v) = norm2Sq v ; normP p (SV _ v) = normP p v; normalize p (SV n v) = SV n (normalize p v); normalize2 (SV n v) = SV n (normalize2 v)};\
instance Normed (SpVector (Complex t)) where {type RealScalar (SpVector (Complex t)) = t; type Magnitude (SpVector (Complex t)) = t; norm1 (SV _ v) = norm1 v; norm2Sq (SV _ v) = norm2Sq v ; normP p (SV _ v) = normP p v; normalize p (SV n v) = SV n (normalize p v); normalize2 (SV n v) = SV n (normalize2 v)}
SpVectorInstance(Double)
dotS :: InnerSpace (IntM t) => SpVector t -> SpVector t -> Scalar (IntM t)
(SV m a) `dotS` (SV n b)
| n == m = a <.> b
| otherwise = error $ unwords ["<.> : Incompatible dimensions:", show m, show n]
dotSSafe :: (InnerSpace (IntM t), MonadThrow m) =>
SpVector t -> SpVector t -> m (Scalar (IntM t))
dotSSafe (SV m a) (SV n b)
| n == m = return $ a <.> b
| otherwise = throwM (DotSizeMismatch m n)
zeroSV :: Int -> SpVector a
zeroSV n = SV n empty
singletonSV :: a -> SpVector a
singletonSV x = SV 1 (singleton 0 x)
ei :: Num a => Int -> IM.Key -> SpVector a
ei n i = SV n (insert (i 1) 1 empty)
mkSpVector :: Epsilon a => Int -> IM.IntMap a -> SpVector a
mkSpVector d im = SV d $ IntM $ IM.filterWithKey (\k v -> isNz v && inBounds0 d k) im
mkSpVector1 :: Int -> IM.IntMap a -> SpVector a
mkSpVector1 d ll = SV d $ IntM $ IM.filterWithKey (\ k _ -> inBounds0 d k) ll
mkSpVR :: Int -> [Double] -> SpVector Double
mkSpVR d ll = SV d $ mkIm ll
mkSpVC :: Int -> [Complex Double] -> SpVector (Complex Double)
mkSpVC d ll = SV d $ mkImC ll
fromListDenseSV :: Int -> [a] -> SpVector a
fromListDenseSV d ll = SV d (fromList $ denseIxArray (take d ll))
spVectorDenseIx :: Epsilon a => (Int -> a) -> UB -> [Int] -> SpVector a
spVectorDenseIx f n ix =
fromListSV n $ filter q $ zip ix $ map f ix where
q (i, v) = inBounds0 n i && isNz v
spVectorDenseLoHi :: Epsilon a => (Int -> a) -> UB -> Int -> Int -> SpVector a
spVectorDenseLoHi f n lo hi = spVectorDenseIx f n [lo .. hi]
oneHotSVU :: Num a => Int -> IxRow -> SpVector a
oneHotSVU n k = SV n (singleton k 1)
oneHotSV :: Num a => Int -> IxRow -> SpVector a
oneHotSV n k |inBounds0 n k = oneHotSVU n k
|otherwise = error "`oneHotSV n k` must satisfy 0 <= k <= n"
onesSV :: Num a => Int -> SpVector a
onesSV d = SV d $ fromList $ denseIxArray $ replicate d 1
zerosSV :: Num a => Int -> SpVector a
zerosSV d = SV d $ fromList $ denseIxArray $ replicate d 0
fromVector :: V.Vector a -> SpVector a
fromVector qv = V.ifoldl' ins (zeroSV n) qv where
n = V.length qv
ins vv i x = insertSpVector i x vv
toVector :: SpVector a -> V.Vector a
toVector = V.fromList . snd . unzip . toListSV
toVectorDense :: Num a => SpVector a -> V.Vector a
toVectorDense = V.fromList . toDenseListSV
insertSpVector :: IM.Key -> a -> SpVector a -> SpVector a
insertSpVector i x (SV d xim) | inBounds0 d i = SV d (insert i x xim)
insertSpVectorSafe :: MonadThrow m => Int -> a -> SpVector a -> m (SpVector a)
insertSpVectorSafe i x (SV d xim)
| inBounds0 d i = return $ SV d (insert i x xim)
| otherwise = throwM (OOBIxError "insertSpVector" i)
fromListSV :: Foldable t => Int -> t (Int, a) -> SpVector a
fromListSV d iix = SV d $ foldr insf empty iix where
insf (i, x) xacc | inBounds0 d i = insert i x xacc
| otherwise = xacc
toListSV :: SpVector a -> [(IM.Key, a)]
toListSV sv = toList (dat sv)
toDenseListSV :: Num b => SpVector b -> [b]
toDenseListSV (SV d (IntM im)) = fmap (\i -> IM.findWithDefault 0 i im) [0 .. d1]
ifoldSV :: (IM.Key -> a -> b -> b) -> b -> SpVector a -> b
ifoldSV f e (SV _ (IntM im)) = IM.foldrWithKey f e im
lookupSV :: IM.Key -> SpVector a -> Maybe a
lookupSV i (SV _ (IntM im)) = IM.lookup i im
lookupDefaultSV :: a -> IM.Key -> SpVector a -> a
lookupDefaultSV def i (SV _ (IntM im)) = IM.findWithDefault def i im
lookupDenseSV :: Num a => IM.Key -> SpVector a -> a
lookupDenseSV = lookupDefaultSV 0
tailSV :: SpVector a -> SpVector a
tailSV (SV n (IntM sv)) = SV (n1) $ IntM ta where
ta = IM.mapKeys (\i -> i 1) $ IM.delete 0 sv
headSV :: Num a => SpVector a -> a
headSV (SV _ (IntM im)) = fromMaybe 0 (IM.lookup 0 im)
takeSV, dropSV :: Int -> SpVector a -> SpVector a
takeSV n (SV _ sv) = SV n $ filterWithKey (\i _ -> i < n) sv
dropSV n (SV n0 (IntM sv)) = SV (n0 n) $ IntM $ IM.mapKeys (subtract n) $ IM.filterWithKey (\i _ -> i >= n) sv
rangeSV :: (IM.Key, IM.Key) -> SpVector a -> SpVector a
rangeSV (rmin, rmax) (SV n (IntM sv))
| len > 0 && len <= n = SV len $ IntM sv'
| otherwise = error $ unwords ["rangeSV : invalid bounds", show (rmin, rmax) ] where
len = rmax rmin
sv' = IM.mapKeys (subtract rmin) $ IM.filterWithKey (\i _ -> i >= rmin && i <= rmax) sv
concatSV :: SpVector a -> SpVector a -> SpVector a
concatSV (SV n1 (IntM s1)) (SV n2 (IntM s2)) = SV (n1+n2) $ IntM (IM.union s1 s2') where
s2' = IM.mapKeys (+ n1) s2
filterSV :: (a -> Bool) -> SpVector a -> SpVector a
filterSV q sv = SV (dim sv) $ IntM (IM.filter q (unIM $ dat sv))
ifilterSV :: (Int -> a -> Bool) -> SpVector a -> SpVector a
ifilterSV q sv = SV (dim sv) (filterWithKey q (dat sv))
sparsifySV :: Epsilon a => SpVector a -> SpVector a
sparsifySV = filterSV isNz
orthogonalSV :: (Scalar (SpVector t) ~ t, InnerSpace (SpVector t), Fractional t) =>
SpVector t -> SpVector t
orthogonalSV v = u where
(h, t) = (headSV v, tailSV v)
n = dim v
v2 = onesSV (n 1)
yn = singletonSV $ (v2 `dot` t)/h
u = concatSV yn v2