00001 #ifndef cublas_Matrix_H
00002 #define cublas_Matrix_H
00003
00004
00005 #include <iomanip>
00006 #include <cmath>
00007 #include <QtGlobal>
00008
00009 #include <base/subscriptor.h>
00010
00011 #include <lac/cublas_Array.h>
00012 #include <lac/expression_template.h>
00013 #include <lac/cublas_Vector.h>
00014
00015 struct mmu;
00016
00017 template<typename> class FullMatrixAccessor;
00018
00019
00020 namespace bw_types {
00021
00022 template<typename, typename> class Vector;
00023
00024 template<typename, typename> class VectorView;
00025
00026 template<typename, typename> class SubMatrixView;
00027 }
00028
00029
00030 namespace bw_types {
00031
00032
00033
00034 template<typename T, typename BW>
00035 class Matrix
00036 :
00037 protected bw_types::Array<T, BW>,
00038 public ::Subscriptor {
00039
00040 friend class Vector<T, BW>;
00041
00042 friend class VectorView<T, Matrix<T, BW> >;
00043
00044 friend class SubMatrixView<T, BW>;
00045
00046 friend class FullMatrixAccessor<T>;
00047
00048 public:
00049
00050 typedef T Number;
00051
00052 typedef T value_type;
00053
00054 typedef BW blas_wrapper_type;
00055
00056 Matrix();
00057
00058 Matrix(int n_rows, int n_cols);
00059
00060 Matrix(int n_rows, int n_cols, const Matrix<T, BW> & src_data);
00061
00062 Matrix(const Matrix<T, BW> & other);
00063
00064 template<typename L, typename R>
00065 Matrix(const X_read_read<L, mmu, R> & AB);
00066
00067
00068 Matrix(const ::IdentityMatrix & Id);
00069
00070 void reinit(int n_rows, int n_cols);
00071
00072 Matrix<T, BW> & operator = (const FullMatrixAccessor<T> & matrix);
00073
00074 Matrix<T, BW> & operator = (const ::IdentityMatrix & Id);
00075
00076 Matrix<T, BW> & operator = (const Matrix<T, BW> & array);
00077
00078
00079 template<typename L, typename R>
00080 Matrix<T, BW> & operator = (const X_read_read<L, mmu, R> & AB);
00081
00082
00083 Matrix<T, BW> & operator += (const Matrix<T, BW> & other);
00084
00085 Matrix<T, BW> & operator -= (const Matrix<T, BW> & other);
00086
00087
00088 template<typename VECTOR1, typename VECTOR2>
00089 void vmult(VECTOR1& dst, const VECTOR2& src) const;
00090
00091 template<typename VECTOR1, typename VECTOR2>
00092 void Tvmult(VECTOR1& dst, const VECTOR2& src) const;
00093
00094 void scaled_vmult(T beta, Vector<T, BW>& dst,
00095 T alpha, const Vector<T, BW>& src) const;
00096
00097
00098 void mmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const;
00099
00100 void mmult(SubMatrixView<T, BW>& dst, const Matrix<T, BW>& src) const;
00101
00102
00103 void mTmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const;
00104
00105
00106 void Tmmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const;
00107
00108
00109 void TmTmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const;
00110
00111 template<typename VECTOR1, typename VECTOR2>
00112 void add_scaled_outer_product(T alpha, const VECTOR1& col, const VECTOR2& row);
00113
00114
00115
00116 int n_rows() const { return __n_rows; }
00117
00118 int n_cols() const { return __n_cols; }
00119
00120 void print() const;
00121
00122 T operator () (const unsigned int i, const unsigned int j) const;
00123
00124 T l2_norm() const;
00125
00126 inline bw_types::Array<T, BW> & array() { return *this; }
00127
00128 inline const bw_types::Array<T, BW> & array() const { return *this; }
00129
00130 private:
00131
00132 Matrix<T, BW> & operator = (const Array<T, BW> & src);
00133
00134 int __n_rows;
00135 int __n_cols;
00136 };
00137
00138 }
00139
00140
00141
00142
00143
00144 template<typename T, typename BW>
00145 bw_types::Matrix<T, BW>::Matrix()
00146 :
00147 Array<T, BW>(), __n_rows(0), __n_cols(0)
00148 {}
00149
00150
00151
00152 template<typename T, typename BW>
00153 bw_types::Matrix<T, BW>::Matrix(int n_rows, int n_cols)
00154 :
00155 Array<T, BW>(n_rows*n_cols),
00156 __n_rows(n_rows),
00157 __n_cols(n_cols)
00158 {
00159 const std::vector<T> tmp(n_rows*n_cols, 0);
00160
00161 const T * const tmp_ptr = &tmp[0];
00162
00163 T * this_data = this->data();
00164
00165 BW::SetMatrix(n_rows, n_cols, tmp_ptr, n_rows,
00166 this_data, n_rows);
00167 }
00168
00169
00170 template<typename T, typename BW>
00171 bw_types::Matrix<T, BW>::Matrix(int n_rows, int n_cols,
00172 const Matrix<T, BW> & src_data)
00173 :
00174 Array<T, BW>(n_rows*n_cols),
00175 __n_rows(n_rows),
00176 __n_cols(n_cols)
00177 {
00178 const std::vector<T> tmp(n_rows*n_cols, 0);
00179
00180 BW::SetMatrix(n_rows, n_cols, &tmp[0], n_rows,
00181 this->data(), n_rows);
00182
00183 Array<T, BW> & self = *this;
00184
00185 self = src_data;
00186 }
00187
00188
00189 template<typename T, typename BW>
00190 bw_types::Matrix<T, BW>::Matrix(const ::IdentityMatrix & Id)
00191 :
00192 Array<T, BW>(), __n_rows(0), __n_cols(0)
00193 {
00194 *this = Id;
00195 }
00196
00197
00198
00199
00200 template<typename T, typename BW>
00201 bw_types::Matrix<T, BW>::Matrix(const Matrix<T, BW> & other)
00202 {
00203 *this = other;
00204 }
00205
00206
00207
00208 template<typename T, typename BW>
00209 void bw_types::Matrix<T, BW>::reinit(int n_rows, int n_cols)
00210 {
00211
00212 this->Array<T, BW>::reinit(n_rows*n_cols);
00213 __n_rows = n_rows;
00214 __n_cols = n_cols;
00215
00216
00217
00218
00219
00220
00221
00222 }
00223
00224
00225
00226
00227
00228
00229 template<typename T, typename BW>
00230 bw_types::Matrix<T, BW> &
00231 bw_types::Matrix<T, BW>::operator = (const Array<T, BW> & src)
00232 {
00233 Assert(this->n_elements() <= src.n_elements(),
00234 ::ExcMessage("n_element mismatch") );
00235
00236
00237 int inc_src = 1;
00238 int inc_this = 1;
00239
00240 BW::copy(this->n_elements(), src.data(), inc_src,
00241 this->data(), inc_this);
00242
00243 return *this;
00244 }
00245
00246 template<typename T, typename BW>
00247 bw_types::Matrix<T, BW> &
00248 bw_types::Matrix<T, BW>::operator = (const ::IdentityMatrix & Id)
00249 {
00250 this->__n_rows = Id.m();
00251 this->__n_cols = Id.n();
00252
00253
00254
00255 int n_dofs = Id.m();
00256 int n_src_el = n_dofs*n_dofs;
00257
00258
00259 this->Array<T,BW>::reinit(n_src_el);
00260
00261 FullMatrixAccessor<T> tmp_id(Id);
00262
00263
00264
00265
00266
00267 const T * id_val = tmp_id.val();
00268 T * dst_val = this->data();
00269 BW::SetMatrix(n_dofs, n_dofs, id_val,
00270 n_dofs, dst_val, n_dofs );
00271
00272 return *this;
00273 }
00274
00275
00276
00277 template<typename T, typename BW>
00278 bw_types::Matrix<T, BW> &
00279 bw_types::Matrix<T, BW>::operator = (const FullMatrixAccessor<T> & src_matrix)
00280 {
00281
00282
00283 Assert(src_matrix.is_column_major(),
00284 ::ExcMessage("bw_types:Matrix expects a matrix in column major"
00285 " format as source.") );
00286
00287 int n_src_el = src_matrix.n_elements();
00288
00289 this->Array<T, BW>::reinit(n_src_el);
00290
00291 int nr = src_matrix.n_rows();
00292 int nc = src_matrix.n_cols();
00293
00294 const T * tmp_src = src_matrix.val();
00295
00296
00297 T * tmp_dst = this->data();
00298 BW::SetMatrix(nr, nc, tmp_src, nr, tmp_dst, nr);
00299
00300
00301 this->__n_rows = nr;
00302 this->__n_cols = nc;
00303
00304 return *this;
00305 }
00306
00307
00308
00309
00310 template<typename T, typename BW>
00311 bw_types::Matrix<T, BW> &
00312 bw_types::Matrix<T, BW>::operator = (const Matrix<T, BW> & other)
00313 {
00314 this->__n_rows = other.__n_rows;
00315 this->__n_cols = other.__n_cols;
00316
00317 this->Array<T, BW>::reinit(__n_rows*__n_cols);
00318
00319
00320 int inc_src = 1;
00321 int inc_this = 1;
00322
00323 BW::copy(this->n_elements(), other.data(), inc_src,
00324 this->data(), inc_this);
00325
00326 return *this;
00327 }
00328
00329
00330
00331
00332
00333 template<typename T, typename BW>
00334 bw_types::Matrix<T, BW> & bw_types::Matrix<T, BW>::operator += (const Matrix<T, BW> & other)
00335 {
00336
00337 Assert((this->n_rows() == other.n_rows()) && (this->n_cols() == other.n_cols()),
00338 ::ExcMessage("Dimension mismatch"));
00339
00340 int n = this->n_elements();
00341 T alpha = 1.;
00342
00343 const T * const x = other.array().val();
00344 int incx = 1;
00345
00346 T * y = this->val();
00347 int incy = 1;
00348
00349 BW::axpy(n, alpha, x, incx, y, incy);
00350
00351 return *this;
00352 }
00353
00354
00355
00356
00357
00358
00359 template<typename T, typename BW>
00360 bw_types::Matrix<T, BW> & bw_types::Matrix<T, BW>::operator -= (const Matrix<T, BW> & other)
00361 {
00362
00363 Assert((this->n_rows() == other.n_rows()) && (this->n_cols() == other.n_cols()),
00364 ::ExcMessage("Dimension mismatch"));
00365
00366 int n = this->n_elements();
00367 T alpha = -1.;
00368
00369 const T * const x = other.array().val();
00370 int incx = 1;
00371
00372 T * y = this->val();
00373 int incy = 1;
00374
00375 BW::axpy(n, alpha, x, incx, y, incy);
00376
00377 return *this;
00378 }
00379
00380
00381
00382
00383
00384 template<typename T, typename BW>
00385 T bw_types::Matrix<T, BW>::l2_norm() const
00386 {
00387
00388 T result = BW::nrm2(this->n_elements(), this->val(), 1);
00389
00390 return result;
00391
00392 }
00393
00394
00395
00396
00397
00398
00399 template<typename T, typename BW>
00400 template<typename VECTOR1, typename VECTOR2>
00401 inline
00402 void bw_types::Matrix<T, BW>::vmult(VECTOR1& dst, const VECTOR2& src) const
00403 {
00404
00405 T alpha = 1.;
00406 T beta = 0.;
00407 int leading_dim_A = this->__n_rows;
00408
00409 T * dst_ptr = dst.val();
00410 const T * src_ptr = src.val();
00411
00412 Assert(src.size() == this->__n_cols,
00413 ::ExcMessage("Dimension mismatch"));
00414
00415 BW::gemv('n', this->__n_rows, this->__n_cols, alpha, this->data(),
00416 leading_dim_A, src_ptr, 1, beta, dst_ptr, 1);
00417 }
00418
00419
00420
00421
00422
00423
00424
00425
00426 template<typename T, typename BW>
00427 template<typename VECTOR1, typename VECTOR2>
00428 inline
00429 void bw_types::Matrix<T, BW>::Tvmult(VECTOR1& dst, const VECTOR2& src) const
00430 {
00431
00432 T alpha = 1.;
00433 T beta = 0.;
00434 int leading_dim_A = this->__n_rows;
00435
00436 BW::gemv('t', this->__n_rows, this->__n_cols, alpha, this->data(),
00437 leading_dim_A, src.val(), 1, beta, dst.val(), 1);
00438 }
00439
00440
00441
00442
00443
00444
00445
00446 template<typename T, typename BW>
00447 void
00448 bw_types::Matrix<T, BW>::scaled_vmult(T beta, Vector<T, BW>& dst ,
00449 T alpha, const Vector<T, BW>& src )
00450 const
00451 {
00452 BW::gemv('n', this->__n_rows, this->__n_cols, alpha, this->data(),
00453 this->__n_rows, src.val(), 1, beta, dst.val(), 1);
00454 }
00455
00456
00457
00458
00459
00460
00461
00462
00463 template<typename T, typename BW>
00464 void
00465 bw_types::Matrix<T, BW>::mmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const
00466 {
00467 T alpha = 1;
00468 T beta = 0;
00469
00470 int lda = this->__n_rows;
00471 int ldb = src.__n_rows;
00472 int ldc = dst.__n_rows;
00473
00474 BW::gemm('n', 'n',
00475 this->__n_rows ,
00476 dst.__n_cols ,
00477 this->__n_cols ,
00478 alpha,
00479 this->data(), lda,
00480 src.data(), ldb,
00481 beta,
00482 dst.data(), ldc);
00483 }
00484
00485
00486 template<typename T, typename BW>
00487 void
00488 bw_types::Matrix<T, BW>::mmult(SubMatrixView<T, BW>& dst, const Matrix<T, BW>& src) const
00489 {
00490 T alpha = 1;
00491 T beta = 0;
00492
00493
00494 int lda = this->__n_rows;
00495 int ldb = src.__n_rows;
00496 int ldc = dst.__n_rows;
00497
00498 BW::gemm('n', 'n',
00499 this->__n_rows ,
00500 dst.__n_cols ,
00501 this->__n_cols ,
00502 alpha,
00503 this->data(), lda,
00504 src.data(), ldb,
00505 beta,
00506 dst.data(), ldc);
00507 }
00508
00509
00510
00511
00512
00513
00514
00515 template<typename T, typename BW>
00516 void
00517 bw_types::Matrix<T, BW>::mTmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const
00518 {
00519 T alpha = 1;
00520 T beta = 0;
00521
00522 int lda = this->__n_rows;
00523 int ldb = src.__n_rows;
00524 int ldc = dst.__n_rows;
00525
00526 BW::gemm('n', 't',
00527 this->__n_rows ,
00528 dst.__n_cols ,
00529 this->__n_cols ,
00530 alpha,
00531 this->data(), lda,
00532 src.data(), ldb,
00533 beta,
00534 dst.data(), ldc);
00535 }
00536
00537
00538
00539
00540
00541
00542 template<typename T, typename BW>
00543 void
00544 bw_types::Matrix<T, BW>::Tmmult(Matrix<T, BW>& dst, const Matrix<T, BW>& src) const
00545 {
00546 T alpha = 1;
00547 T beta = 0;
00548
00549 int lda = this->__n_rows;
00550 int ldb = src.__n_rows;
00551 int ldc = dst.__n_rows;
00552
00553 BW::gemm('t', 'n',
00554 this->__n_cols ,
00555 dst.__n_cols ,
00556 this->__n_rows ,
00557 alpha,
00558 this->data(), lda,
00559 src.data(), ldb,
00560 beta,
00561 dst.data(), ldc);
00562 }
00563
00564
00565
00566
00567
00568
00569 template<typename T, typename BW>
00570 void
00571 bw_types::Matrix<T, BW>::TmTmult(Matrix<T, BW>& , const Matrix<T, BW>& ) const
00572 {
00573 AssertThrow(false, ::ExcNotImplemented() );
00574 }
00575
00576
00577
00578
00579
00580
00581
00582 template<typename T, typename BW>
00583 template<typename VECTOR1, typename VECTOR2>
00584 void
00585 bw_types::Matrix<T, BW>::add_scaled_outer_product(T alpha,
00586 const VECTOR1& x,
00587 const VECTOR2& y)
00588 {
00589
00590 int m = this->__n_rows;
00591 int n = this->__n_cols;
00592 int lda = this->__n_rows;
00593
00594 #ifdef DEBUG
00595 if (x.size() != m) std::cout << "Dimension mismatch. "
00596 "Vector x.size() should be " << m << " but is " << x.size() << std::endl;
00597 if (y.size() != n) std::cout << "Dimension mismatch. "
00598 "Vector y.size() should be " << n << " but is " << y.size() << std::endl;
00599 #endif
00600 Assert(x.size() == m, ::ExcMessage("Dimension mismatch"));
00601 Assert(y.size() == n, ::ExcMessage("Dimension mismatch"));
00602
00603 int incx = x._stride;
00604 int incy = y._stride;
00605
00606 BW::ger(m, n, alpha,
00607 x.val(), incx,
00608 y.val(), incy,
00609 this->data(), lda);
00610 }
00611
00612
00613
00614
00615
00616
00617 template <typename T, typename BW>
00618 inline
00619 T
00620 bw_types::Matrix<T, BW>::operator () (const unsigned int r,
00621 const unsigned int c) const
00622 {
00623 int lead_dim = this->n_rows();
00624 const T * tmp_d = & this->data()[c*this->__n_rows+r];
00625 T entry;
00626 T * p_e = &entry;
00627 BW::GetMatrix(1, 1,
00628 tmp_d, lead_dim, p_e, 1);
00629
00630 return entry;
00631 }
00632
00633
00634
00635
00636
00637
00638 template<typename T, typename BW>
00639 void
00640 bw_types::Matrix<T, BW>::print() const
00641 {
00642
00643
00644 std::cout << "Matrix dims : " << this->__n_rows << " "
00645 << this->__n_cols << std::endl;
00646
00647 int n_el = this->__n_rows * this->__n_cols;
00648 T * tmp = new T[n_el];
00649
00650 BW::GetMatrix(this->__n_rows, this->__n_cols,
00651 this->data(), this->__n_rows, tmp, this->__n_rows);
00652
00653 for (int r = 0; r < this->__n_rows; ++r)
00654 {
00655 for (int c = 0; c < this->__n_cols; ++c)
00656 std::cout << std::setprecision(4) << std::fixed
00657 << std::setw(15) <<
00658
00659
00660 tmp[c*this->__n_rows + r]
00661
00662 << " ";
00663 std::cout <<";" << std::endl;
00664 }
00665
00666 delete [] tmp;
00667
00668 }
00669
00670
00671
00672
00673
00674
00675
00676
00677
00678
00679 struct mmu
00680 {
00681 template<typename T, typename BW>
00682 static void apply( bw_types::Matrix<T, BW> & C,
00683 const bw_types::Matrix<T, BW> & A,
00684 const bw_types::Matrix<T, BW> & B)
00685 {
00686 C.reinit(A.n_rows(), B.n_cols()); A.mmult(C,B);
00687 }
00688
00689 template<typename T, typename BW>
00690 static void apply( bw_types::Matrix<T, BW> & C,
00691 const bw_types::SubMatrixView<T, BW> & A,
00692 const bw_types::Matrix<T, BW> & B)
00693 {
00694 AssertThrow(false, ::ExcMessage("Not implemented"));
00695
00696 }
00697
00698
00699 template<typename T, typename BW>
00700 static void apply( bw_types::Matrix<T, BW> & C,
00701 const bw_types::Matrix<T, BW> & A,
00702 const transpose<bw_types::Matrix<T, BW> > & B_t)
00703 {
00704 C.reinit(A.n_rows(), B_t.A.n_rows()); A.mTmult(C, B_t.A);
00705 }
00706
00707
00708 template<typename T, typename BW>
00709 static void apply( bw_types::Matrix<T, BW> & C,
00710 const transpose<bw_types::Matrix<T, BW> > & A_t,
00711 const bw_types::Matrix<T, BW> & B)
00712 {
00713 C.reinit(A_t.A.n_cols(), B.n_cols()); A_t.A.Tmmult(C, B);
00714 }
00715
00716
00717 template<typename T, typename BW>
00718 static void apply( bw_types::Matrix<T, BW> & C,
00719 const transpose<bw_types::Matrix<T, BW> > & A_t,
00720 const transpose<bw_types::Matrix<T, BW> > & B_t)
00721 {
00722 C.reinit(A_t.A.n_cols(), B_t.A.n_rows()); A_t.A.TmTmult(C, B_t.A);
00723 }
00724
00725 };
00726
00727
00728
00729
00730
00731
00732 template<typename L, typename T, typename BW>
00733 inline
00734 X_read_read<L, mmu, bw_types::Matrix<T, BW> >
00735 operator * (const L & _l, const bw_types::Matrix<T, BW> & _r)
00736 {
00737 typedef bw_types::Matrix<T, BW> R;
00738 return X_read_read<L, mmu, R> (_l,_r);
00739 }
00740
00741 template<typename L, typename T, typename BW>
00742 inline
00743 X_read_read<L, mmu, transpose<bw_types::Matrix<T, BW> > >
00744 operator * (const L & _l, const transpose<bw_types::Matrix<T, BW> > & _r)
00745 {
00746 typedef transpose<bw_types::Matrix<T, BW> > R;
00747 return X_read_read<L, mmu, R> (_l,_r);
00748 }
00749
00750
00751
00752 template<typename T, typename BW>
00753 template<typename L, typename R>
00754 bw_types::Matrix<T, BW>::Matrix(const X_read_read<L, mmu, R> & AB)
00755 {
00756 AB.apply(*this);
00757 }
00758
00759
00760
00761 template<typename T, typename BW>
00762 template<typename L, typename R>
00763 bw_types::Matrix<T, BW> &
00764 bw_types::Matrix<T, BW>::operator = (const X_read_read<L, mmu, R > & AB)
00765 {
00766 AB.apply(*this);
00767 return *this;
00768 }
00769
00770
00771
00772
00773 #endif
00774