Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
MatrixProductMMA.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5 // Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13 
14 #pragma GCC target("cpu=power10,htm")
15 
16 #ifdef __has_builtin
17 #if !__has_builtin(__builtin_vsx_assemble_pair)
18 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
19 #endif
20 #endif
21 
22 namespace Eigen {
23 
24 namespace internal {
25 
26 template<typename Scalar, typename Packet>
27 EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
28 {
29  __builtin_mma_xxsetaccz(acc);
30 }
31 
32 template<typename DataMapper, typename Index, typename Packet, const Index accCols>
33 EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
34 {
35  PacketBlock<Packet, 4> result;
36  __builtin_mma_disassemble_acc(&result.packet, acc);
37 
38  PacketBlock<Packet, 4> tRes;
39  bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
40 
41  bscale<Packet, 4>(tRes, result, alpha);
42 
43  data.template storePacketBlock<Packet, 4>(i, 0, tRes);
44 }
45 
46 template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC>
47 EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
48 {
49  PacketBlock<Packet, 4> resultReal, resultImag;
50  __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
51  __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
52 
53  PacketBlock<Packetc, 8> tRes;
54  bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
55 
56  PacketBlock<Packet,4> taccReal, taccImag;
57  bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
58 
59  PacketBlock<Packetc, 4> acc1, acc2;
60  bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
61 
62  data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
63  data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
64 }
65 
66 // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
67 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
68 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b)
69 {
70  if(NegativeAccumulate)
71  {
72  __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
73  } else {
74  __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
75  }
76 }
77 
78 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
79 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
80 {
81  __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
82  if(NegativeAccumulate)
83  {
84  __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
85  } else {
86  __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
87  }
88 }
89 
90 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
91 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
92 {
93  if(NegativeAccumulate)
94  {
95  __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
96  } else {
97  __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
98  }
99 }
100 
101 template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
102 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
103 {
104  // Just for compilation
105 }
106 
107 template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
108 EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
109 {
110  pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
111  if(LhsIsReal) {
112  pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
113  } else {
114  if(!RhsIsReal) {
115  pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
116  pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
117  } else {
118  EIGEN_UNUSED_VARIABLE(rhsVi);
119  }
120  pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
121  }
122 }
123 
124 // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
125 template<typename Scalar, typename Packet>
126 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
127 {
128  rhsV = ploadRhs<Scalar, Packet>(rhs);
129 }
130 
131 template<>
132 EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
133 {
134  rhsV.packet[0] = ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ));
135  rhsV.packet[1] = ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1));
136 }
137 
138 template<>
139 EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
140 {
141 #if EIGEN_COMP_LLVM
142  __builtin_vsx_assemble_pair(&rhsV,
143  (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)(((Packet2d *)rhs) + 1))),
144  (__vector unsigned char)(ploadRhs<double, Packet2d>((const double *)((Packet2d *)rhs ))));
145 #else
146  __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
147 #endif
148 }
149 
150 template<>
151 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
152 {
153  // Just for compilation
154 }
155 
156 // PEEL_MMA loop factor.
157 #define PEEL_MMA 7
158 
159 #define MICRO_MMA_UNROLL(func) \
160  func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
161 
162 #define MICRO_MMA_LOAD_ONE(iter) \
163  if (unroll_factor > iter) { \
164  lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
165  lhs_ptr##iter += accCols; \
166  } else { \
167  EIGEN_UNUSED_VARIABLE(lhsV##iter); \
168  }
169 
170 #define MICRO_MMA_WORK_ONE(iter, type, peel) \
171  if (unroll_factor > iter) { \
172  pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
173  }
174 
175 #define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
176  if (PEEL_MMA > peel) { \
177  Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
178  ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
179  MICRO_MMA_UNROLL(func2); \
180  func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
181  func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
182  } else { \
183  EIGEN_UNUSED_VARIABLE(rhsV##peel); \
184  }
185 
186 #define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
187  type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
188  MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
189  MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
190  MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
191  MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
192 
193 #define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
194  type rhsV0; \
195  MICRO_MMA_TYPE_PEEL(func,func2,type,0);
196 
197 #define MICRO_MMA_ONE_PEEL \
198  if (sizeof(Scalar) == sizeof(float)) { \
199  MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
200  } else { \
201  MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
202  } \
203  rhs_ptr += (accRows * PEEL_MMA);
204 
205 #define MICRO_MMA_ONE \
206  if (sizeof(Scalar) == sizeof(float)) { \
207  MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
208  } else { \
209  MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
210  } \
211  rhs_ptr += accRows;
212 
213 #define MICRO_MMA_DST_PTR_ONE(iter) \
214  if (unroll_factor > iter) { \
215  bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
216  } else { \
217  EIGEN_UNUSED_VARIABLE(accZero##iter); \
218  }
219 
220 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
221 
222 #define MICRO_MMA_SRC_PTR_ONE(iter) \
223  if (unroll_factor > iter) { \
224  lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
225  } else { \
226  EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
227  }
228 
229 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
230 
231 #define MICRO_MMA_PREFETCH_ONE(iter) \
232  if (unroll_factor > iter) { \
233  EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
234  }
235 
236 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
237 
238 #define MICRO_MMA_STORE_ONE(iter) \
239  if (unroll_factor > iter) { \
240  storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
241  }
242 
243 #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
244 
245 template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
246 EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
247  const DataMapper& res,
248  const Scalar* lhs_base,
249  const Scalar* rhs_base,
250  Index depth,
251  Index strideA,
252  Index& row,
253  const Packet& pAlpha)
254 {
255  const Scalar* rhs_ptr = rhs_base;
256  const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
257  __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
258 
259  MICRO_MMA_SRC_PTR
260  MICRO_MMA_DST_PTR
261 
262  Index k = 0;
263  for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
264  {
265  EIGEN_POWER_PREFETCH(rhs_ptr);
266  MICRO_MMA_PREFETCH
267  MICRO_MMA_ONE_PEEL
268  }
269  for(; k < depth; k++)
270  {
271  MICRO_MMA_ONE
272  }
273  MICRO_MMA_STORE
274 
275  row += unroll_factor*accCols;
276 }
277 
278 template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
279 EIGEN_ALWAYS_INLINE void gemmMMA_cols(
280  const DataMapper& res,
281  const Scalar* blockA,
282  const Scalar* blockB,
283  Index depth,
284  Index strideA,
285  Index offsetA,
286  Index strideB,
287  Index offsetB,
288  Index col,
289  Index rows,
290  Index cols,
291  Index remaining_rows,
292  const Packet& pAlpha,
293  const Packet& pMask)
294 {
295  const DataMapper res3 = res.getSubMapper(0, col);
296 
297  const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
298  const Scalar* lhs_base = blockA + accCols*offsetA;
299  Index row = 0;
300 
301 #define MAX_MMA_UNROLL 7
302  while(row + MAX_MMA_UNROLL*accCols <= rows) {
303  gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
304  }
305  switch( (rows-row)/accCols ) {
306 #if MAX_MMA_UNROLL > 7
307  case 7:
308  gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
309  break;
310 #endif
311 #if MAX_MMA_UNROLL > 6
312  case 6:
313  gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
314  break;
315 #endif
316 #if MAX_MMA_UNROLL > 5
317  case 5:
318  gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
319  break;
320 #endif
321 #if MAX_MMA_UNROLL > 4
322  case 4:
323  gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
324  break;
325 #endif
326 #if MAX_MMA_UNROLL > 3
327  case 3:
328  gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
329  break;
330 #endif
331 #if MAX_MMA_UNROLL > 2
332  case 2:
333  gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
334  break;
335 #endif
336 #if MAX_MMA_UNROLL > 1
337  case 1:
338  gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
339  break;
340 #endif
341  default:
342  break;
343  }
344 #undef MAX_MMA_UNROLL
345 
346  if(remaining_rows > 0)
347  {
348  gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
349  }
350 }
351 
352 template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
353 void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
354 {
355  const Index remaining_rows = rows % accCols;
356 
357  if( strideA == -1 ) strideA = depth;
358  if( strideB == -1 ) strideB = depth;
359 
360  const Packet pAlpha = pset1<Packet>(alpha);
361  const Packet pMask = bmask<Packet>((const int)(remaining_rows));
362 
363  Index col = 0;
364  for(; col + accRows <= cols; col += accRows)
365  {
366  gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
367  }
368 
369  gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
370 }
371 
372 #define accColsC (accCols / 2)
373 #define advanceRows ((LhsIsReal) ? 1 : 2)
374 #define advanceCols ((RhsIsReal) ? 1 : 2)
375 
376 // PEEL_COMPLEX_MMA loop factor.
377 #define PEEL_COMPLEX_MMA 3
378 
379 #define MICRO_COMPLEX_MMA_UNROLL(func) \
380  func(0) func(1) func(2) func(3)
381 
382 #define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
383  if (unroll_factor > iter) { \
384  lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
385  if(!LhsIsReal) { \
386  lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
387  } else { \
388  EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
389  } \
390  lhs_ptr_real##iter += accCols; \
391  } else { \
392  EIGEN_UNUSED_VARIABLE(lhsV##iter); \
393  EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
394  }
395 
396 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
397  if (unroll_factor > iter) { \
398  pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
399  }
400 
401 #define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
402  if (PEEL_COMPLEX_MMA > peel) { \
403  Packet lhsV0, lhsV1, lhsV2, lhsV3; \
404  Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
405  ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
406  if(!RhsIsReal) { \
407  ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
408  } else { \
409  EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
410  } \
411  MICRO_COMPLEX_MMA_UNROLL(func2); \
412  func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
413  } else { \
414  EIGEN_UNUSED_VARIABLE(rhsV##peel); \
415  EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
416  }
417 
418 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
419  type rhsV0, rhsV1, rhsV2, rhsV3; \
420  type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
421  MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
422  MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
423 
424 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
425  type rhsV0, rhsVi0; \
426  MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
427 
428 #define MICRO_COMPLEX_MMA_ONE_PEEL \
429  if (sizeof(Scalar) == sizeof(float)) { \
430  MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
431  } else { \
432  MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
433  } \
434  rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
435  if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
436 
437 #define MICRO_COMPLEX_MMA_ONE \
438  if (sizeof(Scalar) == sizeof(float)) { \
439  MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
440  } else { \
441  MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
442  } \
443  rhs_ptr_real += accRows; \
444  if(!RhsIsReal) rhs_ptr_imag += accRows;
445 
446 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
447  if (unroll_factor > iter) { \
448  bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
449  bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
450  } else { \
451  EIGEN_UNUSED_VARIABLE(accReal##iter); \
452  EIGEN_UNUSED_VARIABLE(accImag##iter); \
453  }
454 
455 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
456 
457 #define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
458  if (unroll_factor > iter) { \
459  lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
460  } else { \
461  EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
462  }
463 
464 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
465 
466 #define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
467  if (unroll_factor > iter) { \
468  EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
469  }
470 
471 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
472 
473 #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
474  if (unroll_factor > iter) { \
475  storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
476  }
477 
478 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
479 
480 template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
481 EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
482  const DataMapper& res,
483  const Scalar* lhs_base,
484  const Scalar* rhs_base,
485  Index depth,
486  Index strideA,
487  Index strideB,
488  Index& row,
489  const Packet& pAlphaReal,
490  const Packet& pAlphaImag)
491 {
492  const Scalar* rhs_ptr_real = rhs_base;
493  const Scalar* rhs_ptr_imag = NULL;
494  const Index imag_delta = accCols*strideA;
495  if(!RhsIsReal) {
496  rhs_ptr_imag = rhs_base + accRows*strideB;
497  } else {
498  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
499  }
500  const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
501  const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
502  __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
503 
504  MICRO_COMPLEX_MMA_SRC_PTR
505  MICRO_COMPLEX_MMA_DST_PTR
506 
507  Index k = 0;
508  for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
509  {
510  EIGEN_POWER_PREFETCH(rhs_ptr_real);
511  if(!RhsIsReal) {
512  EIGEN_POWER_PREFETCH(rhs_ptr_imag);
513  }
514  MICRO_COMPLEX_MMA_PREFETCH
515  MICRO_COMPLEX_MMA_ONE_PEEL
516  }
517  for(; k < depth; k++)
518  {
519  MICRO_COMPLEX_MMA_ONE
520  }
521  MICRO_COMPLEX_MMA_STORE
522 
523  row += unroll_factor*accCols;
524 }
525 
526 template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
527 EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
528  const DataMapper& res,
529  const Scalar* blockA,
530  const Scalar* blockB,
531  Index depth,
532  Index strideA,
533  Index offsetA,
534  Index strideB,
535  Index offsetB,
536  Index col,
537  Index rows,
538  Index cols,
539  Index remaining_rows,
540  const Packet& pAlphaReal,
541  const Packet& pAlphaImag,
542  const Packet& pMask)
543 {
544  const DataMapper res3 = res.getSubMapper(0, col);
545 
546  const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
547  const Scalar* lhs_base = blockA + accCols*offsetA;
548  Index row = 0;
549 
550 #define MAX_COMPLEX_MMA_UNROLL 4
551  while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
552  gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
553  }
554  switch( (rows-row)/accCols ) {
555 #if MAX_COMPLEX_MMA_UNROLL > 4
556  case 4:
557  gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
558  break;
559 #endif
560 #if MAX_COMPLEX_MMA_UNROLL > 3
561  case 3:
562  gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
563  break;
564 #endif
565 #if MAX_COMPLEX_MMA_UNROLL > 2
566  case 2:
567  gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
568  break;
569 #endif
570 #if MAX_COMPLEX_MMA_UNROLL > 1
571  case 1:
572  gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
573  break;
574 #endif
575  default:
576  break;
577  }
578 #undef MAX_COMPLEX_MMA_UNROLL
579 
580  if(remaining_rows > 0)
581  {
582  gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
583  }
584 }
585 
586 template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
587 void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
588 {
589  const Index remaining_rows = rows % accCols;
590 
591  if( strideA == -1 ) strideA = depth;
592  if( strideB == -1 ) strideB = depth;
593 
594  const Packet pAlphaReal = pset1<Packet>(alpha.real());
595  const Packet pAlphaImag = pset1<Packet>(alpha.imag());
596  const Packet pMask = bmask<Packet>((const int)(remaining_rows));
597 
598  const Scalar* blockA = (Scalar *) blockAc;
599  const Scalar* blockB = (Scalar *) blockBc;
600 
601  Index col = 0;
602  for(; col + accRows <= cols; col += accRows)
603  {
604  gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
605  }
606 
607  gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
608 }
609 
610 #undef accColsC
611 #undef advanceRows
612 #undef advanceCols
613 
614 #pragma GCC reset_options
615 } // end namespace internal
616 
617 } // end namespace Eigen
618 
619 #endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
620 
Namespace containing all symbols from the Eigen library.
Definition: Core:141
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Definition: Eigen_Colamd.h:50