00001 #ifndef BLAS_WRAPPER_HH
00002 #define BLAS_WRAPPER_HH
00003
00004 #ifdef __APPLE__
00005
00006 #include <Accelerate/Accelerate.h>
00007 #else
00008 #include <cblas.h>
00009 #endif
00010
00011 #include <vector>
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 struct blas {
00031
00032 static inline std::string name () { return "cpu-blas"; }
00033
00034 #ifdef dfkgskdd
00035
00036
00037
00038
00039
00040 static void check_status(const cublasStatus & status)
00041 {
00042
00043 std::string cublas_errors(" ");
00044
00045 if (status != CUBLAS_STATUS_SUCCESS)
00046 {
00047 if (status == CUBLAS_STATUS_NOT_INITIALIZED)
00048 cublas_errors += "cublas not initialized ";
00049 if (status == CUBLAS_STATUS_MAPPING_ERROR)
00050 cublas_errors +="mapping error ";
00051 if (status == CUBLAS_STATUS_INVALID_VALUE)
00052 cublas_errors +="invalid value ";
00053 if (status == CUBLAS_STATUS_ALLOC_FAILED)
00054 cublas_errors +="allocation failed ";
00055 if (status == CUBLAS_STATUS_ARCH_MISMATCH)
00056 cublas_errors +="architecture mismatch ";
00057 if (status == CUBLAS_STATUS_EXECUTION_FAILED)
00058 cublas_errors +="execution failed ";
00059 if (status == CUBLAS_STATUS_INTERNAL_ERROR)
00060 cublas_errors +="cublas internal error ";
00061
00062 if (cublas_errors == " ")
00063 cublas_errors = "unknown cublas error state";
00064
00065 #ifdef QT_NO_DEBUG
00066 AssertThrow(false, ::ExcMessage(cublas_errors.c_str() ) );
00067 #else
00068 AssertThrow(false, ::ExcMessage(cublas_errors.c_str() ) );
00069 #endif
00070 }
00071
00072 }
00073 #endif
00074
00075
00076
00077
00078
00079
00080
00081
00082 static void Init() {
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098 std::cout << "ATLAS Init" << std::endl;
00099 }
00100
00101
00102
00103
00104
00105
00106 static void Shutdown() {
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118 std::cout << "ATLAS Down" << std::endl;
00119
00120 }
00121
00122
00123
00124 template<typename T>
00125 class Data {
00126
00127 T * __data;
00128 size_t __n_el;
00129 public:
00130
00131 Data() : __data(0)
00132 {}
00133
00134 Data(size_t n)
00135 : __data(0), __n_el(0)
00136 { resize(n); }
00137
00138 void resize(size_t n)
00139 {
00140 #ifdef QT_NO_DEBUG
00141 AssertThrow(n > 0,
00142 ::ExcMessage("allocation of 0 elements not allowed"));
00143 #else
00144 Assert(n > 0,
00145 ::ExcMessage("allocation of 0 elements not allowed"));
00146 #endif
00147
00148 if (__n_el == 0)
00149 {
00150 __data = new T[n];
00151 __n_el = n;
00152 }
00153 else {
00154 if (__n_el != n)
00155 delete __data;
00156 __data = new T[n];
00157 __n_el = n;
00158 }
00159 };
00160
00161 ~Data()
00162 {
00163 if (__n_el > 0)
00164 {
00165 delete __data;
00166 __data = 0;
00167 }
00168 }
00169
00170 T * data() { return __data; }
00171
00172 const T * data() const { return __data; }
00173 };
00174
00175
00176 public:
00177
00178
00179
00180
00181
00182
00183
00184
00185 template<typename T>
00186 static void SetMatrix(int rows, int cols, const T *const &A,
00187 int , T *&B, int )
00188 {
00189
00190
00191
00192 copy(rows*cols, A, 1, B, 1);
00193
00194
00195 }
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205 template<typename T>
00206 static void GetMatrix(int rows, int cols, const T * const &A,
00207 int lda, T *&B, int ldb)
00208 {
00209
00210
00211
00212
00213
00214 for (int c = 0; c < cols; c++)
00215 {
00216 copy(rows, A+c*lda, 1, B+c*ldb, 1);
00217
00218
00219 }
00220
00221
00222 }
00223
00224
00225
00226
00227
00228
00229
00230
00231 template<typename T>
00232 static void SetVector(int n_el, const T * const src, int inc_src, T *dst, int inc_dst)
00233 {
00234
00235
00236
00237
00238
00239 copy(n_el, src, inc_src, dst, inc_dst);
00240
00241 }
00242
00243
00244
00245
00246
00247
00248
00249
00250 template<typename T>
00251 static void GetVector(int n_el, const T * const &A, int inc_src, T *&B, int inc_dst)
00252 {
00253
00254
00255
00256 copy (n_el, A, inc_src, B, inc_dst);
00257
00258
00259 }
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278 static void
00279 axpy (int n, float alpha, const float *x,
00280 int incx, float *y, int incy)
00281 {
00282 cblas_saxpy(n, alpha, x, incx, y, incy);
00283
00284
00285
00286
00287 }
00288
00289 static void
00290 axpy (int n, double alpha, const double *x,
00291 int incx, double *y, int incy)
00292 {
00293 cblas_daxpy(n, alpha, x, incx, y, incy);
00294
00295
00296
00297
00298 }
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309 static void
00310 copy(int n, const float *x, int incx, float *y, int incy)
00311 {
00312 cblas_scopy(n, x, incx, y, incy);
00313
00314
00315
00316
00317 }
00318
00319 static void
00320 copy(int n, const double *x, int incx, double *y, int incy)
00321 {
00322 cblas_dcopy(n, x, incx, y, incy);
00323
00324
00325
00326
00327 }
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337 static void
00338 scal (int n, float alpha, float *x, int incx)
00339 {
00340 cblas_sscal (n, alpha, x, incx);
00341
00342
00343
00344
00345 }
00346
00347 static void
00348 scal (int n, double alpha, double *x, int incx)
00349 {
00350 cblas_dscal (n, alpha, x, incx);
00351
00352
00353
00354
00355 }
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371 static void gemv (char trans, int m, int n, float alpha,
00372 const float * const A, int lda,
00373 const float * const x, int incx, float beta,
00374 float *y, int incy)
00375 {
00376 CBLAS_TRANSPOSE tr = ( ( (trans == 't') || (trans == 'T') ) ? CblasTrans : CblasNoTrans );
00377 cblas_sgemv (CblasColMajor, tr, m, n, alpha, A, lda, x, incx, beta, y, incy);
00378
00379
00380
00381
00382 }
00383
00384 static void gemv (char trans, int m, int n, double alpha,
00385 const double * const A, int lda,
00386 const double * const x, int incx, double beta,
00387 double *y, int incy)
00388 {
00389 CBLAS_TRANSPOSE tr = ( ( (trans == 't') || (trans == 'T') ) ? CblasTrans : CblasNoTrans );
00390 cblas_dgemv (CblasColMajor, tr, m, n, alpha, A, lda, x, incx, beta, y, incy);
00391
00392
00393
00394
00395 }
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409 static void
00410 ger(int m, int n, float alpha, const float *x,
00411 int incx, const float *y, int incy, float *A,
00412 int lda)
00413 {
00414 cblas_sger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
00415
00416
00417
00418
00419 }
00420
00421 static void
00422 ger(int m, int n, double alpha, const double *x,
00423 int incx, const double *y, int incy, double *A,
00424 int lda)
00425 {
00426 cblas_dger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
00427
00428
00429
00430
00431 }
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450 static void gemm(char transa, char transb, int m, int n, int k, float alpha,
00451 const float * const A, int lda, const float * const B, int ldb,
00452 float beta, float * C, int ldc)
00453 {
00454 CBLAS_TRANSPOSE tr_a = ( ( (transa == 't') || (transa == 'T') ) ? CblasTrans : CblasNoTrans );
00455 CBLAS_TRANSPOSE tr_b = ( ( (transb == 't') || (transb == 'T') ) ? CblasTrans : CblasNoTrans );
00456
00457 cblas_sgemm(CblasColMajor,
00458 tr_a, tr_b,
00459 m, n, k,
00460 alpha,
00461 A, lda,
00462 B, ldb,
00463 beta,
00464 C, ldc);
00465
00466
00467
00468
00469 }
00470
00471
00472 static void gemm(char transa, char transb, int m, int n, int k, double alpha,
00473 const double * const A, int lda, const double * const B, int ldb,
00474 double beta, double * C, int ldc)
00475 {
00476 CBLAS_TRANSPOSE tr_a = ( ( (transa == 't') || (transa == 'T') ) ? CblasTrans : CblasNoTrans );
00477 CBLAS_TRANSPOSE tr_b = ( ( (transb == 't') || (transb == 'T') ) ? CblasTrans : CblasNoTrans );
00478
00479 cblas_dgemm(CblasColMajor, tr_a, tr_b, m, n, k, alpha,
00480 A, lda, B, ldb,
00481 beta, C, ldc);
00482
00483
00484
00485
00486 }
00487
00488
00489
00490
00491
00492
00493
00494
00495 static float
00496 nrm2(int n, const float *x, int incx)
00497 {
00498 float result = cblas_snrm2 (n, x, incx);
00499
00500
00501
00502
00503
00504 return result;
00505 }
00506
00507 static double
00508 nrm2(int n, const double *x, int incx)
00509 {
00510 double result = cblas_dnrm2 (n, x, incx);
00511
00512
00513
00514
00515
00516 return result;
00517 }
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527 static float dot(int n, const float *x, int incx, const float *y, int incy)
00528 {
00529 float result = cblas_sdot(n, x, incx, y, incy);
00530
00531
00532
00533
00534
00535 return result;
00536 }
00537
00538 static double dot(int n, const double *x,
00539 int incx, const double *y, int incy)
00540 {
00541 double result = cblas_ddot(n, x, incx, y, incy);
00542
00543
00544
00545
00546
00547 return result;
00548 }
00549
00550 };
00551
00552
00553
00554
00555 #endif // BLAS_WRAPPER_HH
00556