// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2009-2010 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_PRODUCTBASE_H #define EIGEN_PRODUCTBASE_H namespace Eigen { /** \class ProductBase * \ingroup Core_Module * */ namespace internal { template struct traits > { typedef MatrixXpr XprKind; typedef typename remove_all<_Lhs>::type Lhs; typedef typename remove_all<_Rhs>::type Rhs; typedef typename scalar_product_traits::ReturnType Scalar; typedef typename promote_storage_type::StorageKind, typename traits::StorageKind>::ret StorageKind; typedef typename promote_index_type::Index, typename traits::Index>::type Index; enum { RowsAtCompileTime = traits::RowsAtCompileTime, ColsAtCompileTime = traits::ColsAtCompileTime, MaxRowsAtCompileTime = traits::MaxRowsAtCompileTime, MaxColsAtCompileTime = traits::MaxColsAtCompileTime, Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, // Note that EvalBeforeNestingBit and NestByRefBit // are not used in practice because nested is overloaded for products CoeffReadCost = 0 // FIXME why is it needed ? }; }; } #define EIGEN_PRODUCT_PUBLIC_INTERFACE(Derived) \ typedef ProductBase Base; \ EIGEN_DENSE_PUBLIC_INTERFACE(Derived) \ typedef typename Base::LhsNested LhsNested; \ typedef typename Base::_LhsNested _LhsNested; \ typedef typename Base::LhsBlasTraits LhsBlasTraits; \ typedef typename Base::ActualLhsType ActualLhsType; \ typedef typename Base::_ActualLhsType _ActualLhsType; \ typedef typename Base::RhsNested RhsNested; \ typedef typename Base::_RhsNested _RhsNested; \ typedef typename Base::RhsBlasTraits RhsBlasTraits; \ typedef typename Base::ActualRhsType ActualRhsType; \ typedef typename Base::_ActualRhsType _ActualRhsType; \ using Base::m_lhs; \ using Base::m_rhs; template class ProductBase : public MatrixBase { public: typedef MatrixBase Base; EIGEN_DENSE_PUBLIC_INTERFACE(ProductBase) typedef typename Lhs::Nested LhsNested; typedef typename internal::remove_all::type _LhsNested; typedef internal::blas_traits<_LhsNested> LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; typedef typename internal::remove_all::type _ActualLhsType; typedef typename internal::traits::Scalar LhsScalar; typedef typename Rhs::Nested RhsNested; typedef typename internal::remove_all::type _RhsNested; typedef internal::blas_traits<_RhsNested> RhsBlasTraits; typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; typedef typename internal::remove_all::type _ActualRhsType; typedef typename internal::traits::Scalar RhsScalar; // Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once typedef CoeffBasedProduct FullyLazyCoeffBaseProductType; public: #ifndef EIGEN_NO_MALLOC typedef typename Base::PlainObject BasePlainObject; typedef Matrix DynPlainObject; typedef typename internal::conditional<(BasePlainObject::SizeAtCompileTime==Dynamic) || (BasePlainObject::SizeAtCompileTime*int(sizeof(Scalar)) < int(EIGEN_STACK_ALLOCATION_LIMIT)), BasePlainObject, DynPlainObject>::type PlainObject; #else typedef typename Base::PlainObject PlainObject; #endif ProductBase(const Lhs& a_lhs, const Rhs& a_rhs) : m_lhs(a_lhs), m_rhs(a_rhs) { eigen_assert(a_lhs.cols() == a_rhs.rows() && "invalid matrix product" && "if you wanted a coeff-wise or a dot product use the respective explicit functions"); } inline Index rows() const { return m_lhs.rows(); } inline Index cols() const { return m_rhs.cols(); } template inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,Scalar(1)); } template inline void addTo(Dest& dst) const { scaleAndAddTo(dst,Scalar(1)); } template inline void subTo(Dest& dst) const { scaleAndAddTo(dst,Scalar(-1)); } template inline void scaleAndAddTo(Dest& dst, const Scalar& alpha) const { derived().scaleAndAddTo(dst,alpha); } const _LhsNested& lhs() const { return m_lhs; } const _RhsNested& rhs() const { return m_rhs; } // Implicit conversion to the nested type (trigger the evaluation of the product) operator const PlainObject& () const { m_result.resize(m_lhs.rows(), m_rhs.cols()); derived().evalTo(m_result); return m_result; } const Diagonal diagonal() const { return FullyLazyCoeffBaseProductType(m_lhs, m_rhs); } template const Diagonal diagonal() const { return FullyLazyCoeffBaseProductType(m_lhs, m_rhs); } const Diagonal diagonal(Index index) const { return FullyLazyCoeffBaseProductType(m_lhs, m_rhs).diagonal(index); } // restrict coeff accessors to 1x1 expressions. No need to care about mutators here since this isnt a Lvalue expression typename Base::CoeffReturnType coeff(Index row, Index col) const { #ifdef EIGEN2_SUPPORT return lhs().row(row).cwiseProduct(rhs().col(col).transpose()).sum(); #else EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) eigen_assert(this->rows() == 1 && this->cols() == 1); Matrix result = *this; return result.coeff(row,col); #endif } typename Base::CoeffReturnType coeff(Index i) const { EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) eigen_assert(this->rows() == 1 && this->cols() == 1); Matrix result = *this; return result.coeff(i); } const Scalar& coeffRef(Index row, Index col) const { EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) eigen_assert(this->rows() == 1 && this->cols() == 1); return derived().coeffRef(row,col); } const Scalar& coeffRef(Index i) const { EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) eigen_assert(this->rows() == 1 && this->cols() == 1); return derived().coeffRef(i); } protected: LhsNested m_lhs; RhsNested m_rhs; mutable PlainObject m_result; }; // here we need to overload the nested rule for products // such that the nested type is a const reference to a plain matrix namespace internal { template struct nested, N, PlainObject> { typedef typename GeneralProduct::PlainObject const& type; }; template struct nested, N, PlainObject> { typedef typename GeneralProduct::PlainObject const& type; }; } template class ScaledProduct; // Note that these two operator* functions are not defined as member // functions of ProductBase, because, otherwise we would have to // define all overloads defined in MatrixBase. Furthermore, Using // "using Base::operator*" would not work with MSVC. // // Also note that here we accept any compatible scalar types template const ScaledProduct operator*(const ProductBase& prod, const typename Derived::Scalar& x) { return ScaledProduct(prod.derived(), x); } template typename internal::enable_if::value, const ScaledProduct >::type operator*(const ProductBase& prod, const typename Derived::RealScalar& x) { return ScaledProduct(prod.derived(), x); } template const ScaledProduct operator*(const typename Derived::Scalar& x,const ProductBase& prod) { return ScaledProduct(prod.derived(), x); } template typename internal::enable_if::value, const ScaledProduct >::type operator*(const typename Derived::RealScalar& x,const ProductBase& prod) { return ScaledProduct(prod.derived(), x); } namespace internal { template struct traits > : traits, typename NestedProduct::_LhsNested, typename NestedProduct::_RhsNested> > { typedef typename traits::StorageKind StorageKind; }; } template class ScaledProduct : public ProductBase, typename NestedProduct::_LhsNested, typename NestedProduct::_RhsNested> { public: typedef ProductBase, typename NestedProduct::_LhsNested, typename NestedProduct::_RhsNested> Base; typedef typename Base::Scalar Scalar; typedef typename Base::PlainObject PlainObject; // EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct) ScaledProduct(const NestedProduct& prod, const Scalar& x) : Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {} template inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst, Scalar(1)); } template inline void addTo(Dest& dst) const { scaleAndAddTo(dst, Scalar(1)); } template inline void subTo(Dest& dst) const { scaleAndAddTo(dst, Scalar(-1)); } template inline void scaleAndAddTo(Dest& dst, const Scalar& a_alpha) const { m_prod.derived().scaleAndAddTo(dst,a_alpha * m_alpha); } const Scalar& alpha() const { return m_alpha; } protected: const NestedProduct& m_prod; Scalar m_alpha; }; /** \internal * Overloaded to perform an efficient C = (A*B).lazy() */ template template Derived& MatrixBase::lazyAssign(const ProductBase& other) { other.derived().evalTo(derived()); return derived(); } } // end namespace Eigen #endif // EIGEN_PRODUCTBASE_H