1 | ///
|
---|
2 | /// This file is part of ILNumerics Community Edition.
|
---|
3 | ///
|
---|
4 | /// ILNumerics Community Edition - high performance computing for applications.
|
---|
5 | /// Copyright (C) 2006 - 2012 Haymo Kutschbach, http://ilnumerics.net
|
---|
6 | ///
|
---|
7 | /// ILNumerics Community Edition is free software: you can redistribute it and/or modify
|
---|
8 | /// it under the terms of the GNU General Public License version 3 as published by
|
---|
9 | /// the Free Software Foundation.
|
---|
10 | ///
|
---|
11 | /// ILNumerics Community Edition is distributed in the hope that it will be useful,
|
---|
12 | /// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
13 | /// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
14 | /// GNU General Public License for more details.
|
---|
15 | ///
|
---|
16 | /// You should have received a copy of the GNU General Public License
|
---|
17 | /// along with ILNumerics Community Edition. See the file License.txt in the root
|
---|
18 | /// of your distribution package. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | ///
|
---|
20 | /// In addition this software uses the following components and/or licenses:
|
---|
21 | ///
|
---|
22 | /// =================================================================================
|
---|
23 | /// The Open Toolkit Library License
|
---|
24 | ///
|
---|
25 | /// Copyright (c) 2006 - 2009 the Open Toolkit library.
|
---|
26 | ///
|
---|
27 | /// Permission is hereby granted, free of charge, to any person obtaining a copy
|
---|
28 | /// of this software and associated documentation files (the "Software"), to deal
|
---|
29 | /// in the Software without restriction, including without limitation the rights to
|
---|
30 | /// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
---|
31 | /// the Software, and to permit persons to whom the Software is furnished to do
|
---|
32 | /// so, subject to the following conditions:
|
---|
33 | ///
|
---|
34 | /// The above copyright notice and this permission notice shall be included in all
|
---|
35 | /// copies or substantial portions of the Software.
|
---|
36 | ///
|
---|
37 | /// =================================================================================
|
---|
38 | ///
|
---|
39 |
|
---|
40 | using System;
|
---|
41 | using System.Collections.Generic;
|
---|
42 | using System.Linq;
|
---|
43 | using System.Text;
|
---|
44 | using ILNumerics.Misc;
|
---|
45 |
|
---|
46 | namespace ILNumerics {
|
---|
47 |
|
---|
48 | public partial class ILMath {
|
---|
49 | #region managed mult
|
---|
50 | unsafe public static void MMultBlockedThreaded(double[] A, double[] B, double[] C, int m, int n, int k, int kc) {
|
---|
51 |
|
---|
52 | // block parameters
|
---|
53 | int mc = 512;
|
---|
54 | int mr = 4;
|
---|
55 | int nr = 4;
|
---|
56 |
|
---|
57 | double[] CAux = ILMemoryPool.Pool.New<double>(mc + ALIGN / sizeof(double));
|
---|
58 | double[] Bpack1 = ILMemoryPool.Pool.New<double>(kc * (n) + ALIGN / sizeof(double));
|
---|
59 | double[] Apack1 = ILMemoryPool.Pool.New<double>(kc * mc + ALIGN / sizeof(double));
|
---|
60 | double[] Bpack2 = ILMemoryPool.Pool.New<double>(kc * (n) + ALIGN / sizeof(double));
|
---|
61 | double[] Apack2 = ILMemoryPool.Pool.New<double>(kc * mc + ALIGN / sizeof(double));
|
---|
62 | fixed (double* pAArr = A)
|
---|
63 | fixed (double* pBArr = B)
|
---|
64 | fixed (double* pCArr = C)
|
---|
65 | fixed (double* pBpack1 = Bpack1)
|
---|
66 | fixed (double* pApack1 = Apack1)
|
---|
67 | fixed (double* pBpack2 = Bpack2)
|
---|
68 | fixed (double* pApack2 = Apack2) {
|
---|
69 |
|
---|
70 | int workerCount = 1;
|
---|
71 | Action<object> func = data => {
|
---|
72 | MatMultArguments args = (MatMultArguments)data;
|
---|
73 | inner_k_loop_managed(m, n, k, kc, mc, mr, nr,
|
---|
74 | (double*)args.pArr, (double*)args.pBrr, (double*)args.pCrr, (double*)args.pBPack, (double*)args.pAPack,
|
---|
75 | args.n_start, args.n_end, args.m_start, args.m_end);
|
---|
76 | //inner_k_loop(args.Item1, args.Item2, args.Item3, args.Item4, args.Item5, args.Item6, args.Item7,
|
---|
77 | // args.Rest.Item1, args.Rest.Item2, args.Rest.Item3, args.Rest.Item4, args.Rest.Item5,
|
---|
78 | // args.Rest.Item6, args.Rest.Item7);
|
---|
79 | System.Threading.Interlocked.Decrement(ref workerCount);
|
---|
80 | };
|
---|
81 | MatMultArguments args4Thread; // = new MatMultArguments();
|
---|
82 | args4Thread.pArr = (IntPtr)pAArr;
|
---|
83 | args4Thread.pBrr = (IntPtr)pBArr;
|
---|
84 | args4Thread.pCrr = (IntPtr)pCArr;
|
---|
85 | args4Thread.pAPack = (IntPtr)pApack1;
|
---|
86 | args4Thread.pBPack = (IntPtr)pBpack1;
|
---|
87 | args4Thread.n_start = 0;
|
---|
88 | args4Thread.n_end = (int)(n / 2);
|
---|
89 | args4Thread.m_start = 0;
|
---|
90 | args4Thread.m_end = m;
|
---|
91 | ILNumerics.Misc.ILThreadPool.QueueUserWorkItem(0, func, args4Thread);
|
---|
92 |
|
---|
93 | //inner_k_loop(m, n, k, kc, mc, mr, nr, (IntPtr)pAArr, (IntPtr)pBArr, (IntPtr)pCArr, (IntPtr)pBpack2, (IntPtr)pApack2, n / 2, n);
|
---|
94 | inner_k_loop_managed(m, n, k, kc, mc, mr, nr, pAArr, pBArr, pCArr, pBpack2, pApack2, (int)(n / 2), n, 0, m);
|
---|
95 | ILThreadPool.Wait4Workers(ref workerCount);
|
---|
96 |
|
---|
97 | }
|
---|
98 | }
|
---|
99 | unsafe public static void MMultBlocked(double[] A, double[] B, double[] C, int m, int n, int k, int kc) {
|
---|
100 |
|
---|
101 | // block parameters
|
---|
102 | int mc = 512;
|
---|
103 | int mr = 4;
|
---|
104 | int nr = 4;
|
---|
105 |
|
---|
106 | double[] CAux = ILMemoryPool.Pool.New<double>(mc + ALIGN / sizeof(double));
|
---|
107 | double[] Bpack1 = ILMemoryPool.Pool.New<double>(kc * (n) + ALIGN / sizeof(double));
|
---|
108 | double[] Apack1 = ILMemoryPool.Pool.New<double>(kc * mc + ALIGN / sizeof(double));
|
---|
109 | double[] Bpack2 = ILMemoryPool.Pool.New<double>(kc * (n) + ALIGN / sizeof(double));
|
---|
110 | double[] Apack2 = ILMemoryPool.Pool.New<double>(kc * mc + ALIGN / sizeof(double));
|
---|
111 | fixed (double* pAArr = A)
|
---|
112 | fixed (double* pBArr = B)
|
---|
113 | fixed (double* pCArr = C)
|
---|
114 | fixed (double* pBpack1 = Bpack1)
|
---|
115 | fixed (double* pApack1 = Apack1)
|
---|
116 | fixed (double* pBpack2 = Bpack2)
|
---|
117 | fixed (double* pApack2 = Apack2) {
|
---|
118 |
|
---|
119 | inner_k_loop_managed(m, n, k, kc, mc, mr, nr, pAArr, pBArr, pCArr, pBpack2, pApack2, 0, n, 0, m);
|
---|
120 |
|
---|
121 | }
|
---|
122 | }
|
---|
123 |
|
---|
124 | unsafe private static void inner_k_loop_managed(int m, int n, int k, int kc, int mc, int mr, int nr,
|
---|
125 | double* pAArr, double* pBArr, double* pCArr, double* pBpack, double* pApack,
|
---|
126 | int n_start, int n_end, int m_start, int m_end) {
|
---|
127 |
|
---|
128 | double* pApackTmp = (double*)((byte*)pApack + (ALIGN - ((uint)pApack % ALIGN)));
|
---|
129 | double* pBpackTmp = (double*)((byte*)pBpack + (ALIGN - ((uint)pBpack % ALIGN)));
|
---|
130 | double* pCArrTmp, pBArrTmp, pAArrTmp;
|
---|
131 |
|
---|
132 | int n_len = n_end - n_start;
|
---|
133 | for (int ki = 0; ki < k; ki += kc) {
|
---|
134 | if (k - ki < kc) kc = k - ki;
|
---|
135 |
|
---|
136 | #region pack B
|
---|
137 | pBpackTmp = pBpack;
|
---|
138 | for (int nb = 0; nb < n_len; nb++) {
|
---|
139 | pBArrTmp = pBArr + ki + k * (nb + n_start);
|
---|
140 | int c = 0;
|
---|
141 | for (; c < kc - 8; c += 8) {
|
---|
142 | pBpackTmp[0] = pBArrTmp[0];
|
---|
143 | pBpackTmp[1] = pBArrTmp[1];
|
---|
144 | pBpackTmp[2] = pBArrTmp[2];
|
---|
145 | pBpackTmp[3] = pBArrTmp[3];
|
---|
146 | pBpackTmp[4] = pBArrTmp[4];
|
---|
147 | pBpackTmp[5] = pBArrTmp[5];
|
---|
148 | pBpackTmp[6] = pBArrTmp[6];
|
---|
149 | pBpackTmp[7] = pBArrTmp[7];
|
---|
150 | pBpackTmp += 8; pBArrTmp += 8;
|
---|
151 | }
|
---|
152 | for (; c < kc; c++) {
|
---|
153 | *pBpackTmp++ = *pBArrTmp++;
|
---|
154 | }
|
---|
155 | }
|
---|
156 | //pack(BArr, Bpack, r, 0, kc, n, k);
|
---|
157 | #endregion
|
---|
158 |
|
---|
159 | int mcc = mc;
|
---|
160 | int m_len = m_end - m_start;
|
---|
161 | for (int ai = 0; ai < m_len; ai += mcc) {
|
---|
162 | if (m_len - ai < mcc) mcc = m_len - ai;
|
---|
163 |
|
---|
164 | #region pack A
|
---|
165 | for (int ca = 0; ca < kc; ca++) {
|
---|
166 | pApackTmp = pApack + ca;
|
---|
167 | pAArrTmp = pAArr + ai + m * (m_start + ki + ca);
|
---|
168 | int ra = 0;
|
---|
169 | for (; ra < mcc - 8; ra += 8) {
|
---|
170 | pApackTmp[(ra) * kc] = pAArrTmp[0];
|
---|
171 | pApackTmp[(ra + 1) * kc] = pAArrTmp[1];
|
---|
172 | pApackTmp[(ra + 2) * kc] = pAArrTmp[2];
|
---|
173 | pApackTmp[(ra + 3) * kc] = pAArrTmp[3];
|
---|
174 | pApackTmp[(ra + 4) * kc] = pAArrTmp[4];
|
---|
175 | pApackTmp[(ra + 5) * kc] = pAArrTmp[5];
|
---|
176 | pApackTmp[(ra + 6) * kc] = pAArrTmp[6];
|
---|
177 | pApackTmp[(ra + 7) * kc] = pAArrTmp[7];
|
---|
178 | pAArrTmp += 8;
|
---|
179 | }
|
---|
180 | for (; ra < mcc; ra++) {
|
---|
181 | pApackTmp[ra * kc] = *pAArrTmp++;
|
---|
182 | }
|
---|
183 | }
|
---|
184 | #endregion
|
---|
185 |
|
---|
186 | #region subblocked
|
---|
187 | int nrLen = nr;
|
---|
188 | for (int nri = 0; nri < n_len; nri += nrLen) {
|
---|
189 | if (n_len - nri < nrLen) nrLen = n_len - nri;
|
---|
190 | int mrLen = mr;
|
---|
191 | for (int mri = 0; mri < mcc; mri += mrLen) {
|
---|
192 | if (mcc - mri < mrLen) mrLen = mcc - mri;
|
---|
193 |
|
---|
194 | // prefetch CAux
|
---|
195 |
|
---|
196 |
|
---|
197 | if (false && mrLen == 4 && nrLen == 4) {
|
---|
198 |
|
---|
199 | } else {
|
---|
200 | for (int nii = 0; nii < nrLen; nii++) {
|
---|
201 | pCArrTmp = pCArr + ai + mri + (nri + nii + n_start) * m;
|
---|
202 | //for (int mii = 0; mii < mrLen; mii++) {
|
---|
203 | // pCAux[mii] = pCArrTmp[mii];
|
---|
204 | //}
|
---|
205 | for (int mii = 0; mii < mrLen; mii++) {
|
---|
206 | pApackTmp = pApack + (mri + mii) * kc; // <-- transposed packed!
|
---|
207 | pBpackTmp = pBpack + (nri + nii) * kc;
|
---|
208 | double sum = 0;
|
---|
209 | int jj = 0;
|
---|
210 | for (; jj < kc - 8; jj += 8) {
|
---|
211 | sum += pApackTmp[0] * pBpackTmp[0]
|
---|
212 | + pApackTmp[1] * pBpackTmp[1]
|
---|
213 | + pApackTmp[2] * pBpackTmp[2]
|
---|
214 | + pApackTmp[3] * pBpackTmp[3]
|
---|
215 | + pApackTmp[4] * pBpackTmp[4]
|
---|
216 | + pApackTmp[5] * pBpackTmp[5]
|
---|
217 | + pApackTmp[6] * pBpackTmp[6]
|
---|
218 | + pApackTmp[7] * pBpackTmp[7];
|
---|
219 | pApackTmp += 8;
|
---|
220 | pBpackTmp += 8;
|
---|
221 | }
|
---|
222 | for (; jj < kc; jj++) {
|
---|
223 | sum += *pApackTmp++ * *pBpackTmp++;
|
---|
224 | }
|
---|
225 | //CAux[ra] = sum;
|
---|
226 | pCArrTmp[mii] += sum;
|
---|
227 | }
|
---|
228 | //for (int mii = 0; mii < mrLen; mii++) {
|
---|
229 | // pCArrTmp[mii] = pCAux[mii];
|
---|
230 | //}
|
---|
231 |
|
---|
232 | }
|
---|
233 |
|
---|
234 | }
|
---|
235 | }
|
---|
236 |
|
---|
237 | }
|
---|
238 | #endregion
|
---|
239 |
|
---|
240 | #region standard mmult
|
---|
241 | //for (int bj = 0; bj < n_len; bj++) {
|
---|
242 | // pCArrTmp = pCArr + m_start + ai + (bj + n_start) * m;
|
---|
243 | // for (int ra = 0; ra < mcc; ra++) {
|
---|
244 | // pApackTmp = pApack + ra * kc;
|
---|
245 | // pBpackTmp = pBpack + bj * kc;
|
---|
246 | // double sum = 0;
|
---|
247 | // int jj = 0;
|
---|
248 | // for (; jj < kc - 24; jj += 24) {
|
---|
249 | // sum += pApackTmp[0] * pBpackTmp[0]
|
---|
250 | // + pApackTmp[1] * pBpackTmp[1]
|
---|
251 | // + pApackTmp[2] * pBpackTmp[2]
|
---|
252 | // + pApackTmp[3] * pBpackTmp[3]
|
---|
253 | // + pApackTmp[4] * pBpackTmp[4]
|
---|
254 | // + pApackTmp[5] * pBpackTmp[5]
|
---|
255 | // + pApackTmp[6] * pBpackTmp[6]
|
---|
256 | // + pApackTmp[7] * pBpackTmp[7]
|
---|
257 | // + pApackTmp[8] * pBpackTmp[8]
|
---|
258 | // + pApackTmp[9] * pBpackTmp[9]
|
---|
259 | // + pApackTmp[10] * pBpackTmp[10]
|
---|
260 | // + pApackTmp[11] * pBpackTmp[11]
|
---|
261 | // + pApackTmp[12] * pBpackTmp[12]
|
---|
262 | // + pApackTmp[13] * pBpackTmp[13]
|
---|
263 | // + pApackTmp[14] * pBpackTmp[14]
|
---|
264 | // + pApackTmp[15] * pBpackTmp[15]
|
---|
265 | // + pApackTmp[16] * pBpackTmp[16]
|
---|
266 | // + pApackTmp[17] * pBpackTmp[17]
|
---|
267 | // + pApackTmp[18] * pBpackTmp[18]
|
---|
268 | // + pApackTmp[19] * pBpackTmp[19]
|
---|
269 | // + pApackTmp[20] * pBpackTmp[20]
|
---|
270 | // + pApackTmp[21] * pBpackTmp[21]
|
---|
271 | // + pApackTmp[22] * pBpackTmp[22]
|
---|
272 | // + pApackTmp[23] * pBpackTmp[23];
|
---|
273 | // pApackTmp += 24; pBpackTmp += 24;
|
---|
274 | // }
|
---|
275 | // for (; jj < kc; jj++) {
|
---|
276 | // sum += *pApackTmp++ * *pBpackTmp++;
|
---|
277 | // }
|
---|
278 | // pCArrTmp[ra] += sum;
|
---|
279 | // }
|
---|
280 | //}
|
---|
281 | #endregion
|
---|
282 | #region unblocked fast mult (NOT CORRECT! would require A to be NOT transposed!)
|
---|
283 | //double* pYA, pXA, pZA;
|
---|
284 | //for (int i = 0; i < n; i++) {
|
---|
285 | // pYA = pBpack + kc * i;
|
---|
286 | // pXA = pApack;
|
---|
287 | // for (int kcl = 0; kcl < kc; kcl++) {
|
---|
288 | // double r = *pYA++;
|
---|
289 | // //double r = Y[k + N * i];
|
---|
290 | // pZA = pCArr + ai + m * i;
|
---|
291 | // int j = 0;
|
---|
292 | // for (; j < mcc - 16; j += 16) {
|
---|
293 | // //double c0 = pZA[0], c1 = pZA[1], c2 = pZA[2], c3 = pZA[3], c4 = pZA[4], c5 = pZA[5], c6 = pZA[6], c7 = pZA[7];
|
---|
294 | // pZA[0] += r * pXA[0];
|
---|
295 | // pZA[1] += r * pXA[1];
|
---|
296 | // pZA[2] += r * pXA[2];
|
---|
297 | // pZA[3] += r * pXA[3];
|
---|
298 | // pZA[4] += r * pXA[4];
|
---|
299 | // pZA[5] += r * pXA[5];
|
---|
300 | // pZA[6] += r * pXA[6];
|
---|
301 | // pZA[7] += r * pXA[7];
|
---|
302 | // pZA[8] += r * pXA[8];
|
---|
303 | // pZA[9] += r * pXA[9];
|
---|
304 | // pZA[10] += r * pXA[10];
|
---|
305 | // pZA[11] += r * pXA[11];
|
---|
306 | // pZA[12] += r * pXA[12];
|
---|
307 | // pZA[13] += r * pXA[13];
|
---|
308 | // pZA[14] += r * pXA[14];
|
---|
309 | // pZA[15] += r * pXA[15];
|
---|
310 | // pZA += 16; pXA += 16;
|
---|
311 | // }
|
---|
312 | // while (j++ < mcc) {
|
---|
313 | // *pZA++ += *pXA++ * r;
|
---|
314 | // //Z[j + N * i] += X[j + N * k] * r;
|
---|
315 | // }
|
---|
316 | // }
|
---|
317 | //}
|
---|
318 | #endregion
|
---|
319 | }
|
---|
320 | }
|
---|
321 | }
|
---|
322 |
|
---|
323 | #endregion
|
---|
324 | }
|
---|
325 | }
|
---|