00001 #ifndef CUBLAS_WRAPPER_HH
00002 #define CUBLAS_WRAPPER_HH
00003
00004 #include <cutil_inline.h>
00005 #include <cublas.h>
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 struct cublas {
00025
00026
00027
00028
00029
00030
00031 static inline std::string name () { return "cublas"; }
00032
00033
00034
00035
00036
00037
00038
00039 static void check_status(const cublasStatus & status)
00040 {
00041
00042
00043 #ifdef DEBUG
00044 std::string cublas_errors(" ");
00045
00046 if (status != CUBLAS_STATUS_SUCCESS)
00047 {
00048 if (status == CUBLAS_STATUS_NOT_INITIALIZED)
00049 cublas_errors += "cublas not initialized ";
00050 if (status == CUBLAS_STATUS_MAPPING_ERROR)
00051 cublas_errors +="mapping error ";
00052 if (status == CUBLAS_STATUS_INVALID_VALUE)
00053 cublas_errors +="invalid value ";
00054 if (status == CUBLAS_STATUS_ALLOC_FAILED)
00055 cublas_errors +="allocation failed ";
00056 if (status == CUBLAS_STATUS_ARCH_MISMATCH)
00057 cublas_errors +="architecture mismatch ";
00058 if (status == CUBLAS_STATUS_EXECUTION_FAILED)
00059 cublas_errors +="execution failed ";
00060 if (status == CUBLAS_STATUS_INTERNAL_ERROR)
00061 cublas_errors +="cublas internal error ";
00062
00063 if (cublas_errors == " ")
00064 cublas_errors = "unknown cublas error state";
00065
00066 #ifdef QT_NO_DEBUG
00067 AssertThrow(false, ::ExcMessage(cublas_errors.c_str() ) );
00068 #else
00069 Assert(false, ::ExcMessage(cublas_errors.c_str() ) );
00070 #endif
00071 }
00072 #endif
00073 }
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083 static void Init() {
00084 cublasStatus s = cublasInit();
00085
00086 if (s == CUBLAS_STATUS_SUCCESS)
00087 std::cout << "cublas init succeeded" << std::endl;
00088
00089 #ifndef DEBUG
00090 AssertThrow(s == CUBLAS_STATUS_SUCCESS,
00091 ::ExcMessage("cublas init failed"));
00092 #else
00093 Assert(s == CUBLAS_STATUS_SUCCESS,
00094 ::ExcMessage("cublas init failed"));
00095 #endif
00096 }
00097
00098
00099
00100
00101
00102
00103 static void Shutdown() {
00104
00105 cublasStatus s = cublasShutdown();
00106
00107 if (s == CUBLAS_STATUS_SUCCESS)
00108 std::cout << "cublas shutdown succeeded" << std::endl;
00109
00110 #ifdef QT_NO_DEBUG
00111 AssertThrow(s == CUBLAS_STATUS_SUCCESS,
00112 ::ExcMessage("cublas shutdown failed"));
00113 #else
00114 Assert(s == CUBLAS_STATUS_SUCCESS,
00115 ::ExcMessage("cublas shutdown failed"));
00116 #endif
00117
00118 }
00119
00120
00121 template<typename T>
00122 struct Data {
00123
00124 Data() : dev_ptr(0) {}
00125
00126 Data(size_t n) { alloc(n); }
00127
00128 ~Data() { free_dev_ptr(); }
00129 T * data() { return dev_ptr; }
00130
00131 const T * data() const { return dev_ptr; }
00132
00133 void resize(size_t n)
00134 {
00135 free_dev_ptr();
00136
00137 alloc(n);
00138 }
00139
00140 private:
00141 void free_dev_ptr() {
00142
00143
00144
00145 if (dev_ptr == 0) return;
00146
00147 cublasStatus status;
00148
00149 status = cublasFree(dev_ptr);
00150
00151 check_status(status);
00152
00153 dev_ptr = 0;
00154
00155 std::cout << "GPU mem freed " << std::endl;
00156
00157 }
00158
00159 void alloc(size_t n)
00160 {
00161 cublasStatus status = cublasAlloc( n, sizeof(T), (void**)&dev_ptr);
00162
00163 check_status(status);
00164 }
00165
00166
00167 T * dev_ptr;
00168 };
00169
00170
00171 private:
00172
00173
00174
00175
00176 template<typename T>
00177 static void Alloc(T *&dev_ptr, size_t n) {
00178
00179 #ifdef QT_NO_DEBUG
00180 AssertThrow(n > 0,
00181 ::ExcMessage("allocation of 0 elements not allowed"));
00182 #else
00183 Assert(n > 0,
00184 ::ExcMessage("allocation of 0 elements not allowed"));
00185 #endif
00186
00187 cublasStatus status = cublasAlloc( n, sizeof(T), (void**)&dev_ptr);
00188
00189 #ifdef QT_NO_DEBUG
00190 AssertThrow(status == CUBLAS_STATUS_SUCCESS,
00191 ::ExcMessage("cublas allocation failed" ) );
00192 #else
00193 AssertThrow(status == CUBLAS_STATUS_SUCCESS,
00194 ::ExcMessage("cublas allocation failed" ) );
00195 #endif
00196
00197
00198 }
00199
00200
00201
00202 template<typename T>
00203 static void Free(T *&dev_ptr) {
00204
00205 cublasStatus status = cublasFree(dev_ptr);
00206
00207 #ifdef QT_NO_DEBUG
00208 AssertThrow(status == CUBLAS_STATUS_SUCCESS,
00209 ::ExcMessage("cublas deallocation failed" ) );
00210 #else
00211 AssertThrow(status == CUBLAS_STATUS_SUCCESS,
00212 ::ExcMessage("cublas deallocation failed" ) );
00213 #endif
00214
00215 }
00216
00217 public:
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227 template<typename T>
00228 static void SetMatrix(int rows, int cols, const T *const &A,
00229 int lda, T *&B, int ldb)
00230 {
00231 cublasStatus status = cublasSetMatrix(rows, cols, sizeof(T),
00232 A, lda, B, ldb);
00233
00234 check_status(status);
00235 }
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245 template<typename T>
00246 static void GetMatrix(int rows, int cols, const T * const &A,
00247 int lda, T *&B, int ldb)
00248 {
00249 cublasStatus status = cublasGetMatrix(rows, cols, sizeof(T),
00250 A, lda, B, ldb);
00251
00252 check_status(status);
00253 }
00254
00255
00256
00257
00258
00259
00260
00261
00262 template<typename T>
00263 static void SetVector(int n_el, const T * const src, int inc_src, T *dst, int inc_dst)
00264 {
00265 cublasStatus status = cublasSetVector(n_el, sizeof(T),
00266 src, inc_src, dst, inc_dst);
00267
00268 check_status(status);
00269
00270 }
00271
00272
00273
00274
00275
00276
00277
00278
00279 template<typename T>
00280 static void GetVector(int n_el, const T * const &A, int inc_src, T *&B, int inc_dst)
00281 {
00282 cublasStatus status = cublasGetVector(n_el, sizeof(T),
00283 A, inc_src, B, inc_dst);
00284
00285 check_status(status);
00286 }
00287
00288
00289
00290
00291
00292
00293 static float
00294 asum (int n, const float *x,
00295 int incx)
00296 {
00297 float sum = cublasSasum(n, x, incx);
00298
00299 cublasStatus status = cublasGetError();
00300
00301 check_status(status);
00302 return sum;
00303 }
00304
00305 static double
00306 asum (int n, const double *x,
00307 int incx)
00308 {
00309 double sum = cublasDasum(n, x, incx);
00310
00311 cublasStatus status = cublasGetError();
00312
00313 check_status(status);
00314 return sum;
00315 }
00316
00317 static float
00318 asum (int n, const cuComplex *x,
00319 int incx)
00320 {
00321 float sum = cublasScasum(n, x, incx);
00322
00323 cublasStatus status = cublasGetError();
00324
00325 check_status(status);
00326 return sum;
00327 }
00328
00329 static double
00330 asum (int n, const cuDoubleComplex *x,
00331 int incx)
00332 {
00333 double sum = cublasDzasum(n, x, incx);
00334
00335 cublasStatus status = cublasGetError();
00336
00337 check_status(status);
00338 return sum;
00339 }
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350 static void
00351 axpy (int n, float alpha, const float *x,
00352 int incx, float *y, int incy)
00353 {
00354 cublasSaxpy(n, alpha, x, incx, y, incy);
00355
00356 cublasStatus status = cublasGetError();
00357
00358 check_status(status);
00359 }
00360
00361 static void
00362 axpy (int n, double alpha, const double *x,
00363 int incx, double *y, int incy)
00364 {
00365 cublasDaxpy(n, alpha, x, incx, y, incy);
00366
00367 cublasStatus status = cublasGetError();
00368
00369 check_status(status);
00370 }
00371
00372 static void
00373 axpy (int n, cuComplex alpha, const cuComplex *x,
00374 int incx, cuComplex *y, int incy)
00375 {
00376 cublasCaxpy(n, alpha, x, incx, y, incy);
00377
00378 cublasStatus status = cublasGetError();
00379
00380 check_status(status);
00381 }
00382
00383 static void
00384 axpy (int n, cuDoubleComplex alpha, const cuDoubleComplex *x,
00385 int incx, cuDoubleComplex *y, int incy)
00386 {
00387 cublasZaxpy(n, alpha, x, incx, y, incy);
00388
00389 cublasStatus status = cublasGetError();
00390
00391 check_status(status);
00392 }
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402 static void
00403 copy(int n, const float *x, int incx, float *y, int incy)
00404 {
00405 cublasScopy(n, x, incx, y, incy);
00406
00407 cublasStatus status = cublasGetError();
00408
00409 check_status(status);
00410 }
00411
00412 static void
00413 copy(int n, const double *x, int incx, double *y, int incy)
00414 {
00415 cublasDcopy(n, x, incx, y, incy);
00416
00417 cublasStatus status = cublasGetError();
00418
00419 check_status(status);
00420 }
00421 static void
00422 copy(int n, const cuComplex *x, int incx, cuComplex *y, int incy)
00423 {
00424 cublasCcopy(n, x, incx, y, incy);
00425
00426 cublasStatus status = cublasGetError();
00427
00428 check_status(status);
00429 }
00430
00431 static void
00432 copy(int n, const cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy)
00433 {
00434 cublasZcopy(n, x, incx, y, incy);
00435
00436 cublasStatus status = cublasGetError();
00437
00438 check_status(status);
00439 }
00440
00441
00442
00443
00444
00445
00446
00447
00448 static void
00449 scal (int n, float alpha, float *x, int incx)
00450 {
00451 cublasSscal (n, alpha, x, incx);
00452
00453 cublasStatus status = cublasGetError();
00454
00455 check_status(status);
00456 }
00457
00458 static void
00459 scal (int n, double alpha, double *x, int incx)
00460 {
00461 cublasDscal (n, alpha, x, incx);
00462
00463 cublasStatus status = cublasGetError();
00464
00465 check_status(status);
00466 }
00467
00468
00469
00470 static void
00471 scal (int n, cuComplex alpha, cuComplex *x, int incx)
00472 {
00473 cublasCscal (n, alpha, x, incx);
00474
00475 cublasStatus status = cublasGetError();
00476
00477 check_status(status);
00478 }
00479
00480
00481
00482 static void
00483 scal (int n, cuDoubleComplex alpha, cuDoubleComplex *x, int incx)
00484 {
00485 cublasZscal (n, alpha, x, incx);
00486
00487 cublasStatus status = cublasGetError();
00488
00489 check_status(status);
00490 }
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506 static void gemv (char trans, int m, int n, float alpha,
00507 const float * const A, int lda,
00508 const float * const x, int incx, float beta,
00509 float *y, int incy)
00510 {
00511 cublasSgemv (trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
00512
00513 cublasStatus status = cublasGetError();
00514
00515 check_status(status);
00516 }
00517
00518 static void gemv (char trans, int m, int n, double alpha,
00519 const double * const A, int lda,
00520 const double * const x, int incx, double beta,
00521 double *y, int incy)
00522 {
00523 cublasDgemv (trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
00524
00525 cublasStatus status = cublasGetError();
00526
00527 check_status(status);
00528 }
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542 static void
00543 ger(int m, int n, float alpha, const float *x,
00544 int incx, const float *y, int incy, float *A,
00545 int lda)
00546 {
00547 cublasSger(m, n, alpha, x, incx, y, incy, A, lda);
00548
00549 cublasStatus status = cublasGetError();
00550
00551 check_status(status);
00552 }
00553
00554 static void
00555 ger(int m, int n, double alpha, const double *x,
00556 int incx, const double *y, int incy, double *A,
00557 int lda)
00558 {
00559 cublasDger(m, n, alpha, x, incx, y, incy, A, lda);
00560
00561 cublasStatus status = cublasGetError();
00562
00563 check_status(status);
00564 }
00565
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583 static void gemm(char transa, char transb, int m, int n, int k, float alpha,
00584 const float * const A, int lda, const float * const B, int ldb,
00585 float beta, float * C, int ldc)
00586 {
00587 cublasSgemm(transa, transb, m, n, k, alpha,
00588 A, lda, B, ldb,
00589 beta, C, ldc);
00590
00591 cublasStatus status = cublasGetError();
00592
00593 check_status(status);
00594 }
00595
00596
00597 static void gemm(char transa, char transb, int m, int n, int k, double alpha,
00598 const double * const A, int lda, const double * const B, int ldb,
00599 double beta, double * C, int ldc)
00600 {
00601 cublasDgemm(transa, transb, m, n, k, alpha,
00602 A, lda, B, ldb,
00603 beta, C, ldc);
00604
00605 cublasStatus status = cublasGetError();
00606
00607 check_status(status);
00608 }
00609
00610
00611
00612
00613
00614
00615
00616
00617 static float
00618 nrm2(int n, const float *x, int incx)
00619 {
00620 float result = cublasSnrm2 (n, x, incx);
00621
00622 cublasStatus status = cublasGetError();
00623
00624 check_status(status);
00625
00626 return result;
00627 }
00628
00629 static double
00630 nrm2(int n, const double *x, int incx)
00631 {
00632 double result = cublasDnrm2 (n, x, incx);
00633
00634 cublasStatus status = cublasGetError();
00635
00636 check_status(status);
00637
00638 return result;
00639 }
00640
00641 static cuComplex
00642 nrm2(int n, const cuComplex *x, int incx)
00643 {
00644 cuComplex result = dotc(n, x, incx, x, incx);
00645 result.x = sqrt(result.x);
00646 result.y = 0.;
00647
00648 return result;
00649
00650 }
00651
00652
00653
00654 static cuDoubleComplex
00655 nrm2(int n, const cuDoubleComplex *x, int incx)
00656 {
00657 cuDoubleComplex result = dotc(n, x, incx, x, incx);
00658 result.x = sqrt(result.x);
00659 result.y = 0.;
00660
00661 return result;
00662
00663 }
00664
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674 static float dot(int n, const float *x, int incx, const float *y, int incy)
00675 {
00676 float result = cublasSdot(n, x, incx, y, incy);
00677
00678 cublasStatus status = cublasGetError();
00679
00680 check_status(status);
00681
00682 return result;
00683 }
00684
00685 static double dot(int n, const double *x,
00686 int incx, const double *y, int incy)
00687 {
00688 double result = cublasDdot(n, x, incx, y, incy);
00689
00690 cublasStatus status = cublasGetError();
00691
00692 check_status(status);
00693
00694 return result;
00695 }
00696
00697 static cuComplex dot(int n, const cuComplex *x,
00698 int incx, const cuComplex *y, int incy)
00699 {
00700 return dotc(n, x, incx, y, incy);
00701 }
00702
00703 static cuDoubleComplex dot(int n, const cuDoubleComplex *x,
00704 int incx, const cuDoubleComplex *y, int incy)
00705 {
00706 return dotc(n, x, incx, y, incy);
00707 }
00708
00709
00710
00711
00712
00713
00714
00715
00716
00717
00718 static cuComplex dotu(int n, const cuComplex *x,
00719 int incx, const cuComplex *y, int incy)
00720 {
00721 cuComplex result = cublasCdotu(n, x, incx, y, incy);
00722
00723 cublasStatus status = cublasGetError();
00724
00725 check_status(status);
00726
00727 std::cout << __FUNCTION__ << " : " << " c*c^* = " << result.x << ", " << result.y << std::endl;
00728
00729 return result;
00730 }
00731
00732
00733 static cuDoubleComplex dotu(int n, const cuDoubleComplex *x,
00734 int incx, const cuDoubleComplex *y, int incy)
00735 {
00736 cuDoubleComplex result = cublasZdotu(n, x, incx, y, incy);
00737
00738 cublasStatus status = cublasGetError();
00739
00740 check_status(status);
00741
00742 return result;
00743 }
00744
00745
00746
00747
00748
00749
00750
00751
00752
00753
00754 static cuComplex dotc(int n, const cuComplex *x,
00755 int incx, const cuComplex *y, int incy)
00756 {
00757 cuComplex result = cublasCdotc(n, x, incx, y, incy);
00758
00759 cublasStatus status = cublasGetError();
00760
00761 check_status(status);
00762
00763
00764
00765 return result;
00766 }
00767
00768
00769 static cuDoubleComplex dotc(int n, const cuDoubleComplex *x,
00770 int incx, const cuDoubleComplex *y, int incy)
00771 {
00772 cuDoubleComplex result = cublasZdotc(n, x, incx, y, incy);
00773
00774 cublasStatus status = cublasGetError();
00775
00776 check_status(status);
00777
00778 return result;
00779 }
00780
00781
00782
00783
00784 inline void trsm(char side, char uplo, char transa, char diag, int m, int n,
00785 float alpha,
00786 const float * A, int lda, float * B, int ldb)
00787 {
00788 cublasStrsm(side, uplo, transa, diag, m, n, alpha,
00789 A, lda, B, ldb);
00790
00791 cublasStatus status = cublasGetError();
00792
00793 check_status(status);
00794 }
00795
00796 inline void trsm(char side, char uplo, char transa, char diag, int m, int n,
00797 double alpha,
00798 const double * A, int lda, double * B, int ldb)
00799 {
00800 cublasDtrsm(side, uplo, transa, diag, m, n, alpha,
00801 A, lda, B, ldb);
00802
00803 cublasStatus status = cublasGetError();
00804
00805 check_status(status);
00806 }
00807
00808 };
00809
00810
00811
00812
00813 #endif // CUBLAS_WRAPPER_HH
00814