00001 #ifndef cublas_SubMatrixView_H
00002 #define cublas_SubMatrixView_H
00003
00004 #include <lac/cublas_Vector.h>
00005
00006 namespace bw_types {
00007
00008 template<typename, typename> class SubMatrixView ;
00009
00010 template<typename T, typename BW> struct SMSMmult {
00011 const SubMatrixView<T, BW> & l;
00012 const SubMatrixView<T, BW> & r;
00013 SMSMmult (const SubMatrixView<T, BW> & A, const SubMatrixView<T, BW> & B): l(A), r(B){}
00014 };
00015
00016 template<typename T, typename BW> struct SMSMTmult {
00017 const SubMatrixView<T, BW> & l;
00018 const transpose<SubMatrixView<T, BW> > & r;
00019 SMSMTmult (const SubMatrixView<T, BW> & A, const transpose<SubMatrixView<T, BW> > & B): l(A), r(B){}
00020 };
00021
00022
00023 template<typename T, typename BW> struct SMTSMmult {
00024 const transpose<SubMatrixView<T, BW> > & l;
00025 const SubMatrixView<T, BW> & r;
00026 SMTSMmult (const transpose<SubMatrixView<T, BW> > & A, const SubMatrixView<T, BW> & B): l(A), r(B){}
00027 };
00028
00029
00030
00031
00032
00033
00034 template<typename T, typename BW>
00035 inline SMSMmult<T, BW> operator * (const SubMatrixView<T, BW> & A, const SubMatrixView<T, BW> & B) {
00036 return SMSMmult<T, BW>(A, B);
00037 }
00038
00039
00040
00041
00042
00043
00044 template<typename T, typename BW>
00045 inline SMSMTmult<T, BW> operator * (const SubMatrixView<T, BW> & A,
00046 const transpose<SubMatrixView<T, BW> > & B) {
00047 return SMSMTmult<T, BW>(A, B);
00048 }
00049
00050
00051
00052
00053
00054
00055
00056
00057 template<typename T, typename BW>
00058 inline SMTSMmult<T, BW> operator * (const transpose<SubMatrixView<T, BW> > & A,
00059 const SubMatrixView<T, BW> & B) {
00060 return SMTSMmult<T, BW>(A, B);
00061 }
00062
00063
00064
00065
00066 template<typename T, typename BW>
00067 class SubMatrixView {
00068
00069 public:
00070 SubMatrixView(Matrix<T, BW> & src, int r_begin, int c_begin);
00071
00072 SubMatrixView(Matrix<T, BW> & src, int r_begin, int r_end,
00073 int c_begin, int c_end);
00074
00075 void print() const;
00076
00077 SubMatrixView & operator = (const Matrix<T, BW>& col);
00078
00079
00080 SubMatrixView<T, BW> & operator = (const SMSMmult<T, BW> & AB);
00081
00082 SubMatrixView<T, BW> & operator += (const SMSMmult<T, BW> & AB);
00083
00084 SubMatrixView<T, BW> & operator += (const SMSMTmult<T, BW> & AB);
00085
00086 SubMatrixView<T, BW> & operator += (const SMTSMmult<T, BW> & AB);
00087
00088 SubMatrixView<T, BW> & operator -= (const SMSMmult<T, BW> & AB);
00089
00090 const T * val() const { return __src->array().val() + __view_begin; }
00091
00092 T * val() { return __src->array().val() + __view_begin; }
00093
00094 template<typename VECTOR1, typename VECTOR2>
00095 void vmult(VECTOR1& dst, const VECTOR2& src) const;
00096
00097
00098 template<typename VECTOR1, typename VECTOR2>
00099 void Tvmult(VECTOR1& dst, const VECTOR2& src) const;
00100
00101 void add_scaled_outer_product(T alpha,
00102 const Vector<T, BW>& x,
00103 const Vector<T, BW>& y);
00104
00105 private:
00106 SubMatrixView() {}
00107
00108 SubMatrixView(const SubMatrixView<T, BW>& other) {}
00109
00110
00111 SubMatrixView & operator = (const SubMatrixView<T, BW>& col) {}
00112
00113
00114
00115
00116
00117
00118
00119 ::SmartPointer<Matrix<T, BW> >__src;
00120
00121 int __r_begin;
00122 int __c_begin;
00123
00124
00125 int __r_end;
00126 int __c_end;
00127
00128
00129 int __n_el;
00130 int __view_begin;
00131 int __leading_dim;
00132
00133 public:
00134
00135
00136
00137
00138 inline int leading_dim() const{
00139 return __leading_dim;
00140
00141 }
00142
00143
00144 inline int r_begin() const {
00145 return __r_begin;
00146 }
00147
00148 inline int c_begin() const {
00149 return __c_begin;
00150 }
00151
00152 inline int r_end() const {
00153 return __r_end;
00154 }
00155
00156 inline int c_end() const {
00157 return __c_end;
00158 }
00159
00160 inline const Matrix<T, BW>& matrix() const {
00161 return *__src;
00162 }
00163
00164 inline Array<T, BW>& array() {
00165 return __src->array();
00166 }
00167
00168 inline const Array<T, BW>& array() const {
00169 return __src->array();
00170 }
00171
00172 void shift(int m_r, int m_c);
00173 void reset(int new_r_begin, int new_r_end, int new_c_begin, int new_c_end);
00174 void reinit(int new_r_begin, int new_r_end, int new_c_begin, int new_c_end);
00175 };
00176 }
00177
00178
00179
00180
00181
00182
00183
00184 template<typename T, typename BW>
00185 void bw_types::SubMatrixView<T, BW>::shift(int m_r, int m_c) {
00186
00187 reset(__r_begin + m_r, __r_end + m_r, __c_begin + m_c, __c_end + m_c);
00188
00189 }
00190
00191
00192
00193
00194
00195
00196
00197 template<typename T, typename BW>
00198 void bw_types::SubMatrixView<T, BW>::reset(int new_r_begin, int new_r_end,
00199 int new_c_begin, int new_c_end) {
00200 #ifdef DEBUG
00201 Assert(new_r_begin >= 0, ::ExcMessage("View out of matrix bounds."));
00202 Assert(new_r_begin < new_r_end, ::ExcMessage("View out of matrix bounds."));
00203 Assert(new_r_end <= __src->n_rows(), ::ExcMessage("View out of matrix bounds."));
00204
00205 Assert(new_c_begin >= 0, ::ExcMessage("View out of matrix bounds."));
00206 Assert(new_c_begin < new_c_end, ::ExcMessage("View out of matrix bounds."));
00207 Assert(new_c_end <= __src->n_cols(), ::ExcMessage("View out of matrix bounds."));
00208 #else
00209 AssertThrow(new_r_begin >= 0, ::ExcMessage("View out of matrix bounds."));
00210 AssertThrow(new_r_begin < new_r_end, ::ExcMessage("View out of matrix bounds."));
00211 AssertThrow(new_r_end <= __src->n_rows(), ::ExcMessage("View out of matrix bounds."));
00212
00213 AssertThrow(new_c_begin >= 0, ::ExcMessage("View out of matrix bounds."));
00214 AssertThrow(new_c_begin < new_c_end, ::ExcMessage("View out of matrix bounds."));
00215 AssertThrow(new_c_end <= __src->n_cols(), ::ExcMessage("View out of matrix bounds."));
00216 #endif
00217
00218 __r_begin = new_r_begin;
00219 __r_end = new_r_end;
00220 __c_begin = new_c_begin;
00221 __c_end = new_c_end;
00222
00223 __n_el = ((__r_end - __r_begin)*(__c_end - __c_begin));
00224 __view_begin = (__c_begin*__src->n_rows() + __r_begin);
00225 }
00226
00227
00228
00229
00230
00231
00232
00233
00234 template<typename T, typename BW>
00235 void bw_types::SubMatrixView<T, BW>::reinit(int new_r_begin, int new_r_end,
00236 int new_c_begin, int new_c_end) {
00237 reset(new_r_begin, new_r_end, new_c_begin, new_c_end);
00238 }
00239
00240
00241
00242
00243
00244
00245
00246
00247 template<typename T, typename BW>
00248 bw_types::SubMatrixView<T, BW>::SubMatrixView(Matrix<T, BW> & src,
00249 int r_begin, int c_begin)
00250 :
00251 __src(&src),
00252 __r_begin(r_begin),
00253 __c_begin (c_begin),
00254 __r_end(src.n_rows()),
00255 __c_end(src.n_cols()),
00256 __n_el((__r_end - r_begin)*(__c_end - c_begin)),
00257 __view_begin(c_begin*src.n_rows() + r_begin),
00258 __leading_dim(src.n_rows())
00259 {
00260 Assert ((r_begin >= 0) && (r_begin < src.n_rows()),
00261 ::ExcIndexRange (r_begin, 0, src.n_rows()));
00262 Assert ((c_begin >= 0) && (c_begin < src.n_cols()),
00263 ::ExcIndexRange (c_begin, 0, src.n_cols()));
00264
00265 }
00266
00267
00268
00269
00270
00271
00272
00273 template<typename T, typename BW>
00274 bw_types::SubMatrixView<T, BW>::SubMatrixView(Matrix<T, BW> & src,
00275 int r_begin, int r_end,
00276 int c_begin, int c_end)
00277 :
00278 __src(&src),
00279 __r_begin(r_begin),
00280 __c_begin (c_begin),
00281 __r_end(r_end),
00282 __c_end(c_end),
00283 __n_el((r_end - r_begin)*(c_end - c_begin)),
00284 __view_begin(c_begin*src.n_rows() + r_begin),
00285 __leading_dim(src.n_rows())
00286 {
00287
00288 Assert ((r_begin >= 0) && (r_begin < src.n_rows()),
00289 ::ExcIndexRange (r_begin, 0, src.n_rows()));
00290 Assert ((c_begin >= 0) && (c_begin < src.n_cols()),
00291 ::ExcIndexRange (c_begin, 0, src.n_cols()));
00292
00293 Assert ((r_end > r_begin) && (r_end <= src.n_rows()),
00294 ::ExcIndexRange (r_end, 1, src.n_rows()+1));
00295 Assert ((c_end > c_begin) && (c_end <= src.n_cols()),
00296 ::ExcIndexRange (c_end, 1, src.n_cols()+1));
00297
00298 }
00299
00300
00301
00302
00303
00304
00305 template<typename T, typename BW>
00306 bw_types::SubMatrixView<T, BW> &
00307 bw_types::SubMatrixView<T, BW>::operator = (const Matrix<T, BW>& col)
00308 {
00309 Assert(this->__n_el <= col.n_elements(),
00310 ::ExcMessage("Dimension mismatch"));
00311
00312 int incx = 1;
00313 int incy = 1;
00314 int n_rows_2_copy = this->__r_end - this->__r_begin;
00315 int n_cols_2_copy = this->__c_end - this->__c_begin;
00316
00317 std::cout << "n_rows_2_copy : " << n_rows_2_copy << ", "
00318 << "n_cols_2_copy : " << n_cols_2_copy << std::endl;
00319
00320
00321 for (int c = __c_begin; c < this->__c_end; ++c)
00322 BW::copy(n_rows_2_copy, (col.val() + c*(col.__n_rows) + this->__r_begin), incx,
00323 this->__src->val() + this->__r_begin + c*this->__leading_dim,
00324 incy);
00325
00326 return *this;
00327
00328 }
00329
00330
00331
00332
00333
00334
00335 template<typename T, typename BW>
00336 bw_types::SubMatrixView<T, BW> &
00337 bw_types::SubMatrixView<T, BW>::operator = (const SMSMmult<T, BW> & AB)
00338 {
00339 const T * A = &(AB.l.__src->val()[AB.l.__view_begin]) ;
00340 const T * B = &(AB.r.__src->val()[AB.r.__view_begin]) ;
00341 T * C = &(this->__src->val()[this->__view_begin]) ;
00342
00343 T alpha = +1;
00344 T beta = 0.;
00345
00346 int lda = AB.l.__leading_dim;
00347 int ldb = AB.r.__leading_dim;
00348 int ldc = this->__leading_dim;
00349
00350 int m = AB.l.__r_end - AB.l.__r_begin ;
00351 int n = AB.r.__c_end - AB.r.__c_begin ;
00352 int k = AB.l.__c_end - AB.l.__c_begin ;
00353
00354 BW::gemm('n', 'n',
00355 m, n, k,
00356 alpha,
00357 A, lda,
00358 B, ldb,
00359 beta,
00360 C, ldc);
00361
00362 return *this;
00363 }
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375 template<typename T, typename BW>
00376 bw_types::SubMatrixView<T, BW> &
00377 bw_types::SubMatrixView<T, BW>::operator += (const SMSMmult<T, BW> & AB)
00378 {
00379 const T * A = &(AB.l.__src->val()[AB.l.__view_begin]) ;
00380 const T * B = &(AB.r.__src->val()[AB.r.__view_begin]) ;
00381 T * C = &(this->__src->val()[this->__view_begin]) ;
00382
00383 T alpha = +1;
00384 T beta = 1;
00385
00386 int lda = AB.l.__leading_dim;
00387 int ldb = AB.r.__leading_dim;
00388 int ldc = this->__leading_dim;
00389
00390 int m = AB.l.__r_end - AB.l.__r_begin ;
00391 int n = AB.r.__c_end - AB.r.__c_begin ;
00392 int k = AB.l.__c_end - AB.l.__c_begin ;
00393
00394 BW::gemm('n', 'n',
00395 m, n, k,
00396 alpha,
00397 A, lda,
00398 B, ldb,
00399 beta,
00400 C, ldc);
00401
00402 return *this;
00403 }
00404
00405
00406
00407
00408
00409
00410 template<typename T, typename BW>
00411 bw_types::SubMatrixView<T, BW> &
00412 bw_types::SubMatrixView<T, BW>::operator += (const SMSMTmult<T, BW> & AB)
00413 {
00414
00415
00416 const T * A = &(AB.l.__src->val()[AB.l.__view_begin]) ;
00417 const T * B = &(AB.r.A.__src->val()[AB.r.A.__view_begin]) ;
00418 T * C = &(this->__src->val()[this->__view_begin]) ;
00419
00420 T alpha = +1;
00421 T beta = 1;
00422
00423 int lda = AB.l.__leading_dim;
00424 int ldb = AB.r.A.__leading_dim;
00425 int ldc = this->__leading_dim;
00426
00427 int m = AB.l.__r_end - AB.l.__r_begin ;
00428 int n = AB.r.A.__c_end - AB.r.A.__c_begin ;
00429 int k = AB.l.__c_end - AB.l.__c_begin ;
00430
00431 BW::gemm('n', 't',
00432 m, n, k,
00433 alpha,
00434 A, lda,
00435 B, ldb,
00436 beta,
00437 C, ldc);
00438
00439 return *this;
00440 }
00441
00442
00443
00444
00445
00446 template<typename T, typename BW>
00447 bw_types::SubMatrixView<T, BW> &
00448 bw_types::SubMatrixView<T, BW>::operator += (const SMTSMmult<T, BW> & AB)
00449 {
00450
00451 const T * A = &(AB.l.A.__src->val()[AB.l.A.__view_begin]) ;
00452 const T * B = &(AB.r.__src->val()[AB.r.__view_begin]) ;
00453 T * C = &(this->__src->val()[this->__view_begin]) ;
00454
00455 T alpha = +1;
00456 T beta = 1;
00457
00458 int lda = AB.l.A.__leading_dim;
00459 int ldb = AB.r.__leading_dim;
00460 int ldc = this->__leading_dim;
00461
00462 int m = AB.l.A.__c_end - AB.l.A.__c_begin ;
00463 int n = AB.r.__c_end - AB.r.__c_begin ;
00464 int k = AB.l.A.__c_end - AB.l.A.__c_begin ;
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476 BW::gemm('t', 'n',
00477 m, n, k,
00478 alpha,
00479 A, lda,
00480 B, ldb,
00481 beta,
00482 C, ldc);
00483
00484 return *this;
00485 }
00486
00487
00488
00489
00490
00491
00492
00493 template<typename T, typename BW>
00494 bw_types::SubMatrixView<T, BW> &
00495 bw_types::SubMatrixView<T, BW>::operator -= (const SMSMmult<T, BW> & AB)
00496 {
00497 const T * A = &(AB.l.__src->val()[AB.l.__view_begin]) ;
00498 const T * B = &(AB.r.__src->val()[AB.r.__view_begin]) ;
00499 T * C = &(this->__src->val()[this->__view_begin]) ;
00500
00501 T alpha = -1;
00502 T beta = 1;
00503
00504 int lda = AB.l.__leading_dim;
00505 int ldb = AB.r.__leading_dim;
00506 int ldc = this->__leading_dim;
00507
00508 int m = AB.l.__r_end - AB.l.__r_begin ;
00509 int n = AB.r.__c_end - AB.r.__c_begin ;
00510 int k = AB.l.__c_end - AB.l.__c_begin ;
00511
00512 BW::gemm('n', 'n',
00513 m, n, k,
00514 alpha,
00515 A, lda,
00516 B, ldb,
00517 beta,
00518 C, ldc);
00519
00520 return *this;
00521 }
00522
00523
00524
00525
00526
00527
00528 template<typename T, typename BW>
00529 template<typename VECTOR1, typename VECTOR2>
00530 void
00531 bw_types::SubMatrixView<T, BW>::vmult(VECTOR1& dst, const VECTOR2& src) const
00532 {
00533
00534 T alpha = 1.;
00535 T beta = 0.;
00536 int n_rows = __r_end - __r_begin;
00537 int n_cols = __c_end - __c_begin;
00538
00539 Assert(src.size() >= n_cols, ::ExcMessage("Dimension mismatch"));
00540 Assert(dst.size() >= n_rows, ::ExcMessage("Dimension mismatch"));
00541
00542 const int dst_val_begin = (VECTOR1::is_vector_view ? 0 : this->__r_begin );
00543 T *dst_val_ptr = dst.val() + dst_val_begin;
00544
00545
00546 const int src_val_begin = (VECTOR2::is_vector_view ? 0 : this->__c_begin );
00547 const T * const src_val_ptr = src.val() + src_val_begin;
00548
00549 BW::gemv('n', n_rows, n_cols, alpha, __src->val() + __view_begin,
00550 this->__leading_dim, src_val_ptr, 1, beta, dst_val_ptr, 1);
00551 }
00552
00553
00554
00555
00556
00557
00558
00559 template<typename T, typename BW>
00560 template<typename VECTOR1, typename VECTOR2>
00561 void
00562 bw_types::SubMatrixView<T, BW>::Tvmult(VECTOR1& dst, const VECTOR2& src) const
00563 {
00564
00565 T alpha = 1.;
00566 T beta = 0.;
00567 int n_rows = __r_end - __r_begin;
00568 int n_cols = __c_end - __c_begin;
00569
00570 Assert(src.size() >= n_cols, ::ExcMessage("Dimension mismatch"));
00571 Assert(dst.size() >= n_rows, ::ExcMessage("Dimension mismatch"));
00572
00573 const int dst_val_begin = (VECTOR1::is_vector_view ? 0 : this->__c_begin );
00574 T *dst_val_ptr = dst.val() + dst_val_begin;
00575
00576 const int src_val_begin = (VECTOR2::is_vector_view ? 0 :this->__r_begin );
00577 const T * const src_val_ptr = src.val() + src_val_begin;
00578
00579
00580
00581 BW::gemv('t', n_rows, n_cols, alpha, __src->val() + __view_begin,
00582 this->__leading_dim, src_val_ptr, 1, beta, dst_val_ptr, 1);
00583 }
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595 template<typename T, typename BW>
00596 void
00597 bw_types::SubMatrixView<T, BW>::add_scaled_outer_product(T alpha,
00598 const Vector<T, BW>& x,
00599 const Vector<T, BW>& y)
00600 {
00601
00602
00603
00604 int m = __r_end - __r_begin;
00605 int n = __c_end - __c_begin;
00606
00607
00608 int lda = this->__leading_dim;
00609
00610
00611
00612
00613
00614
00615 int incx = 1;
00616 int incy = 1;
00617
00618
00619 const int x_val_begin = this->__r_begin ;
00620 const T * const x_val_ptr = x.val() + x_val_begin;
00621
00622
00623 const int y_val_begin = this->__c_begin;
00624 const T * const y_val_ptr = y.val() + y_val_begin;
00625
00626
00627
00628 BW::ger(m, n, alpha,
00629 x_val_ptr, incx,
00630 y_val_ptr, incy,
00631 __src->val() + __view_begin,
00632 lda);
00633 }
00634
00635
00636
00637
00638
00639 template<typename T, typename BW>
00640 void
00641 bw_types::SubMatrixView<T, BW>::print() const
00642 {
00643
00644 int n_rows_2_copy = this->__r_end - this->__r_begin;
00645 int n_cols_2_copy = this->__c_end - this->__c_begin;
00646
00647
00648 int n_el = this->__n_el;
00649 T * tmp = new T[n_el];
00650
00651 int lda = this->__src->n_rows();
00652 int ldb = n_rows_2_copy;
00653 BW::GetMatrix(n_rows_2_copy, n_cols_2_copy,
00654 &(__src->val()[__view_begin]), lda,
00655 tmp, ldb);
00656
00657 for (int r = 0; r < n_rows_2_copy; ++r)
00658 {
00659 for (int c = 0; c < n_cols_2_copy; ++c)
00660 std::cout << std::setprecision(4)
00661 << std::setw(15) <<
00662
00663 tmp[c*n_rows_2_copy + r]
00664
00665 << " ";
00666 std::cout << std::endl;
00667 }
00668
00669 delete [] tmp;
00670 }
00671
00672
00673
00674
00675
00676
00677 #endif
00678