GetFEM  5.4.2
gmm_blas_interface.h
Go to the documentation of this file.
1 /* -*- c++ -*- (enables emacs c++ mode) */
2 /*===========================================================================
3 
4  Copyright (C) 2003-2020 Yves Renard
5 
6  This file is a part of GetFEM
7 
8  GetFEM is free software; you can redistribute it and/or modify it
9  under the terms of the GNU Lesser General Public License as published
10  by the Free Software Foundation; either version 3 of the License, or
11  (at your option) any later version along with the GCC Runtime Library
12  Exception either version 3.1 or (at your option) any later version.
13  This program is distributed in the hope that it will be useful, but
14  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15  or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
16  License and GCC Runtime Library Exception for more details.
17  You should have received a copy of the GNU Lesser General Public License
18  along with this program; if not, write to the Free Software Foundation,
19  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.
20 
21  As a special exception, you may use this file as it is a part of a free
22  software library without restriction. Specifically, if other files
23  instantiate templates or use macros or inline functions from this file,
24  or you compile this file and link it with other files to produce an
25  executable, this file does not by itself cause the resulting executable
26  to be covered by the GNU Lesser General Public License. This exception
27  does not however invalidate any other reasons why the executable file
28  might be covered by the GNU Lesser General Public License.
29 
30 ===========================================================================*/
31 
32 /**@file gmm_blas_interface.h
33  @author Yves Renard <Yves.Renard@insa-lyon.fr>
34  @date October 7, 2003.
35  @brief gmm interface for fortran BLAS.
36 */
37 
38 #if defined(GETFEM_USES_BLAS) || defined(GMM_USES_BLAS) \
39  || defined(GMM_USES_LAPACK) || defined(GMM_USES_ATLAS)
40 
41 #ifndef GMM_BLAS_INTERFACE_H
42 #define GMM_BLAS_INTERFACE_H
43 
44 #include "gmm_blas.h"
45 #include "gmm_interface.h"
46 #include "gmm_matrix.h"
47 
48 namespace gmm {
49 
50  // Use ./configure --enable-blas-interface to activate this interface.
51 
52 #define GMMLAPACK_TRACE(f)
53  // #define GMMLAPACK_TRACE(f) cout << "function " << f << " called" << endl;
54 
55 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
56  #define BLAS_INT long
57 #else // By default BLAS_INT will just be int in C
58  #define BLAS_INT int
59 #endif
60 
61  /* ********************************************************************* */
62  /* Operations interfaced for T = float, double, std::complex<float> */
63  /* or std::complex<double> : */
64  /* */
65  /* vect_norm2(std::vector<T>) */
66  /* */
67  /* vect_sp(std::vector<T>, std::vector<T>) */
68  /* vect_sp(scaled(std::vector<T>), std::vector<T>) */
69  /* vect_sp(std::vector<T>, scaled(std::vector<T>)) */
70  /* vect_sp(scaled(std::vector<T>), scaled(std::vector<T>)) */
71  /* */
72  /* vect_hp(std::vector<T>, std::vector<T>) */
73  /* vect_hp(scaled(std::vector<T>), std::vector<T>) */
74  /* vect_hp(std::vector<T>, scaled(std::vector<T>)) */
75  /* vect_hp(scaled(std::vector<T>), scaled(std::vector<T>)) */
76  /* */
77  /* add(std::vector<T>, std::vector<T>) */
78  /* add(scaled(std::vector<T>, a), std::vector<T>) */
79  /* */
80  /* mult(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
81  /* mult(transposed(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
82  /* mult(dense_matrix<T>, transposed(dense_matrix<T>), dense_matrix<T>) */
83  /* mult(transposed(dense_matrix<T>), transposed(dense_matrix<T>), */
84  /* dense_matrix<T>) */
85  /* mult(conjugated(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
86  /* mult(dense_matrix<T>, conjugated(dense_matrix<T>), dense_matrix<T>) */
87  /* mult(conjugated(dense_matrix<T>), conjugated(dense_matrix<T>), */
88  /* dense_matrix<T>) */
89  /* */
90  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>) */
91  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
92  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
93  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
94  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
95  /* std::vector<T>) */
96  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
97  /* std::vector<T>) */
98  /* */
99  /* mult_add(dense_matrix<T>, std::vector<T>, std::vector<T>) */
100  /* mult_add(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
101  /* mult_add(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
102  /* mult_add(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
103  /* mult_add(transposed(dense_matrix<T>), scaled(std::vector<T>), */
104  /* std::vector<T>) */
105  /* mult_add(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
106  /* std::vector<T>) */
107  /* */
108  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>, std::vector<T>) */
109  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>, */
110  /* std::vector<T>) */
111  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>, */
112  /* std::vector<T>) */
113  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>, */
114  /* std::vector<T>) */
115  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
116  /* std::vector<T>, std::vector<T>) */
117  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
118  /* std::vector<T>, std::vector<T>) */
119  /* mult(dense_matrix<T>, std::vector<T>, scaled(std::vector<T>), */
120  /* std::vector<T>) */
121  /* mult(transposed(dense_matrix<T>), std::vector<T>, */
122  /* scaled(std::vector<T>), std::vector<T>) */
123  /* mult(conjugated(dense_matrix<T>), std::vector<T>, */
124  /* scaled(std::vector<T>), std::vector<T>) */
125  /* mult(dense_matrix<T>, scaled(std::vector<T>), scaled(std::vector<T>), */
126  /* std::vector<T>) */
127  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
128  /* scaled(std::vector<T>), std::vector<T>) */
129  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
130  /* scaled(std::vector<T>), std::vector<T>) */
131  /* */
132  /* lower_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
133  /* upper_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
134  /* lower_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
135  /* upper_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
136  /* lower_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
137  /* upper_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
138  /* */
139  /* rank_one_update(dense_matrix<T>, std::vector<T>, std::vector<T>) */
140  /* rank_one_update(dense_matrix<T>, scaled(std::vector<T>), */
141  /* std::vector<T>) */
142  /* rank_one_update(dense_matrix<T>, std::vector<T>, */
143  /* scaled(std::vector<T>)) */
144  /* */
145  /* ********************************************************************* */
146 
147  /* ********************************************************************* */
148  /* Basic defines. */
149  /* ********************************************************************* */
150 
151 # define BLAS_S float
152 # define BLAS_D double
153 # define BLAS_C std::complex<float>
154 # define BLAS_Z std::complex<double>
155 
156  /* ********************************************************************* */
157  /* BLAS functions used. */
158  /* ********************************************************************* */
159  extern "C" {
160  void daxpy_(const BLAS_INT *n, const double *alpha, const double *x,
161  const BLAS_INT *incx, double *y, const BLAS_INT *incy);
162  void dgemm_(const char *tA, const char *tB, const BLAS_INT *m,
163  const BLAS_INT *n, const BLAS_INT *k, const double *alpha,
164  const double *A, const BLAS_INT *ldA, const double *B,
165  const BLAS_INT *ldB, const double *beta, double *C,
166  const BLAS_INT *ldC);
167  void sgemm_(...); void cgemm_(...); void zgemm_(...);
168  void sgemv_(...); void dgemv_(...); void cgemv_(...); void zgemv_(...);
169  void strsv_(...); void dtrsv_(...); void ctrsv_(...); void ztrsv_(...);
170  void saxpy_(...); /*void daxpy_(...); */void caxpy_(...); void zaxpy_(...);
171  BLAS_S sdot_ (...); BLAS_D ddot_ (...);
172  BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
173  BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
174  BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
175  BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
176  void sger_(...); void dger_(...); void cgerc_(...); void zgerc_(...);
177  }
178 
179 #if 1
180 
181  /* ********************************************************************* */
182  /* vect_norm2(x). */
183  /* ********************************************************************* */
184 
185 # define nrm2_interface(param1, trans1, blas_name, base_type) \
186  inline number_traits<base_type >::magnitude_type \
187  vect_norm2(param1(base_type)) { \
188  GMMLAPACK_TRACE("nrm2_interface"); \
189  BLAS_INT inc(1), n(BLAS_INT(vect_size(x))); trans1(base_type); \
190  return blas_name(&n, &x[0], &inc); \
191  }
192 
193 # define nrm2_p1(base_type) const std::vector<base_type > &x
194 # define nrm2_trans1(base_type)
195 
196  nrm2_interface(nrm2_p1, nrm2_trans1, snrm2_ , BLAS_S)
197  nrm2_interface(nrm2_p1, nrm2_trans1, dnrm2_ , BLAS_D)
198  nrm2_interface(nrm2_p1, nrm2_trans1, scnrm2_, BLAS_C)
199  nrm2_interface(nrm2_p1, nrm2_trans1, dznrm2_, BLAS_Z)
200 
201  /* ********************************************************************* */
202  /* vect_sp(x, y). */
203  /* ********************************************************************* */
204 
205 # define dot_interface(param1, trans1, mult1, param2, trans2, mult2, \
206  blas_name, base_type) \
207  inline base_type vect_sp(param1(base_type), param2(base_type)) { \
208  GMMLAPACK_TRACE("dot_interface"); \
209  trans1(base_type); trans2(base_type); \
210  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
211  return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
212  }
213 
214 # define dot_p1(base_type) const std::vector<base_type > &x
215 # define dot_trans1(base_type)
216 # define dot_p1_s(base_type) \
217  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
218 # define dot_trans1_s(base_type) \
219  std::vector<base_type > &x = \
220  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
221  base_type a(x_.r)
222 
223 # define dot_p2(base_type) const std::vector<base_type > &y
224 # define dot_trans2(base_type)
225 # define dot_p2_s(base_type) \
226  const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
227 # define dot_trans2_s(base_type) \
228  std::vector<base_type > &y = \
229  const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
230  base_type b(y_.r)
231 
232  dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2, dot_trans2, (BLAS_S),
233  sdot_ , BLAS_S)
234  dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2, dot_trans2, (BLAS_D),
235  ddot_ , BLAS_D)
236  dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2, dot_trans2, (BLAS_C),
237  cdotu_, BLAS_C)
238  dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2, dot_trans2, (BLAS_Z),
239  zdotu_, BLAS_Z)
240 
241  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_S),
242  sdot_ ,BLAS_S)
243  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_D),
244  ddot_ ,BLAS_D)
245  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_C),
246  cdotu_,BLAS_C)
247  dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_Z),
248  zdotu_,BLAS_Z)
249 
250  dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2_s, dot_trans2_s, b*,
251  sdot_ ,BLAS_S)
252  dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2_s, dot_trans2_s, b*,
253  ddot_ ,BLAS_D)
254  dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2_s, dot_trans2_s, b*,
255  cdotu_,BLAS_C)
256  dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2_s, dot_trans2_s, b*,
257  zdotu_,BLAS_Z)
258 
259  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,sdot_ ,
260  BLAS_S)
261  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,ddot_ ,
262  BLAS_D)
263  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,cdotu_,
264  BLAS_C)
265  dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,zdotu_,
266  BLAS_Z)
267 
268 
269  /* ********************************************************************* */
270  /* vect_hp(x, y). */
271  /* ********************************************************************* */
272 
273 # define dotc_interface(param1, trans1, mult1, param2, trans2, mult2, \
274  blas_name, base_type) \
275  inline base_type vect_hp(param1(base_type), param2(base_type)) { \
276  GMMLAPACK_TRACE("dotc_interface"); \
277  trans1(base_type); trans2(base_type); \
278  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
279  return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
280  }
281 
282 # define dotc_p1(base_type) const std::vector<base_type > &x
283 # define dotc_trans1(base_type)
284 # define dotc_p1_s(base_type) \
285  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
286 # define dotc_trans1_s(base_type) \
287  std::vector<base_type > &x = \
288  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
289  base_type a(x_.r)
290 
291 # define dotc_p2(base_type) const std::vector<base_type > &y
292 # define dotc_trans2(base_type)
293 # define dotc_p2_s(base_type) \
294  const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
295 # define dotc_trans2_s(base_type) \
296  std::vector<base_type > &y = \
297  const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
298  base_type b(gmm::conj(y_.r))
299 
300  dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2, dotc_trans2,
301  (BLAS_S),sdot_ ,BLAS_S)
302  dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2, dotc_trans2,
303  (BLAS_D),ddot_ ,BLAS_D)
304  dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2, dotc_trans2,
305  (BLAS_C),cdotc_,BLAS_C)
306  dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2, dotc_trans2,
307  (BLAS_Z),zdotc_,BLAS_Z)
308 
309  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
310  (BLAS_S),sdot_, BLAS_S)
311  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
312  (BLAS_D),ddot_ , BLAS_D)
313  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
314  (BLAS_C),cdotc_, BLAS_C)
315  dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
316  (BLAS_Z),zdotc_, BLAS_Z)
317 
318  dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2_s, dotc_trans2_s,
319  b*,sdot_ , BLAS_S)
320  dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2_s, dotc_trans2_s,
321  b*,ddot_ , BLAS_D)
322  dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2_s, dotc_trans2_s,
323  b*,cdotc_, BLAS_C)
324  dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2_s, dotc_trans2_s,
325  b*,zdotc_, BLAS_Z)
326 
327  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,sdot_ ,
328  BLAS_S)
329  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,ddot_ ,
330  BLAS_D)
331  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,cdotc_,
332  BLAS_C)
333  dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,zdotc_,
334  BLAS_Z)
335 
336  /* ********************************************************************* */
337  /* add(x, y). */
338  /* ********************************************************************* */
339  template<size_type N, class V1, class V2>
340  inline void add_fixed(const V1 &x, V2 &y)
341  {
342  for(size_type i = 0; i != N; ++i) y[i] += x[i];
343  }
344 
345  template<class V1, class V2>
346  inline void add_for_short_vectors(const V1 &x, V2 &y, size_type n)
347  {
348  switch(n)
349  {
350  case 1: add_fixed<1>(x, y); break;
351  case 2: add_fixed<2>(x, y); break;
352  case 3: add_fixed<3>(x, y); break;
353  case 4: add_fixed<4>(x, y); break;
354  case 5: add_fixed<5>(x, y); break;
355  case 6: add_fixed<6>(x, y); break;
356  case 7: add_fixed<7>(x, y); break;
357  case 8: add_fixed<8>(x, y); break;
358  case 9: add_fixed<9>(x, y); break;
359  case 10: add_fixed<10>(x, y); break;
360  case 11: add_fixed<11>(x, y); break;
361  case 12: add_fixed<12>(x, y); break;
362  case 13: add_fixed<13>(x, y); break;
363  case 14: add_fixed<14>(x, y); break;
364  case 15: add_fixed<15>(x, y); break;
365  case 16: add_fixed<16>(x, y); break;
366  case 17: add_fixed<17>(x, y); break;
367  case 18: add_fixed<18>(x, y); break;
368  case 19: add_fixed<19>(x, y); break;
369  case 20: add_fixed<20>(x, y); break;
370  case 21: add_fixed<21>(x, y); break;
371  case 22: add_fixed<22>(x, y); break;
372  case 23: add_fixed<23>(x, y); break;
373  case 24: add_fixed<24>(x, y); break;
374  default: GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size"); break;
375  }
376  }
377 
378  template<size_type N, class V1, class V2, class T>
379  inline void add_fixed(const V1 &x, V2 &y, const T &a)
380  {
381  for(size_type i = 0; i != N; ++i) y[i] += a*x[i];
382  }
383 
384  template<class V1, class V2, class T>
385  inline void add_for_short_vectors(const V1 &x, V2 &y, const T &a, size_type n)
386  {
387  switch(n)
388  {
389  case 1: add_fixed<1>(x, y, a); break;
390  case 2: add_fixed<2>(x, y, a); break;
391  case 3: add_fixed<3>(x, y, a); break;
392  case 4: add_fixed<4>(x, y, a); break;
393  case 5: add_fixed<5>(x, y, a); break;
394  case 6: add_fixed<6>(x, y, a); break;
395  case 7: add_fixed<7>(x, y, a); break;
396  case 8: add_fixed<8>(x, y, a); break;
397  case 9: add_fixed<9>(x, y, a); break;
398  case 10: add_fixed<10>(x, y, a); break;
399  case 11: add_fixed<11>(x, y, a); break;
400  case 12: add_fixed<12>(x, y, a); break;
401  case 13: add_fixed<13>(x, y, a); break;
402  case 14: add_fixed<14>(x, y, a); break;
403  case 15: add_fixed<15>(x, y, a); break;
404  case 16: add_fixed<16>(x, y, a); break;
405  case 17: add_fixed<17>(x, y, a); break;
406  case 18: add_fixed<18>(x, y, a); break;
407  case 19: add_fixed<19>(x, y, a); break;
408  case 20: add_fixed<20>(x, y, a); break;
409  case 21: add_fixed<21>(x, y, a); break;
410  case 22: add_fixed<22>(x, y, a); break;
411  case 23: add_fixed<23>(x, y, a); break;
412  case 24: add_fixed<24>(x, y, a); break;
413  default: GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size"); break;
414  }
415  }
416 
417 
418 # define axpy_interface(param1, trans1, blas_name, base_type) \
419  inline void add(param1(base_type), std::vector<base_type > &y) { \
420  GMMLAPACK_TRACE("axpy_interface"); \
421  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); trans1(base_type); \
422  if(n == 0) return; \
423  else if(n < 25) add_for_short_vectors(x, y, n); \
424  else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
425  }
426 
427 # define axpy2_interface(param1, trans1, blas_name, base_type) \
428  inline void add(param1(base_type), std::vector<base_type > &y) { \
429  GMMLAPACK_TRACE("axpy_interface"); \
430  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); trans1(base_type); \
431  if(n == 0) return; \
432  else if(n < 25) add_for_short_vectors(x, y, a, n); \
433  else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
434  }
435 
436 # define axpy_p1(base_type) const std::vector<base_type > &x
437 # define axpy_trans1(base_type) base_type a(1)
438 # define axpy_p1_s(base_type) \
439  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
440 # define axpy_trans1_s(base_type) \
441  std::vector<base_type > &x = \
442  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
443  base_type a(x_.r)
444 
445  axpy_interface(axpy_p1, axpy_trans1, saxpy_, BLAS_S)
446  axpy_interface(axpy_p1, axpy_trans1, daxpy_, BLAS_D)
447  axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C)
448  axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z)
449 
450  axpy2_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
451  axpy2_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
452  axpy2_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
453  axpy2_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
454 
455 
456  /* ********************************************************************* */
457  /* mult_add(A, x, z). */
458  /* ********************************************************************* */
459 
460 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
461  base_type, orien) \
462  inline void mult_add_spec(param1(base_type), param2(base_type), \
463  std::vector<base_type > &z, orien) { \
464  GMMLAPACK_TRACE("gemv_interface"); \
465  trans1(base_type); trans2(base_type); base_type beta(1); \
466  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
467  BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
468  if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
469  &beta, &z[0], &inc); \
470  else gmm::clear(z); \
471  }
472 
473  // First parameter
474 # define gem_p1_n(base_type) const dense_matrix<base_type > &A
475 # define gem_trans1_n(base_type) const char t = 'N'
476 # define gem_p1_t(base_type) \
477  const transposed_col_ref<dense_matrix<base_type > *> &A_
478 # define gem_trans1_t(base_type) dense_matrix<base_type > &A = \
479  const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
480  const char t = 'T'
481 # define gem_p1_tc(base_type) \
482  const transposed_col_ref<const dense_matrix<base_type > *> &A_
483 # define gem_p1_c(base_type) \
484  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_
485 # define gem_trans1_c(base_type) dense_matrix<base_type > &A = \
486  const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
487  const char t = 'C'
488 
489  // second parameter
490 # define gemv_p2_n(base_type) const std::vector<base_type > &x
491 # define gemv_trans2_n(base_type) base_type alpha(1)
492 # define gemv_p2_s(base_type) \
493  const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
494 # define gemv_trans2_s(base_type) std::vector<base_type > &x = \
495  const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
496  base_type alpha(x_.r)
497 
498  // Z <- AX + Z.
499  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
500  BLAS_S, col_major)
501  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
502  BLAS_D, col_major)
503  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
504  BLAS_C, col_major)
505  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
506  BLAS_Z, col_major)
507 
508  // Z <- transposed(A)X + Z.
509  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
510  BLAS_S, row_major)
511  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
512  BLAS_D, row_major)
513  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
514  BLAS_C, row_major)
515  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
516  BLAS_Z, row_major)
517 
518  // Z <- transposed(const A)X + Z.
519  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
520  BLAS_S, row_major)
521  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
522  BLAS_D, row_major)
523  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
524  BLAS_C, row_major)
525  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
526  BLAS_Z, row_major)
527 
528  // Z <- conjugated(A)X + Z.
529  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
530  BLAS_S, row_major)
531  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
532  BLAS_D, row_major)
533  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
534  BLAS_C, row_major)
535  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
536  BLAS_Z, row_major)
537 
538  // Z <- A scaled(X) + Z.
539  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
540  BLAS_S, col_major)
541  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
542  BLAS_D, col_major)
543  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
544  BLAS_C, col_major)
545  gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
546  BLAS_Z, col_major)
547 
548  // Z <- transposed(A) scaled(X) + Z.
549  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
550  BLAS_S, row_major)
551  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
552  BLAS_D, row_major)
553  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
554  BLAS_C, row_major)
555  gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
556  BLAS_Z, row_major)
557 
558  // Z <- transposed(const A) scaled(X) + Z.
559  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
560  BLAS_S, row_major)
561  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
562  BLAS_D, row_major)
563  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
564  BLAS_C, row_major)
565  gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
566  BLAS_Z, row_major)
567 
568  // Z <- conjugated(A) scaled(X) + Z.
569  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
570  BLAS_S, row_major)
571  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
572  BLAS_D, row_major)
573  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
574  BLAS_C, row_major)
575  gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
576  BLAS_Z, row_major)
577 
578 
579  /* ********************************************************************* */
580  /* mult(A, x, y). */
581  /* ********************************************************************* */
582 
583 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
584  base_type, orien) \
585  inline void mult_spec(param1(base_type), param2(base_type), \
586  std::vector<base_type > &z, orien) { \
587  GMMLAPACK_TRACE("gemv_interface2"); \
588  trans1(base_type); trans2(base_type); base_type beta(0); \
589  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
590  BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
591  if (m && n) \
592  blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
593  &z[0], &inc); \
594  else gmm::clear(z); \
595  }
596 
597  // Y <- AX.
598  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
599  BLAS_S, col_major)
600  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
601  BLAS_D, col_major)
602  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
603  BLAS_C, col_major)
604  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
605  BLAS_Z, col_major)
606 
607  // Y <- transposed(A)X.
608  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
609  BLAS_S, row_major)
610  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
611  BLAS_D, row_major)
612  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
613  BLAS_C, row_major)
614  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
615  BLAS_Z, row_major)
616 
617  // Y <- transposed(const A)X.
618  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
619  BLAS_S, row_major)
620  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
621  BLAS_D, row_major)
622  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
623  BLAS_C, row_major)
624  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
625  BLAS_Z, row_major)
626 
627  // Y <- conjugated(A)X.
628  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
629  BLAS_S, row_major)
630  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
631  BLAS_D, row_major)
632  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
633  BLAS_C, row_major)
634  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
635  BLAS_Z, row_major)
636 
637  // Y <- A scaled(X).
638  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
639  BLAS_S, col_major)
640  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
641  BLAS_D, col_major)
642  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
643  BLAS_C, col_major)
644  gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
645  BLAS_Z, col_major)
646 
647  // Y <- transposed(A) scaled(X).
648  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
649  BLAS_S, row_major)
650  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
651  BLAS_D, row_major)
652  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
653  BLAS_C, row_major)
654  gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
655  BLAS_Z, row_major)
656 
657  // Y <- transposed(const A) scaled(X).
658  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
659  BLAS_S, row_major)
660  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
661  BLAS_D, row_major)
662  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
663  BLAS_C, row_major)
664  gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
665  BLAS_Z, row_major)
666 
667  // Y <- conjugated(A) scaled(X).
668  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
669  BLAS_S, row_major)
670  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
671  BLAS_D, row_major)
672  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
673  BLAS_C, row_major)
674  gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
675  BLAS_Z, row_major)
676 
677 
678  /* ********************************************************************* */
679  /* Rank one update. */
680  /* ********************************************************************* */
681 
682 # define ger_interface(blas_name, base_type) \
683  inline void rank_one_update(const dense_matrix<base_type > &A, \
684  const std::vector<base_type > &V, \
685  const std::vector<base_type > &W) { \
686  GMMLAPACK_TRACE("ger_interface"); \
687  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
688  BLAS_INT n(BLAS_INT(mat_ncols(A))); \
689  BLAS_INT incx = 1, incy = 1; \
690  base_type alpha(1); \
691  if (m && n) \
692  blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\
693  }
694 
695  ger_interface(sger_, BLAS_S)
696  ger_interface(dger_, BLAS_D)
697  ger_interface(cgerc_, BLAS_C)
698  ger_interface(zgerc_, BLAS_Z)
699 
700 # define ger_interface_sn(blas_name, base_type) \
701  inline void rank_one_update(const dense_matrix<base_type > &A, \
702  gemv_p2_s(base_type), \
703  const std::vector<base_type > &W) { \
704  GMMLAPACK_TRACE("ger_interface"); \
705  gemv_trans2_s(base_type); \
706  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
707  BLAS_INT n(BLAS_INT(mat_ncols(A))); \
708  BLAS_INT incx = 1, incy = 1; \
709  if (m && n) \
710  blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\
711  }
712 
713  ger_interface_sn(sger_, BLAS_S)
714  ger_interface_sn(dger_, BLAS_D)
715  ger_interface_sn(cgerc_, BLAS_C)
716  ger_interface_sn(zgerc_, BLAS_Z)
717 
718 # define ger_interface_ns(blas_name, base_type) \
719  inline void rank_one_update(const dense_matrix<base_type > &A, \
720  const std::vector<base_type > &V, \
721  gemv_p2_s(base_type)) { \
722  GMMLAPACK_TRACE("ger_interface"); \
723  gemv_trans2_s(base_type); \
724  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
725  BLAS_INT n(BLAS_INT(mat_ncols(A))); \
726  BLAS_INT incx = 1, incy = 1; \
727  base_type al2 = gmm::conj(alpha); \
728  if (m && n) \
729  blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \
730  }
731 
732  ger_interface_ns(sger_, BLAS_S)
733  ger_interface_ns(dger_, BLAS_D)
734  ger_interface_ns(cgerc_, BLAS_C)
735  ger_interface_ns(zgerc_, BLAS_Z)
736 
737  /* ********************************************************************* */
738  /* dense matrix x dense matrix multiplication. */
739  /* ********************************************************************* */
740 
741 # define gemm_interface_nn(blas_name, base_type) \
742  inline void mult_spec(const dense_matrix<base_type > &A, \
743  const dense_matrix<base_type > &B, \
744  dense_matrix<base_type > &C, c_mult) { \
745  GMMLAPACK_TRACE("gemm_interface_nn"); \
746  const char t = 'N'; \
747  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
748  BLAS_INT k(BLAS_INT(mat_ncols(A))); \
749  BLAS_INT n(BLAS_INT(mat_ncols(B))); \
750  BLAS_INT ldb = k, ldc = m; \
751  base_type alpha(1), beta(0); \
752  if (m && k && n) \
753  blas_name(&t, &t, &m, &n, &k, &alpha, \
754  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
755  else gmm::clear(C); \
756  }
757 
758  gemm_interface_nn(sgemm_, BLAS_S)
759  gemm_interface_nn(dgemm_, BLAS_D)
760  gemm_interface_nn(cgemm_, BLAS_C)
761  gemm_interface_nn(zgemm_, BLAS_Z)
762 
763  /* ********************************************************************* */
764  /* transposed(dense matrix) x dense matrix multiplication. */
765  /* ********************************************************************* */
766 
767 # define gemm_interface_tn(blas_name, base_type, is_const) \
768  inline void mult_spec( \
769  const transposed_col_ref<is_const<base_type > *> &A_,\
770  const dense_matrix<base_type > &B, \
771  dense_matrix<base_type > &C, rcmult) { \
772  GMMLAPACK_TRACE("gemm_interface_tn"); \
773  dense_matrix<base_type > &A \
774  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
775  const char t = 'T', u = 'N'; \
776  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
777  BLAS_INT n(BLAS_INT(mat_ncols(B))); \
778  BLAS_INT lda = k, ldb = k, ldc = m; \
779  base_type alpha(1), beta(0); \
780  if (m && k && n) \
781  blas_name(&t, &u, &m, &n, &k, &alpha, \
782  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
783  else gmm::clear(C); \
784  }
785 
786  gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
787  gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
788  gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
789  gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
790  gemm_interface_tn(sgemm_, BLAS_S, const dense_matrix)
791  gemm_interface_tn(dgemm_, BLAS_D, const dense_matrix)
792  gemm_interface_tn(cgemm_, BLAS_C, const dense_matrix)
793  gemm_interface_tn(zgemm_, BLAS_Z, const dense_matrix)
794 
795  /* ********************************************************************* */
796  /* dense matrix x transposed(dense matrix) multiplication. */
797  /* ********************************************************************* */
798 
799 # define gemm_interface_nt(blas_name, base_type, is_const) \
800  inline void mult_spec(const dense_matrix<base_type > &A, \
801  const transposed_col_ref<is_const<base_type > *> &B_, \
802  dense_matrix<base_type > &C, r_mult) { \
803  GMMLAPACK_TRACE("gemm_interface_nt"); \
804  dense_matrix<base_type > &B \
805  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
806  const char t = 'N', u = 'T'; \
807  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
808  BLAS_INT k(BLAS_INT(mat_ncols(A))); \
809  BLAS_INT n(BLAS_INT(mat_nrows(B))); \
810  BLAS_INT ldb = n, ldc = m; \
811  base_type alpha(1), beta(0); \
812  if (m && k && n) \
813  blas_name(&t, &u, &m, &n, &k, &alpha, \
814  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
815  else gmm::clear(C); \
816  }
817 
818  gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
819  gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
820  gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
821  gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
822  gemm_interface_nt(sgemm_, BLAS_S, const dense_matrix)
823  gemm_interface_nt(dgemm_, BLAS_D, const dense_matrix)
824  gemm_interface_nt(cgemm_, BLAS_C, const dense_matrix)
825  gemm_interface_nt(zgemm_, BLAS_Z, const dense_matrix)
826 
827  /* ********************************************************************* */
828  /* transposed(dense matrix) x transposed(dense matrix) multiplication. */
829  /* ********************************************************************* */
830 
831 # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \
832  inline void mult_spec( \
833  const transposed_col_ref<isA_const <base_type > *> &A_, \
834  const transposed_col_ref<isB_const <base_type > *> &B_, \
835  dense_matrix<base_type > &C, r_mult) { \
836  GMMLAPACK_TRACE("gemm_interface_tt"); \
837  dense_matrix<base_type > &A \
838  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
839  dense_matrix<base_type > &B \
840  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
841  const char t = 'T', u = 'T'; \
842  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
843  BLAS_INT n(BLAS_INT(mat_nrows(B))); \
844  BLAS_INT lda = k, ldb = n, ldc = m; \
845  base_type alpha(1), beta(0); \
846  if (m && k && n) \
847  blas_name(&t, &u, &m, &n, &k, &alpha, \
848  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
849  else gmm::clear(C); \
850  }
851 
852  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
853  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
854  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
855  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
856  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
857  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
858  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
859  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
860  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
861  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
862  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
863  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
864  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
865  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
866  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
867  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
868 
869 
870  /* ********************************************************************* */
871  /* conjugated(dense matrix) x dense matrix multiplication. */
872  /* ********************************************************************* */
873 
874 # define gemm_interface_cn(blas_name, base_type) \
875  inline void mult_spec( \
876  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
877  const dense_matrix<base_type > &B, \
878  dense_matrix<base_type > &C, rcmult) { \
879  GMMLAPACK_TRACE("gemm_interface_cn"); \
880  dense_matrix<base_type > &A \
881  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
882  const char t = 'C', u = 'N'; \
883  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
884  BLAS_INT n(BLAS_INT(mat_ncols(B))); \
885  BLAS_INT lda = k, ldb = k, ldc = m; \
886  base_type alpha(1), beta(0); \
887  if (m && k && n) \
888  blas_name(&t, &u, &m, &n, &k, &alpha, \
889  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
890  else gmm::clear(C); \
891  }
892 
893  gemm_interface_cn(sgemm_, BLAS_S)
894  gemm_interface_cn(dgemm_, BLAS_D)
895  gemm_interface_cn(cgemm_, BLAS_C)
896  gemm_interface_cn(zgemm_, BLAS_Z)
897 
898  /* ********************************************************************* */
899  /* dense matrix x conjugated(dense matrix) multiplication. */
900  /* ********************************************************************* */
901 
902 # define gemm_interface_nc(blas_name, base_type) \
903  inline void mult_spec(const dense_matrix<base_type > &A, \
904  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
905  dense_matrix<base_type > &C, c_mult, row_major) { \
906  GMMLAPACK_TRACE("gemm_interface_nc"); \
907  dense_matrix<base_type > &B \
908  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
909  const char t = 'N', u = 'C'; \
910  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
911  BLAS_INT k(BLAS_INT(mat_ncols(A))); \
912  BLAS_INT n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
913  base_type alpha(1), beta(0); \
914  if (m && k && n) \
915  blas_name(&t, &u, &m, &n, &k, &alpha, \
916  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
917  else gmm::clear(C); \
918  }
919 
920  gemm_interface_nc(sgemm_, BLAS_S)
921  gemm_interface_nc(dgemm_, BLAS_D)
922  gemm_interface_nc(cgemm_, BLAS_C)
923  gemm_interface_nc(zgemm_, BLAS_Z)
924 
925  /* ********************************************************************* */
926  /* conjugated(dense matrix) x conjugated(dense matrix) multiplication. */
927  /* ********************************************************************* */
928 
929 # define gemm_interface_cc(blas_name, base_type) \
930  inline void mult_spec( \
931  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
932  const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
933  dense_matrix<base_type > &C, r_mult) { \
934  GMMLAPACK_TRACE("gemm_interface_cc"); \
935  dense_matrix<base_type > &A \
936  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
937  dense_matrix<base_type > &B \
938  = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
939  const char t = 'C', u = 'C'; \
940  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
941  BLAS_INT lda = k, n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
942  base_type alpha(1), beta(0); \
943  if (m && k && n) \
944  blas_name(&t, &u, &m, &n, &k, &alpha, \
945  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
946  else gmm::clear(C); \
947  }
948 
949  gemm_interface_cc(sgemm_, BLAS_S)
950  gemm_interface_cc(dgemm_, BLAS_D)
951  gemm_interface_cc(cgemm_, BLAS_C)
952  gemm_interface_cc(zgemm_, BLAS_Z)
953 
954  /* ********************************************************************* */
955  /* Tri solve. */
956  /* ********************************************************************* */
957 
958 # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\
959  inline void f_name(param1(base_type), std::vector<base_type > &x, \
960  size_type k, bool is_unit) { \
961  GMMLAPACK_TRACE("trsv_interface"); \
962  loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
963  BLAS_INT lda(BLAS_INT(mat_nrows(A))), inc(1), n = BLAS_INT(k); \
964  if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
965  }
966 
967 # define trsv_upper const char l = 'U'
968 # define trsv_lower const char l = 'L'
969 
970  // X <- LOWER(A)^{-1}X.
971  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
972  strsv_, BLAS_S)
973  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
974  dtrsv_, BLAS_D)
975  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
976  ctrsv_, BLAS_C)
977  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
978  ztrsv_, BLAS_Z)
979 
980  // X <- UPPER(A)^{-1}X.
981  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
982  strsv_, BLAS_S)
983  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
984  dtrsv_, BLAS_D)
985  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
986  ctrsv_, BLAS_C)
987  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
988  ztrsv_, BLAS_Z)
989 
990  // X <- LOWER(transposed(A))^{-1}X.
991  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
992  strsv_, BLAS_S)
993  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
994  dtrsv_, BLAS_D)
995  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
996  ctrsv_, BLAS_C)
997  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
998  ztrsv_, BLAS_Z)
999 
1000  // X <- UPPER(transposed(A))^{-1}X.
1001  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1002  strsv_, BLAS_S)
1003  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1004  dtrsv_, BLAS_D)
1005  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1006  ctrsv_, BLAS_C)
1007  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
1008  ztrsv_, BLAS_Z)
1009 
1010  // X <- LOWER(transposed(const A))^{-1}X.
1011  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1012  strsv_, BLAS_S)
1013  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1014  dtrsv_, BLAS_D)
1015  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1016  ctrsv_, BLAS_C)
1017  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
1018  ztrsv_, BLAS_Z)
1019 
1020  // X <- UPPER(transposed(const A))^{-1}X.
1021  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1022  strsv_, BLAS_S)
1023  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1024  dtrsv_, BLAS_D)
1025  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1026  ctrsv_, BLAS_C)
1027  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
1028  ztrsv_, BLAS_Z)
1029 
1030  // X <- LOWER(conjugated(A))^{-1}X.
1031  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1032  strsv_, BLAS_S)
1033  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1034  dtrsv_, BLAS_D)
1035  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1036  ctrsv_, BLAS_C)
1037  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1038  ztrsv_, BLAS_Z)
1039 
1040  // X <- UPPER(conjugated(A))^{-1}X.
1041  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1042  strsv_, BLAS_S)
1043  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1044  dtrsv_, BLAS_D)
1045  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1046  ctrsv_, BLAS_C)
1047  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1048  ztrsv_, BLAS_Z)
1049 
1050 #endif
1051 }
1052 
1053 #endif // GMM_BLAS_INTERFACE_H
1054 
1055 #endif // GMM_USES_BLAS
bgeot::size_type
size_t size_type
used as the common size type in the library
Definition: bgeot_poly.h:49
gmm_interface.h
gmm interface for STL vectors.
gmm_matrix.h
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
gmm_blas.h
Basic linear algebra functions.