1 | // This file is part of Eigen, a lightweight C++ template library |
---|
2 | // for linear algebra. |
---|
3 | // |
---|
4 | // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr> |
---|
5 | // |
---|
6 | // This Source Code Form is subject to the terms of the Mozilla |
---|
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
---|
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
---|
9 | |
---|
10 | #ifndef EIGEN_SPARSEDENSEPRODUCT_H |
---|
11 | #define EIGEN_SPARSEDENSEPRODUCT_H |
---|
12 | |
---|
13 | namespace Eigen { |
---|
14 | |
---|
15 | template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType |
---|
16 | { |
---|
17 | typedef SparseTimeDenseProduct<Lhs,Rhs> Type; |
---|
18 | }; |
---|
19 | |
---|
20 | template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1> |
---|
21 | { |
---|
22 | typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type; |
---|
23 | }; |
---|
24 | |
---|
25 | template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType |
---|
26 | { |
---|
27 | typedef DenseTimeSparseProduct<Lhs,Rhs> Type; |
---|
28 | }; |
---|
29 | |
---|
30 | template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1> |
---|
31 | { |
---|
32 | typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type; |
---|
33 | }; |
---|
34 | |
---|
35 | namespace internal { |
---|
36 | |
---|
37 | template<typename Lhs, typename Rhs, bool Tr> |
---|
38 | struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> > |
---|
39 | { |
---|
40 | typedef Sparse StorageKind; |
---|
41 | typedef typename scalar_product_traits<typename traits<Lhs>::Scalar, |
---|
42 | typename traits<Rhs>::Scalar>::ReturnType Scalar; |
---|
43 | typedef typename Lhs::Index Index; |
---|
44 | typedef typename Lhs::Nested LhsNested; |
---|
45 | typedef typename Rhs::Nested RhsNested; |
---|
46 | typedef typename remove_all<LhsNested>::type _LhsNested; |
---|
47 | typedef typename remove_all<RhsNested>::type _RhsNested; |
---|
48 | |
---|
49 | enum { |
---|
50 | LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost, |
---|
51 | RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost, |
---|
52 | |
---|
53 | RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime), |
---|
54 | ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime), |
---|
55 | MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime), |
---|
56 | MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime), |
---|
57 | |
---|
58 | Flags = Tr ? RowMajorBit : 0, |
---|
59 | |
---|
60 | CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost |
---|
61 | }; |
---|
62 | }; |
---|
63 | |
---|
64 | } // end namespace internal |
---|
65 | |
---|
66 | template<typename Lhs, typename Rhs, bool Tr> |
---|
67 | class SparseDenseOuterProduct |
---|
68 | : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> > |
---|
69 | { |
---|
70 | public: |
---|
71 | |
---|
72 | typedef SparseMatrixBase<SparseDenseOuterProduct> Base; |
---|
73 | EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct) |
---|
74 | typedef internal::traits<SparseDenseOuterProduct> Traits; |
---|
75 | |
---|
76 | private: |
---|
77 | |
---|
78 | typedef typename Traits::LhsNested LhsNested; |
---|
79 | typedef typename Traits::RhsNested RhsNested; |
---|
80 | typedef typename Traits::_LhsNested _LhsNested; |
---|
81 | typedef typename Traits::_RhsNested _RhsNested; |
---|
82 | |
---|
83 | public: |
---|
84 | |
---|
85 | class InnerIterator; |
---|
86 | |
---|
87 | EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs) |
---|
88 | : m_lhs(lhs), m_rhs(rhs) |
---|
89 | { |
---|
90 | EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); |
---|
91 | } |
---|
92 | |
---|
93 | EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs) |
---|
94 | : m_lhs(lhs), m_rhs(rhs) |
---|
95 | { |
---|
96 | EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); |
---|
97 | } |
---|
98 | |
---|
99 | EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); } |
---|
100 | EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); } |
---|
101 | |
---|
102 | EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } |
---|
103 | EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } |
---|
104 | |
---|
105 | protected: |
---|
106 | LhsNested m_lhs; |
---|
107 | RhsNested m_rhs; |
---|
108 | }; |
---|
109 | |
---|
110 | template<typename Lhs, typename Rhs, bool Transpose> |
---|
111 | class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator |
---|
112 | { |
---|
113 | typedef typename _LhsNested::InnerIterator Base; |
---|
114 | public: |
---|
115 | EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) |
---|
116 | : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer)) |
---|
117 | { |
---|
118 | } |
---|
119 | |
---|
120 | inline Index outer() const { return m_outer; } |
---|
121 | inline Index row() const { return Transpose ? Base::row() : m_outer; } |
---|
122 | inline Index col() const { return Transpose ? m_outer : Base::row(); } |
---|
123 | |
---|
124 | inline Scalar value() const { return Base::value() * m_factor; } |
---|
125 | |
---|
126 | protected: |
---|
127 | int m_outer; |
---|
128 | Scalar m_factor; |
---|
129 | }; |
---|
130 | |
---|
131 | namespace internal { |
---|
132 | template<typename Lhs, typename Rhs> |
---|
133 | struct traits<SparseTimeDenseProduct<Lhs,Rhs> > |
---|
134 | : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> > |
---|
135 | { |
---|
136 | typedef Dense StorageKind; |
---|
137 | typedef MatrixXpr XprKind; |
---|
138 | }; |
---|
139 | |
---|
140 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, |
---|
141 | int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, |
---|
142 | bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> |
---|
143 | struct sparse_time_dense_product_impl; |
---|
144 | |
---|
145 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
---|
146 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true> |
---|
147 | { |
---|
148 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
---|
149 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
---|
150 | typedef typename internal::remove_all<DenseResType>::type Res; |
---|
151 | typedef typename Lhs::Index Index; |
---|
152 | typedef typename Lhs::InnerIterator LhsInnerIterator; |
---|
153 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) |
---|
154 | { |
---|
155 | for(Index c=0; c<rhs.cols(); ++c) |
---|
156 | { |
---|
157 | int n = lhs.outerSize(); |
---|
158 | for(Index j=0; j<n; ++j) |
---|
159 | { |
---|
160 | typename Res::Scalar tmp(0); |
---|
161 | for(LhsInnerIterator it(lhs,j); it ;++it) |
---|
162 | tmp += it.value() * rhs.coeff(it.index(),c); |
---|
163 | res.coeffRef(j,c) = alpha * tmp; |
---|
164 | } |
---|
165 | } |
---|
166 | } |
---|
167 | }; |
---|
168 | |
---|
169 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
---|
170 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true> |
---|
171 | { |
---|
172 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
---|
173 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
---|
174 | typedef typename internal::remove_all<DenseResType>::type Res; |
---|
175 | typedef typename Lhs::InnerIterator LhsInnerIterator; |
---|
176 | typedef typename Lhs::Index Index; |
---|
177 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) |
---|
178 | { |
---|
179 | for(Index c=0; c<rhs.cols(); ++c) |
---|
180 | { |
---|
181 | for(Index j=0; j<lhs.outerSize(); ++j) |
---|
182 | { |
---|
183 | typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); |
---|
184 | for(LhsInnerIterator it(lhs,j); it ;++it) |
---|
185 | res.coeffRef(it.index(),c) += it.value() * rhs_j; |
---|
186 | } |
---|
187 | } |
---|
188 | } |
---|
189 | }; |
---|
190 | |
---|
191 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
---|
192 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false> |
---|
193 | { |
---|
194 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
---|
195 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
---|
196 | typedef typename internal::remove_all<DenseResType>::type Res; |
---|
197 | typedef typename Lhs::InnerIterator LhsInnerIterator; |
---|
198 | typedef typename Lhs::Index Index; |
---|
199 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) |
---|
200 | { |
---|
201 | for(Index j=0; j<lhs.outerSize(); ++j) |
---|
202 | { |
---|
203 | typename Res::RowXpr res_j(res.row(j)); |
---|
204 | for(LhsInnerIterator it(lhs,j); it ;++it) |
---|
205 | res_j += (alpha*it.value()) * rhs.row(it.index()); |
---|
206 | } |
---|
207 | } |
---|
208 | }; |
---|
209 | |
---|
210 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
---|
211 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false> |
---|
212 | { |
---|
213 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
---|
214 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
---|
215 | typedef typename internal::remove_all<DenseResType>::type Res; |
---|
216 | typedef typename Lhs::InnerIterator LhsInnerIterator; |
---|
217 | typedef typename Lhs::Index Index; |
---|
218 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha) |
---|
219 | { |
---|
220 | for(Index j=0; j<lhs.outerSize(); ++j) |
---|
221 | { |
---|
222 | typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); |
---|
223 | for(LhsInnerIterator it(lhs,j); it ;++it) |
---|
224 | res.row(it.index()) += (alpha*it.value()) * rhs_j; |
---|
225 | } |
---|
226 | } |
---|
227 | }; |
---|
228 | |
---|
229 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> |
---|
230 | inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) |
---|
231 | { |
---|
232 | sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha); |
---|
233 | } |
---|
234 | |
---|
235 | } // end namespace internal |
---|
236 | |
---|
237 | template<typename Lhs, typename Rhs> |
---|
238 | class SparseTimeDenseProduct |
---|
239 | : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> |
---|
240 | { |
---|
241 | public: |
---|
242 | EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct) |
---|
243 | |
---|
244 | SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) |
---|
245 | {} |
---|
246 | |
---|
247 | template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const |
---|
248 | { |
---|
249 | internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha); |
---|
250 | } |
---|
251 | |
---|
252 | private: |
---|
253 | SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&); |
---|
254 | }; |
---|
255 | |
---|
256 | |
---|
257 | // dense = dense * sparse |
---|
258 | namespace internal { |
---|
259 | template<typename Lhs, typename Rhs> |
---|
260 | struct traits<DenseTimeSparseProduct<Lhs,Rhs> > |
---|
261 | : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> > |
---|
262 | { |
---|
263 | typedef Dense StorageKind; |
---|
264 | }; |
---|
265 | } // end namespace internal |
---|
266 | |
---|
267 | template<typename Lhs, typename Rhs> |
---|
268 | class DenseTimeSparseProduct |
---|
269 | : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> |
---|
270 | { |
---|
271 | public: |
---|
272 | EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct) |
---|
273 | |
---|
274 | DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) |
---|
275 | {} |
---|
276 | |
---|
277 | template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const |
---|
278 | { |
---|
279 | Transpose<const _LhsNested> lhs_t(m_lhs); |
---|
280 | Transpose<const _RhsNested> rhs_t(m_rhs); |
---|
281 | Transpose<Dest> dest_t(dest); |
---|
282 | internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha); |
---|
283 | } |
---|
284 | |
---|
285 | private: |
---|
286 | DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&); |
---|
287 | }; |
---|
288 | |
---|
289 | // sparse * dense |
---|
290 | template<typename Derived> |
---|
291 | template<typename OtherDerived> |
---|
292 | inline const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type |
---|
293 | SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const |
---|
294 | { |
---|
295 | return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); |
---|
296 | } |
---|
297 | |
---|
298 | } // end namespace Eigen |
---|
299 | |
---|
300 | #endif // EIGEN_SPARSEDENSEPRODUCT_H |
---|