1 | // This file is part of Eigen, a lightweight C++ template library |
---|
2 | // for linear algebra. |
---|
3 | // |
---|
4 | // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr> |
---|
5 | // |
---|
6 | // This Source Code Form is subject to the terms of the Mozilla |
---|
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
---|
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
---|
9 | |
---|
10 | #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
---|
11 | #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
---|
12 | |
---|
13 | namespace Eigen { |
---|
14 | |
---|
15 | namespace internal { |
---|
16 | |
---|
17 | |
---|
18 | // perform a pseudo in-place sparse * sparse product assuming all matrices are col major |
---|
19 | template<typename Lhs, typename Rhs, typename ResultType> |
---|
20 | static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance) |
---|
21 | { |
---|
22 | // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); |
---|
23 | |
---|
24 | typedef typename remove_all<Lhs>::type::Scalar Scalar; |
---|
25 | typedef typename remove_all<Lhs>::type::Index Index; |
---|
26 | |
---|
27 | // make sure to call innerSize/outerSize since we fake the storage order. |
---|
28 | Index rows = lhs.innerSize(); |
---|
29 | Index cols = rhs.outerSize(); |
---|
30 | //int size = lhs.outerSize(); |
---|
31 | eigen_assert(lhs.outerSize() == rhs.innerSize()); |
---|
32 | |
---|
33 | // allocate a temporary buffer |
---|
34 | AmbiVector<Scalar,Index> tempVector(rows); |
---|
35 | |
---|
36 | // estimate the number of non zero entries |
---|
37 | // given a rhs column containing Y non zeros, we assume that the respective Y columns |
---|
38 | // of the lhs differs in average of one non zeros, thus the number of non zeros for |
---|
39 | // the product of a rhs column with the lhs is X+Y where X is the average number of non zero |
---|
40 | // per column of the lhs. |
---|
41 | // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) |
---|
42 | Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); |
---|
43 | |
---|
44 | // mimics a resizeByInnerOuter: |
---|
45 | if(ResultType::IsRowMajor) |
---|
46 | res.resize(cols, rows); |
---|
47 | else |
---|
48 | res.resize(rows, cols); |
---|
49 | |
---|
50 | res.reserve(estimated_nnz_prod); |
---|
51 | double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); |
---|
52 | for (Index j=0; j<cols; ++j) |
---|
53 | { |
---|
54 | // FIXME: |
---|
55 | //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); |
---|
56 | // let's do a more accurate determination of the nnz ratio for the current column j of res |
---|
57 | tempVector.init(ratioColRes); |
---|
58 | tempVector.setZero(); |
---|
59 | for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) |
---|
60 | { |
---|
61 | // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) |
---|
62 | tempVector.restart(); |
---|
63 | Scalar x = rhsIt.value(); |
---|
64 | for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) |
---|
65 | { |
---|
66 | tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; |
---|
67 | } |
---|
68 | } |
---|
69 | res.startVec(j); |
---|
70 | for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it) |
---|
71 | res.insertBackByOuterInner(j,it.index()) = it.value(); |
---|
72 | } |
---|
73 | res.finalize(); |
---|
74 | } |
---|
75 | |
---|
76 | template<typename Lhs, typename Rhs, typename ResultType, |
---|
77 | int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, |
---|
78 | int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, |
---|
79 | int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> |
---|
80 | struct sparse_sparse_product_with_pruning_selector; |
---|
81 | |
---|
82 | template<typename Lhs, typename Rhs, typename ResultType> |
---|
83 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> |
---|
84 | { |
---|
85 | typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; |
---|
86 | typedef typename ResultType::RealScalar RealScalar; |
---|
87 | |
---|
88 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) |
---|
89 | { |
---|
90 | typename remove_all<ResultType>::type _res(res.rows(), res.cols()); |
---|
91 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); |
---|
92 | res.swap(_res); |
---|
93 | } |
---|
94 | }; |
---|
95 | |
---|
96 | template<typename Lhs, typename Rhs, typename ResultType> |
---|
97 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> |
---|
98 | { |
---|
99 | typedef typename ResultType::RealScalar RealScalar; |
---|
100 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) |
---|
101 | { |
---|
102 | // we need a col-major matrix to hold the result |
---|
103 | typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; |
---|
104 | SparseTemporaryType _res(res.rows(), res.cols()); |
---|
105 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); |
---|
106 | res = _res; |
---|
107 | } |
---|
108 | }; |
---|
109 | |
---|
110 | template<typename Lhs, typename Rhs, typename ResultType> |
---|
111 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> |
---|
112 | { |
---|
113 | typedef typename ResultType::RealScalar RealScalar; |
---|
114 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) |
---|
115 | { |
---|
116 | // let's transpose the product to get a column x column product |
---|
117 | typename remove_all<ResultType>::type _res(res.rows(), res.cols()); |
---|
118 | internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); |
---|
119 | res.swap(_res); |
---|
120 | } |
---|
121 | }; |
---|
122 | |
---|
123 | template<typename Lhs, typename Rhs, typename ResultType> |
---|
124 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> |
---|
125 | { |
---|
126 | typedef typename ResultType::RealScalar RealScalar; |
---|
127 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) |
---|
128 | { |
---|
129 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; |
---|
130 | ColMajorMatrix colLhs(lhs); |
---|
131 | ColMajorMatrix colRhs(rhs); |
---|
132 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance); |
---|
133 | |
---|
134 | // let's transpose the product to get a column x column product |
---|
135 | // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; |
---|
136 | // SparseTemporaryType _res(res.cols(), res.rows()); |
---|
137 | // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); |
---|
138 | // res = _res.transpose(); |
---|
139 | } |
---|
140 | }; |
---|
141 | |
---|
142 | // NOTE the 2 others cases (col row *) must never occur since they are caught |
---|
143 | // by ProductReturnType which transforms it to (col col *) by evaluating rhs. |
---|
144 | |
---|
145 | } // end namespace internal |
---|
146 | |
---|
147 | } // end namespace Eigen |
---|
148 | |
---|
149 | #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
---|