38 #if defined(GETFEM_USES_BLAS) || defined(GMM_USES_BLAS) \
39 || defined(GMM_USES_LAPACK) || defined(GMM_USES_ATLAS)
41 #ifndef GMM_BLAS_INTERFACE_H
42 #define GMM_BLAS_INTERFACE_H
52 #define GMMLAPACK_TRACE(f)
55 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
57 #else // By default BLAS_INT will just be int in C
151 # define BLAS_S float
152 # define BLAS_D double
153 # define BLAS_C std::complex<float>
154 # define BLAS_Z std::complex<double>
160 void daxpy_(
const BLAS_INT *n,
const double *alpha,
const double *x,
161 const BLAS_INT *incx,
double *y,
const BLAS_INT *incy);
162 void dgemm_(
const char *tA,
const char *tB,
const BLAS_INT *m,
163 const BLAS_INT *n,
const BLAS_INT *k,
const double *alpha,
164 const double *A,
const BLAS_INT *ldA,
const double *B,
165 const BLAS_INT *ldB,
const double *beta,
double *C,
166 const BLAS_INT *ldC);
167 void sgemm_(...);
void cgemm_(...);
void zgemm_(...);
168 void sgemv_(...);
void dgemv_(...);
void cgemv_(...);
void zgemv_(...);
169 void strsv_(...);
void dtrsv_(...);
void ctrsv_(...);
void ztrsv_(...);
170 void saxpy_(...);
void caxpy_(...);
void zaxpy_(...);
171 BLAS_S sdot_ (...); BLAS_D ddot_ (...);
172 BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
173 BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
174 BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
175 BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
176 void sger_(...);
void dger_(...);
void cgerc_(...);
void zgerc_(...);
185 # define nrm2_interface(param1, trans1, blas_name, base_type) \
186 inline number_traits<base_type >::magnitude_type \
187 vect_norm2(param1(base_type)) { \
188 GMMLAPACK_TRACE("nrm2_interface"); \
189 BLAS_INT inc(1), n(BLAS_INT(vect_size(x))); trans1(base_type); \
190 return blas_name(&n, &x[0], &inc); \
193 # define nrm2_p1(base_type) const std::vector<base_type > &x
194 # define nrm2_trans1(base_type)
196 nrm2_interface(nrm2_p1, nrm2_trans1, snrm2_ , BLAS_S)
197 nrm2_interface(nrm2_p1, nrm2_trans1, dnrm2_ , BLAS_D)
198 nrm2_interface(nrm2_p1, nrm2_trans1, scnrm2_, BLAS_C)
199 nrm2_interface(nrm2_p1, nrm2_trans1, dznrm2_, BLAS_Z)
205 # define dot_interface(param1, trans1, mult1, param2, trans2, mult2, \
206 blas_name, base_type) \
207 inline base_type vect_sp(param1(base_type), param2(base_type)) { \
208 GMMLAPACK_TRACE("dot_interface"); \
209 trans1(base_type); trans2(base_type); \
210 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
211 return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
214 # define dot_p1(base_type) const std::vector<base_type > &x
215 # define dot_trans1(base_type)
216 # define dot_p1_s(base_type) \
217 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
218 # define dot_trans1_s(base_type) \
219 std::vector<base_type > &x = \
220 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
223 # define dot_p2(base_type) const std::vector<base_type > &y
224 # define dot_trans2(base_type)
225 # define dot_p2_s(base_type) \
226 const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
227 # define dot_trans2_s(base_type) \
228 std::vector<base_type > &y = \
229 const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
232 dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2, dot_trans2, (BLAS_S),
234 dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2, dot_trans2, (BLAS_D),
236 dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2, dot_trans2, (BLAS_C),
238 dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2, dot_trans2, (BLAS_Z),
241 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_S),
243 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_D),
245 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_C),
247 dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_Z),
250 dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2_s, dot_trans2_s, b*,
252 dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2_s, dot_trans2_s, b*,
254 dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2_s, dot_trans2_s, b*,
256 dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2_s, dot_trans2_s, b*,
259 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,sdot_ ,
261 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,ddot_ ,
263 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,cdotu_,
265 dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,zdotu_,
273 # define dotc_interface(param1, trans1, mult1, param2, trans2, mult2, \
274 blas_name, base_type) \
275 inline base_type vect_hp(param1(base_type), param2(base_type)) { \
276 GMMLAPACK_TRACE("dotc_interface"); \
277 trans1(base_type); trans2(base_type); \
278 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
279 return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
282 # define dotc_p1(base_type) const std::vector<base_type > &x
283 # define dotc_trans1(base_type)
284 # define dotc_p1_s(base_type) \
285 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
286 # define dotc_trans1_s(base_type) \
287 std::vector<base_type > &x = \
288 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
291 # define dotc_p2(base_type) const std::vector<base_type > &y
292 # define dotc_trans2(base_type)
293 # define dotc_p2_s(base_type) \
294 const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
295 # define dotc_trans2_s(base_type) \
296 std::vector<base_type > &y = \
297 const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
298 base_type b(gmm::conj(y_.r))
300 dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2, dotc_trans2,
301 (BLAS_S),sdot_ ,BLAS_S)
302 dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2, dotc_trans2,
303 (BLAS_D),ddot_ ,BLAS_D)
304 dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2, dotc_trans2,
305 (BLAS_C),cdotc_,BLAS_C)
306 dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2, dotc_trans2,
307 (BLAS_Z),zdotc_,BLAS_Z)
309 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
310 (BLAS_S),sdot_, BLAS_S)
311 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
312 (BLAS_D),ddot_ , BLAS_D)
313 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
314 (BLAS_C),cdotc_, BLAS_C)
315 dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
316 (BLAS_Z),zdotc_, BLAS_Z)
318 dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2_s, dotc_trans2_s,
320 dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2_s, dotc_trans2_s,
322 dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2_s, dotc_trans2_s,
324 dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2_s, dotc_trans2_s,
327 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,sdot_ ,
329 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,ddot_ ,
331 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,cdotc_,
333 dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,zdotc_,
339 template<
size_type N,
class V1,
class V2>
340 inline void add_fixed(
const V1 &x, V2 &y)
342 for(
size_type i = 0; i != N; ++i) y[i] += x[i];
345 template<
class V1,
class V2>
346 inline void add_for_short_vectors(
const V1 &x, V2 &y,
size_type n)
350 case 1: add_fixed<1>(x, y);
break;
351 case 2: add_fixed<2>(x, y);
break;
352 case 3: add_fixed<3>(x, y);
break;
353 case 4: add_fixed<4>(x, y);
break;
354 case 5: add_fixed<5>(x, y);
break;
355 case 6: add_fixed<6>(x, y);
break;
356 case 7: add_fixed<7>(x, y);
break;
357 case 8: add_fixed<8>(x, y);
break;
358 case 9: add_fixed<9>(x, y);
break;
359 case 10: add_fixed<10>(x, y);
break;
360 case 11: add_fixed<11>(x, y);
break;
361 case 12: add_fixed<12>(x, y);
break;
362 case 13: add_fixed<13>(x, y);
break;
363 case 14: add_fixed<14>(x, y);
break;
364 case 15: add_fixed<15>(x, y);
break;
365 case 16: add_fixed<16>(x, y);
break;
366 case 17: add_fixed<17>(x, y);
break;
367 case 18: add_fixed<18>(x, y);
break;
368 case 19: add_fixed<19>(x, y);
break;
369 case 20: add_fixed<20>(x, y);
break;
370 case 21: add_fixed<21>(x, y);
break;
371 case 22: add_fixed<22>(x, y);
break;
372 case 23: add_fixed<23>(x, y);
break;
373 case 24: add_fixed<24>(x, y);
break;
374 default: GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
break;
378 template<
size_type N,
class V1,
class V2,
class T>
379 inline void add_fixed(
const V1 &x, V2 &y,
const T &a)
381 for(
size_type i = 0; i != N; ++i) y[i] += a*x[i];
384 template<
class V1,
class V2,
class T>
385 inline void add_for_short_vectors(
const V1 &x, V2 &y,
const T &a,
size_type n)
389 case 1: add_fixed<1>(x, y, a);
break;
390 case 2: add_fixed<2>(x, y, a);
break;
391 case 3: add_fixed<3>(x, y, a);
break;
392 case 4: add_fixed<4>(x, y, a);
break;
393 case 5: add_fixed<5>(x, y, a);
break;
394 case 6: add_fixed<6>(x, y, a);
break;
395 case 7: add_fixed<7>(x, y, a);
break;
396 case 8: add_fixed<8>(x, y, a);
break;
397 case 9: add_fixed<9>(x, y, a);
break;
398 case 10: add_fixed<10>(x, y, a);
break;
399 case 11: add_fixed<11>(x, y, a);
break;
400 case 12: add_fixed<12>(x, y, a);
break;
401 case 13: add_fixed<13>(x, y, a);
break;
402 case 14: add_fixed<14>(x, y, a);
break;
403 case 15: add_fixed<15>(x, y, a);
break;
404 case 16: add_fixed<16>(x, y, a);
break;
405 case 17: add_fixed<17>(x, y, a);
break;
406 case 18: add_fixed<18>(x, y, a);
break;
407 case 19: add_fixed<19>(x, y, a);
break;
408 case 20: add_fixed<20>(x, y, a);
break;
409 case 21: add_fixed<21>(x, y, a);
break;
410 case 22: add_fixed<22>(x, y, a);
break;
411 case 23: add_fixed<23>(x, y, a);
break;
412 case 24: add_fixed<24>(x, y, a);
break;
413 default: GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
break;
418 # define axpy_interface(param1, trans1, blas_name, base_type) \
419 inline void add(param1(base_type), std::vector<base_type > &y) { \
420 GMMLAPACK_TRACE("axpy_interface"); \
421 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); trans1(base_type); \
423 else if(n < 25) add_for_short_vectors(x, y, n); \
424 else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
427 # define axpy2_interface(param1, trans1, blas_name, base_type) \
428 inline void add(param1(base_type), std::vector<base_type > &y) { \
429 GMMLAPACK_TRACE("axpy_interface"); \
430 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); trans1(base_type); \
432 else if(n < 25) add_for_short_vectors(x, y, a, n); \
433 else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
436 # define axpy_p1(base_type) const std::vector<base_type > &x
437 # define axpy_trans1(base_type) base_type a(1)
438 # define axpy_p1_s(base_type) \
439 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
440 # define axpy_trans1_s(base_type) \
441 std::vector<base_type > &x = \
442 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
445 axpy_interface(axpy_p1, axpy_trans1, saxpy_, BLAS_S)
446 axpy_interface(axpy_p1, axpy_trans1, daxpy_, BLAS_D)
447 axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C)
448 axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z)
450 axpy2_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
451 axpy2_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
452 axpy2_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
453 axpy2_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
460 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
462 inline void mult_add_spec(param1(base_type), param2(base_type), \
463 std::vector<base_type > &z, orien) { \
464 GMMLAPACK_TRACE("gemv_interface"); \
465 trans1(base_type); trans2(base_type); base_type beta(1); \
466 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
467 BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
468 if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
469 &beta, &z[0], &inc); \
470 else gmm::clear(z); \
474 # define gem_p1_n(base_type) const dense_matrix<base_type > &A
475 # define gem_trans1_n(base_type) const char t = 'N'
476 # define gem_p1_t(base_type) \
477 const transposed_col_ref<dense_matrix<base_type > *> &A_
478 # define gem_trans1_t(base_type) dense_matrix<base_type > &A = \
479 const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
481 # define gem_p1_tc(base_type) \
482 const transposed_col_ref<const dense_matrix<base_type > *> &A_
483 # define gem_p1_c(base_type) \
484 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_
485 # define gem_trans1_c(base_type) dense_matrix<base_type > &A = \
486 const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
490 # define gemv_p2_n(base_type) const std::vector<base_type > &x
491 # define gemv_trans2_n(base_type) base_type alpha(1)
492 # define gemv_p2_s(base_type) \
493 const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
494 # define gemv_trans2_s(base_type) std::vector<base_type > &x = \
495 const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
496 base_type alpha(x_.r)
499 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
501 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
503 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
505 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
509 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
511 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
513 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
515 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
519 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
521 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
523 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
525 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
529 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
531 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
533 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
535 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
539 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
541 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
543 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
545 gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
549 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
551 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
553 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
555 gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
559 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
561 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
563 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
565 gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
569 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
571 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
573 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
575 gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
583 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
585 inline void mult_spec(param1(base_type), param2(base_type), \
586 std::vector<base_type > &z, orien) { \
587 GMMLAPACK_TRACE("gemv_interface2"); \
588 trans1(base_type); trans2(base_type); base_type beta(0); \
589 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
590 BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
592 blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
594 else gmm::clear(z); \
598 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
600 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
602 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
604 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
608 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
610 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
612 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
614 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
618 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
620 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
622 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
624 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
628 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
630 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
632 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
634 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
638 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
640 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
642 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
644 gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
648 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
650 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
652 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
654 gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
658 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
660 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
662 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
664 gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
668 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
670 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
672 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
674 gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
682 # define ger_interface(blas_name, base_type) \
683 inline void rank_one_update(const dense_matrix<base_type > &A, \
684 const std::vector<base_type > &V, \
685 const std::vector<base_type > &W) { \
686 GMMLAPACK_TRACE("ger_interface"); \
687 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
688 BLAS_INT n(BLAS_INT(mat_ncols(A))); \
689 BLAS_INT incx = 1, incy = 1; \
690 base_type alpha(1); \
692 blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\
695 ger_interface(sger_, BLAS_S)
696 ger_interface(dger_, BLAS_D)
697 ger_interface(cgerc_, BLAS_C)
698 ger_interface(zgerc_, BLAS_Z)
700 # define ger_interface_sn(blas_name, base_type) \
701 inline void rank_one_update(const dense_matrix<base_type > &A, \
702 gemv_p2_s(base_type), \
703 const std::vector<base_type > &W) { \
704 GMMLAPACK_TRACE("ger_interface"); \
705 gemv_trans2_s(base_type); \
706 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
707 BLAS_INT n(BLAS_INT(mat_ncols(A))); \
708 BLAS_INT incx = 1, incy = 1; \
710 blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\
713 ger_interface_sn(sger_, BLAS_S)
714 ger_interface_sn(dger_, BLAS_D)
715 ger_interface_sn(cgerc_, BLAS_C)
716 ger_interface_sn(zgerc_, BLAS_Z)
718 # define ger_interface_ns(blas_name, base_type) \
719 inline void rank_one_update(const dense_matrix<base_type > &A, \
720 const std::vector<base_type > &V, \
721 gemv_p2_s(base_type)) { \
722 GMMLAPACK_TRACE("ger_interface"); \
723 gemv_trans2_s(base_type); \
724 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
725 BLAS_INT n(BLAS_INT(mat_ncols(A))); \
726 BLAS_INT incx = 1, incy = 1; \
727 base_type al2 = gmm::conj(alpha); \
729 blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \
732 ger_interface_ns(sger_, BLAS_S)
733 ger_interface_ns(dger_, BLAS_D)
734 ger_interface_ns(cgerc_, BLAS_C)
735 ger_interface_ns(zgerc_, BLAS_Z)
741 # define gemm_interface_nn(blas_name, base_type) \
742 inline void mult_spec(const dense_matrix<base_type > &A, \
743 const dense_matrix<base_type > &B, \
744 dense_matrix<base_type > &C, c_mult) { \
745 GMMLAPACK_TRACE("gemm_interface_nn"); \
746 const char t = 'N'; \
747 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
748 BLAS_INT k(BLAS_INT(mat_ncols(A))); \
749 BLAS_INT n(BLAS_INT(mat_ncols(B))); \
750 BLAS_INT ldb = k, ldc = m; \
751 base_type alpha(1), beta(0); \
753 blas_name(&t, &t, &m, &n, &k, &alpha, \
754 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
755 else gmm::clear(C); \
758 gemm_interface_nn(sgemm_, BLAS_S)
759 gemm_interface_nn(dgemm_, BLAS_D)
760 gemm_interface_nn(cgemm_, BLAS_C)
761 gemm_interface_nn(zgemm_, BLAS_Z)
767 # define gemm_interface_tn(blas_name, base_type, is_const) \
768 inline void mult_spec( \
769 const transposed_col_ref<is_const<base_type > *> &A_,\
770 const dense_matrix<base_type > &B, \
771 dense_matrix<base_type > &C, rcmult) { \
772 GMMLAPACK_TRACE("gemm_interface_tn"); \
773 dense_matrix<base_type > &A \
774 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
775 const char t = 'T', u = 'N'; \
776 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
777 BLAS_INT n(BLAS_INT(mat_ncols(B))); \
778 BLAS_INT lda = k, ldb = k, ldc = m; \
779 base_type alpha(1), beta(0); \
781 blas_name(&t, &u, &m, &n, &k, &alpha, \
782 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
783 else gmm::clear(C); \
786 gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
787 gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
788 gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
789 gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
790 gemm_interface_tn(sgemm_, BLAS_S,
const dense_matrix)
791 gemm_interface_tn(dgemm_, BLAS_D,
const dense_matrix)
792 gemm_interface_tn(cgemm_, BLAS_C,
const dense_matrix)
793 gemm_interface_tn(zgemm_, BLAS_Z,
const dense_matrix)
799 # define gemm_interface_nt(blas_name, base_type, is_const) \
800 inline void mult_spec(const dense_matrix<base_type > &A, \
801 const transposed_col_ref<is_const<base_type > *> &B_, \
802 dense_matrix<base_type > &C, r_mult) { \
803 GMMLAPACK_TRACE("gemm_interface_nt"); \
804 dense_matrix<base_type > &B \
805 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
806 const char t = 'N', u = 'T'; \
807 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
808 BLAS_INT k(BLAS_INT(mat_ncols(A))); \
809 BLAS_INT n(BLAS_INT(mat_nrows(B))); \
810 BLAS_INT ldb = n, ldc = m; \
811 base_type alpha(1), beta(0); \
813 blas_name(&t, &u, &m, &n, &k, &alpha, \
814 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
815 else gmm::clear(C); \
818 gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
819 gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
820 gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
821 gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
822 gemm_interface_nt(sgemm_, BLAS_S,
const dense_matrix)
823 gemm_interface_nt(dgemm_, BLAS_D,
const dense_matrix)
824 gemm_interface_nt(cgemm_, BLAS_C,
const dense_matrix)
825 gemm_interface_nt(zgemm_, BLAS_Z,
const dense_matrix)
831 # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \
832 inline void mult_spec( \
833 const transposed_col_ref<isA_const <base_type > *> &A_, \
834 const transposed_col_ref<isB_const <base_type > *> &B_, \
835 dense_matrix<base_type > &C, r_mult) { \
836 GMMLAPACK_TRACE("gemm_interface_tt"); \
837 dense_matrix<base_type > &A \
838 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
839 dense_matrix<base_type > &B \
840 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
841 const char t = 'T', u = 'T'; \
842 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
843 BLAS_INT n(BLAS_INT(mat_nrows(B))); \
844 BLAS_INT lda = k, ldb = n, ldc = m; \
845 base_type alpha(1), beta(0); \
847 blas_name(&t, &u, &m, &n, &k, &alpha, \
848 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
849 else gmm::clear(C); \
852 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
853 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
854 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
855 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
856 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix, dense_matrix)
857 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix, dense_matrix)
858 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix, dense_matrix)
859 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix, dense_matrix)
860 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix,
const dense_matrix)
861 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix,
const dense_matrix)
862 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix,
const dense_matrix)
863 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix,
const dense_matrix)
864 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix,
const dense_matrix)
865 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix,
const dense_matrix)
866 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix,
const dense_matrix)
867 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix,
const dense_matrix)
874 # define gemm_interface_cn(blas_name, base_type) \
875 inline void mult_spec( \
876 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
877 const dense_matrix<base_type > &B, \
878 dense_matrix<base_type > &C, rcmult) { \
879 GMMLAPACK_TRACE("gemm_interface_cn"); \
880 dense_matrix<base_type > &A \
881 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
882 const char t = 'C', u = 'N'; \
883 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
884 BLAS_INT n(BLAS_INT(mat_ncols(B))); \
885 BLAS_INT lda = k, ldb = k, ldc = m; \
886 base_type alpha(1), beta(0); \
888 blas_name(&t, &u, &m, &n, &k, &alpha, \
889 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
890 else gmm::clear(C); \
893 gemm_interface_cn(sgemm_, BLAS_S)
894 gemm_interface_cn(dgemm_, BLAS_D)
895 gemm_interface_cn(cgemm_, BLAS_C)
896 gemm_interface_cn(zgemm_, BLAS_Z)
902 # define gemm_interface_nc(blas_name, base_type) \
903 inline void mult_spec(const dense_matrix<base_type > &A, \
904 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
905 dense_matrix<base_type > &C, c_mult, row_major) { \
906 GMMLAPACK_TRACE("gemm_interface_nc"); \
907 dense_matrix<base_type > &B \
908 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
909 const char t = 'N', u = 'C'; \
910 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
911 BLAS_INT k(BLAS_INT(mat_ncols(A))); \
912 BLAS_INT n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
913 base_type alpha(1), beta(0); \
915 blas_name(&t, &u, &m, &n, &k, &alpha, \
916 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
917 else gmm::clear(C); \
920 gemm_interface_nc(sgemm_, BLAS_S)
921 gemm_interface_nc(dgemm_, BLAS_D)
922 gemm_interface_nc(cgemm_, BLAS_C)
923 gemm_interface_nc(zgemm_, BLAS_Z)
929 # define gemm_interface_cc(blas_name, base_type) \
930 inline void mult_spec( \
931 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
932 const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
933 dense_matrix<base_type > &C, r_mult) { \
934 GMMLAPACK_TRACE("gemm_interface_cc"); \
935 dense_matrix<base_type > &A \
936 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
937 dense_matrix<base_type > &B \
938 = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
939 const char t = 'C', u = 'C'; \
940 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
941 BLAS_INT lda = k, n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
942 base_type alpha(1), beta(0); \
944 blas_name(&t, &u, &m, &n, &k, &alpha, \
945 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
946 else gmm::clear(C); \
949 gemm_interface_cc(sgemm_, BLAS_S)
950 gemm_interface_cc(dgemm_, BLAS_D)
951 gemm_interface_cc(cgemm_, BLAS_C)
952 gemm_interface_cc(zgemm_, BLAS_Z)
958 # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\
959 inline void f_name(param1(base_type), std::vector<base_type > &x, \
960 size_type k, bool is_unit) { \
961 GMMLAPACK_TRACE("trsv_interface"); \
962 loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
963 BLAS_INT lda(BLAS_INT(mat_nrows(A))), inc(1), n = BLAS_INT(k); \
964 if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
967 # define trsv_upper const char l = 'U'
968 # define trsv_lower const char l = 'L'
971 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
973 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
975 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
977 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
981 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
983 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
985 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
987 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
991 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
993 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
995 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
997 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
1001 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1003 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1005 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1007 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1011 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1013 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1015 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1017 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1021 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1023 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1025 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1027 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1031 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1033 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1035 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1037 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1041 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1043 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1045 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1047 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1053 #endif // GMM_BLAS_INTERFACE_H
1055 #endif // GMM_USES_BLAS