// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2008-2010 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_SPARSEDENSEPRODUCT_H #define EIGEN_SPARSEDENSEPRODUCT_H namespace Eigen { template struct SparseDenseProductReturnType { typedef SparseTimeDenseProduct Type; }; template struct SparseDenseProductReturnType { typedef typename internal::conditional< Lhs::IsRowMajor, SparseDenseOuterProduct, SparseDenseOuterProduct >::type Type; }; template struct DenseSparseProductReturnType { typedef DenseTimeSparseProduct Type; }; template struct DenseSparseProductReturnType { typedef typename internal::conditional< Rhs::IsRowMajor, SparseDenseOuterProduct, SparseDenseOuterProduct >::type Type; }; namespace internal { template struct traits > { typedef Sparse StorageKind; typedef typename scalar_product_traits::Scalar, typename traits::Scalar>::ReturnType Scalar; typedef typename Lhs::Index Index; typedef typename Lhs::Nested LhsNested; typedef typename Rhs::Nested RhsNested; typedef typename remove_all::type _LhsNested; typedef typename remove_all::type _RhsNested; enum { LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost, RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost, RowsAtCompileTime = Tr ? int(traits::RowsAtCompileTime) : int(traits::RowsAtCompileTime), ColsAtCompileTime = Tr ? int(traits::ColsAtCompileTime) : int(traits::ColsAtCompileTime), MaxRowsAtCompileTime = Tr ? int(traits::MaxRowsAtCompileTime) : int(traits::MaxRowsAtCompileTime), MaxColsAtCompileTime = Tr ? int(traits::MaxColsAtCompileTime) : int(traits::MaxColsAtCompileTime), Flags = Tr ? RowMajorBit : 0, CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits::MulCost }; }; } // end namespace internal template class SparseDenseOuterProduct : public SparseMatrixBase > { public: typedef SparseMatrixBase Base; EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct) typedef internal::traits Traits; private: typedef typename Traits::LhsNested LhsNested; typedef typename Traits::RhsNested RhsNested; typedef typename Traits::_LhsNested _LhsNested; typedef typename Traits::_RhsNested _RhsNested; public: class InnerIterator; EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); } EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs) : m_lhs(lhs), m_rhs(rhs) { EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); } EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); } EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); } EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } protected: LhsNested m_lhs; RhsNested m_rhs; }; template class SparseDenseOuterProduct::InnerIterator : public _LhsNested::InnerIterator { typedef typename _LhsNested::InnerIterator Base; typedef typename SparseDenseOuterProduct::Index Index; public: EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits::StorageKind() )) { } inline Index outer() const { return m_outer; } inline Index row() const { return Transpose ? m_outer : Base::index(); } inline Index col() const { return Transpose ? Base::index() : m_outer; } inline Scalar value() const { return Base::value() * m_factor; } protected: static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) { return rhs.coeff(outer); } static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse()) { typename Traits::_RhsNested::InnerIterator it(rhs, outer); if (it && it.index()==0) return it.value(); return Scalar(0); } Index m_outer; Scalar m_factor; }; namespace internal { template struct traits > : traits, Lhs, Rhs> > { typedef Dense StorageKind; typedef MatrixXpr XprKind; }; template struct sparse_time_dense_product_impl; template struct sparse_time_dense_product_impl { typedef typename internal::remove_all::type Lhs; typedef typename internal::remove_all::type Rhs; typedef typename internal::remove_all::type Res; typedef typename Lhs::Index Index; typedef typename Lhs::InnerIterator LhsInnerIterator; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { for(Index c=0; c struct sparse_time_dense_product_impl { typedef typename internal::remove_all::type Lhs; typedef typename internal::remove_all::type Rhs; typedef typename internal::remove_all::type Res; typedef typename Lhs::InnerIterator LhsInnerIterator; typedef typename Lhs::Index Index; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { for(Index c=0; c struct sparse_time_dense_product_impl { typedef typename internal::remove_all::type Lhs; typedef typename internal::remove_all::type Rhs; typedef typename internal::remove_all::type Res; typedef typename Lhs::InnerIterator LhsInnerIterator; typedef typename Lhs::Index Index; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { for(Index j=0; j struct sparse_time_dense_product_impl { typedef typename internal::remove_all::type Lhs; typedef typename internal::remove_all::type Rhs; typedef typename internal::remove_all::type Res; typedef typename Lhs::InnerIterator LhsInnerIterator; typedef typename Lhs::Index Index; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { for(Index j=0; j inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) { sparse_time_dense_product_impl::run(lhs, rhs, res, alpha); } } // end namespace internal template class SparseTimeDenseProduct : public ProductBase, Lhs, Rhs> { public: EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct) SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} template void scaleAndAddTo(Dest& dest, const Scalar& alpha) const { internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha); } private: SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&); }; // dense = dense * sparse namespace internal { template struct traits > : traits, Lhs, Rhs> > { typedef Dense StorageKind; }; } // end namespace internal template class DenseTimeSparseProduct : public ProductBase, Lhs, Rhs> { public: EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct) DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} template void scaleAndAddTo(Dest& dest, const Scalar& alpha) const { Transpose lhs_t(m_lhs); Transpose rhs_t(m_rhs); Transpose dest_t(dest); internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha); } private: DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&); }; } // end namespace Eigen #endif // EIGEN_SPARSEDENSEPRODUCT_H