// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2008-2016 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H #define EIGEN_GENERAL_MATRIX_VECTOR_H namespace Eigen { namespace internal { enum GEMVPacketSizeType { GEMVPacketFull = 0, GEMVPacketHalf, GEMVPacketQuarter }; template struct gemv_packet_cond { typedef T3 type; }; template struct gemv_packet_cond { typedef T1 type; }; template struct gemv_packet_cond { typedef T2 type; }; template class gemv_traits { typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; #define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \ typedef typename gemv_packet_cond::type, \ typename packet_traits::half, \ typename unpacket_traits::half>::half>::type \ prefix ## name ## Packet PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize); PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize); PACKET_DECL_COND_PREFIX(_, Res, _PacketSize); #undef PACKET_DECL_COND_PREFIX public: enum { Vectorizable = unpacket_traits<_LhsPacket>::vectorizable && unpacket_traits<_RhsPacket>::vectorizable && int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size), LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1, RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1, ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1 }; typedef typename conditional::type LhsPacket; typedef typename conditional::type RhsPacket; typedef typename conditional::type ResPacket; }; /* Optimized col-major matrix * vector product: * This algorithm processes the matrix per vertical panels, * which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments. * * Mixing type logic: C += alpha * A * B * | A | B |alpha| comments * |real |cplx |cplx | no vectorization * |real |cplx |real | alpha is converted to a cplx when calling the run function, no vectorization * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp * |cplx |real |real | optimal case, vectorization possible via real-cplx mul * * The same reasoning apply for the transposed case. */ template struct general_matrix_vector_product { typedef gemv_traits Traits; typedef gemv_traits HalfTraits; typedef gemv_traits QuarterTraits; typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; typedef typename Traits::LhsPacket LhsPacket; typedef typename Traits::RhsPacket RhsPacket; typedef typename Traits::ResPacket ResPacket; typedef typename HalfTraits::LhsPacket LhsPacketHalf; typedef typename HalfTraits::RhsPacket RhsPacketHalf; typedef typename HalfTraits::ResPacket ResPacketHalf; typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; typedef typename QuarterTraits::ResPacket ResPacketQuarter; EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha); }; template EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product::run( Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha) { EIGEN_UNUSED_VARIABLE(resIncr); eigen_internal_assert(resIncr==1); // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate propoer code. LhsMapper lhs(alhs); conj_helper cj; conj_helper pcj; conj_helper pcj_half; conj_helper pcj_quarter; const Index lhsStride = lhs.stride(); // TODO: for padded aligned inputs, we could enable aligned reads enum { LhsAlignment = Unaligned, ResPacketSize = Traits::ResPacketSize, ResPacketSizeHalf = HalfTraits::ResPacketSize, ResPacketSizeQuarter = QuarterTraits::ResPacketSize, LhsPacketSize = Traits::LhsPacketSize, HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf }; const Index n8 = rows-8*ResPacketSize+1; const Index n4 = rows-4*ResPacketSize+1; const Index n3 = rows-3*ResPacketSize+1; const Index n2 = rows-2*ResPacketSize+1; const Index n1 = rows-1*ResPacketSize+1; const Index n_half = rows-1*ResPacketSizeHalf+1; const Index n_quarter = rows-1*ResPacketSizeQuarter+1; // TODO: improve the following heuristic: const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4); ResPacket palpha = pset1(alpha); ResPacketHalf palpha_half = pset1(alpha); ResPacketQuarter palpha_quarter = pset1(alpha); for(Index j2=0; j2(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)), c4 = pset1(ResScalar(0)), c5 = pset1(ResScalar(0)), c6 = pset1(ResScalar(0)), c7 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+LhsPacketSize*2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+LhsPacketSize*3,j),b0,c3); c4 = pcj.pmadd(lhs.template load(i+LhsPacketSize*4,j),b0,c4); c5 = pcj.pmadd(lhs.template load(i+LhsPacketSize*5,j),b0,c5); c6 = pcj.pmadd(lhs.template load(i+LhsPacketSize*6,j),b0,c6); c7 = pcj.pmadd(lhs.template load(i+LhsPacketSize*7,j),b0,c7); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu(res+i+ResPacketSize*2))); pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu(res+i+ResPacketSize*3))); pstoreu(res+i+ResPacketSize*4, pmadd(c4,palpha,ploadu(res+i+ResPacketSize*4))); pstoreu(res+i+ResPacketSize*5, pmadd(c5,palpha,ploadu(res+i+ResPacketSize*5))); pstoreu(res+i+ResPacketSize*6, pmadd(c6,palpha,ploadu(res+i+ResPacketSize*6))); pstoreu(res+i+ResPacketSize*7, pmadd(c7,palpha,ploadu(res+i+ResPacketSize*7))); } if(i(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+LhsPacketSize*2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+LhsPacketSize*3,j),b0,c3); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu(res+i+ResPacketSize*2))); pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu(res+i+ResPacketSize*3))); i+=ResPacketSize*4; } if(i(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+LhsPacketSize*2,j),b0,c2); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu(res+i+ResPacketSize*2))); i+=ResPacketSize*3; } if(i(ResScalar(0)), c1 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); i+=ResPacketSize*2; } if(i(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); i+=ResPacketSize; } if(HasHalf && i(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj_half.pmadd(lhs.template load(i+0,j),b0,c0); } pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu(res+i+ResPacketSizeHalf*0))); i+=ResPacketSizeHalf; } if(HasQuarter && i(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj_quarter.pmadd(lhs.template load(i+0,j),b0,c0); } pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu(res+i+ResPacketSizeQuarter*0))); i+=ResPacketSizeQuarter; } for(;i struct general_matrix_vector_product { typedef gemv_traits Traits; typedef gemv_traits HalfTraits; typedef gemv_traits QuarterTraits; typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; typedef typename Traits::LhsPacket LhsPacket; typedef typename Traits::RhsPacket RhsPacket; typedef typename Traits::ResPacket ResPacket; typedef typename HalfTraits::LhsPacket LhsPacketHalf; typedef typename HalfTraits::RhsPacket RhsPacketHalf; typedef typename HalfTraits::ResPacket ResPacketHalf; typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; typedef typename QuarterTraits::ResPacket ResPacketQuarter; EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha); }; template EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product::run( Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha) { // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate propoer code. LhsMapper lhs(alhs); eigen_internal_assert(rhs.stride()==1); conj_helper cj; conj_helper pcj; conj_helper pcj_half; conj_helper pcj_quarter; // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large, // processing 8 rows at once might be counter productive wrt cache. const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7; const Index n4 = rows-3; const Index n2 = rows-1; // TODO: for padded aligned inputs, we could enable aligned reads enum { LhsAlignment = Unaligned, ResPacketSize = Traits::ResPacketSize, ResPacketSizeHalf = HalfTraits::ResPacketSize, ResPacketSizeQuarter = QuarterTraits::ResPacketSize, LhsPacketSize = Traits::LhsPacketSize, LhsPacketSizeHalf = HalfTraits::LhsPacketSize, LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize, HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf }; Index i=0; for(; i(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)), c4 = pset1(ResScalar(0)), c5 = pset1(ResScalar(0)), c6 = pset1(ResScalar(0)), c7 = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+3,j),b0,c3); c4 = pcj.pmadd(lhs.template load(i+4,j),b0,c4); c5 = pcj.pmadd(lhs.template load(i+5,j),b0,c5); c6 = pcj.pmadd(lhs.template load(i+6,j),b0,c6); c7 = pcj.pmadd(lhs.template load(i+7,j),b0,c7); } ResScalar cc0 = predux(c0); ResScalar cc1 = predux(c1); ResScalar cc2 = predux(c2); ResScalar cc3 = predux(c3); ResScalar cc4 = predux(c4); ResScalar cc5 = predux(c5); ResScalar cc6 = predux(c6); ResScalar cc7 = predux(c7); for(; j(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+3,j),b0,c3); } ResScalar cc0 = predux(c0); ResScalar cc1 = predux(c1); ResScalar cc2 = predux(c2); ResScalar cc3 = predux(c3); for(; j(ResScalar(0)), c1 = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+1,j),b0,c1); } ResScalar cc0 = predux(c0); ResScalar cc1 = predux(c1); for(; j(ResScalar(0)); ResPacketHalf c0_h = pset1(ResScalar(0)); ResPacketQuarter c0_q = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i,j),b0,c0); } ResScalar cc0 = predux(c0); if (HasHalf) { for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf) { RhsPacketHalf b0 = rhs.template load(j,0); c0_h = pcj_half.pmadd(lhs.template load(i,j),b0,c0_h); } cc0 += predux(c0_h); } if (HasQuarter) { for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter) { RhsPacketQuarter b0 = rhs.template load(j,0); c0_q = pcj_quarter.pmadd(lhs.template load(i,j),b0,c0_q); } cc0 += predux(c0_q); } for(; j