00001 #ifndef FullMatrixAccessor_H
00002 #define FullMatrixAccessor_H
00003
00004 #include <lac/full_matrix.h>
00005
00006
00007 #include <lac/cublas_Matrix.h>
00008
00009
00010
00011
00012
00013 namespace bw_types {
00014
00015 template<typename, typename> class Array;
00016 template<typename, typename> class Matrix;
00017
00018 template<typename, typename> class VectorView;
00019 template<typename, typename> class SubMatrixView;
00020 }
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 template<typename T>
00035 class FullMatrixAccessor : protected ::FullMatrix<T>
00036 {
00037 public:
00038
00039
00040
00041 typedef ::FullMatrix<T> Base;
00042
00043 typedef typename Base::value_type value_type;
00044
00045
00046
00047
00048
00049
00050 FullMatrixAccessor (const unsigned int n=0);
00051
00052
00053 FullMatrixAccessor (const unsigned int rows,
00054 const unsigned int cols, bool is_column_major=false);
00055
00056 FullMatrixAccessor (const ::FullMatrix<T> & other);
00057
00058 FullMatrixAccessor (const ::FullMatrix<T> & other,
00059 bool transpose_copy);
00060
00061 FullMatrixAccessor (const unsigned int rows,
00062 const unsigned int cols,
00063 const T * entries);
00064
00065
00066 FullMatrixAccessor (const ::IdentityMatrix &id);
00067
00068 void reinit (size_t nr, size_t nc){
00069 this->Base::reinit(nr,nc);
00070 }
00071
00072
00073
00074
00075 T * val();
00076
00077 const T * val() const;
00078
00079 template<typename BW>
00080 FullMatrixAccessor<T> & operator = (const bw_types::Matrix<T, BW> & A_d);
00081
00082 template<typename T2>
00083 void push_to(::FullMatrix<T2> & dst) const;
00084
00085
00086 T & operator () (const unsigned int i, const unsigned int j);
00087
00088 const T & operator () (const unsigned int i, const unsigned int j) const;
00089
00090
00091 FullMatrixAccessor<T> & operator += (const FullMatrixAccessor<T> & A);
00092
00093 FullMatrixAccessor<T> & operator -= (const FullMatrixAccessor<T> & A);
00094
00095 FullMatrixAccessor<T> & operator += (const ::FullMatrix<T> & A_h);
00096
00097 FullMatrixAccessor<T> & operator -= (const ::FullMatrix<T> & A_h);
00098
00099 FullMatrixAccessor<T> & operator += (const ::IdentityMatrix & I_h);
00100
00101 FullMatrixAccessor<T> & operator -= (const ::IdentityMatrix & I_h);
00102
00103 T frobenius_norm() const;
00104
00105 bool is_column_major() const;
00106
00107 inline unsigned int n_rows() const { return this->Base::n_rows(); }
00108
00109 inline unsigned int n_cols() const { return this->Base::n_cols(); }
00110
00111 inline unsigned int n_elements () const { return this->table_size[0]*this->table_size[1]; }
00112
00113 void print() { this->::FullMatrix<T>::print(std::cout); }
00114
00115 private:
00116
00117 bool __is_col_major;
00118
00119 };
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140 template<typename T>
00141 FullMatrixAccessor<T>::FullMatrixAccessor (const unsigned int n)
00142 :
00143 Base(n),
00144 __is_col_major(false)
00145 {}
00146
00147
00148
00149 template<typename T>
00150 FullMatrixAccessor<T>::FullMatrixAccessor (const unsigned int rows,
00151 const unsigned int cols,bool is_column_major)
00152 :
00153 Base(rows, cols),
00154 __is_col_major(is_column_major)
00155 {}
00156
00157
00158
00159
00160 template<typename T>
00161 FullMatrixAccessor<T>::FullMatrixAccessor (const ::FullMatrix<T> & other)
00162 :
00163 Base(other),
00164 __is_col_major(false)
00165 {}
00166
00167
00168
00169
00170
00171
00172
00173
00174 template<typename T>
00175 FullMatrixAccessor<T>::FullMatrixAccessor (const ::FullMatrix<T> & other,
00176 bool transpose_copy)
00177 :
00178 __is_col_major(transpose_copy)
00179 {
00180 if (!transpose_copy)
00181 this->copy_from(other);
00182 else {
00183 this->reinit( other.n_rows(), other.n_cols());
00184
00185 const unsigned int this_n_rows = other.n_rows();
00186 const unsigned int this_n_cols = other.n_cols();
00187
00188 T * entries = this->val();
00189 for (unsigned int c=0;c<this_n_cols;++c)
00190 for (unsigned int r=0;r<this_n_rows;++r)
00191 entries[c*this_n_rows + r] = other(r,c);
00192 }
00193 }
00194
00195
00196
00197
00198
00199
00200
00201 template<typename T>
00202 FullMatrixAccessor<T>::FullMatrixAccessor (const unsigned int rows,
00203 const unsigned int cols,
00204 const T * entries)
00205 :
00206 Base(rows, cols, entries),
00207 __is_col_major(false)
00208 {}
00209
00210
00211
00212
00213 template<typename T>
00214 FullMatrixAccessor<T>::FullMatrixAccessor (const ::IdentityMatrix &id)
00215 :
00216 Base(id),
00217 __is_col_major(false)
00218 {}
00219
00220
00221
00222
00223
00224 template<typename T>
00225 template<typename BW>
00226 FullMatrixAccessor<T> &
00227 FullMatrixAccessor<T>::operator = (const bw_types::Matrix<T, BW> & A_d)
00228 {
00229
00230
00231 int nr = A_d.n_rows();
00232 int nc = A_d.n_cols();
00233 this->reinit(nr, nc);
00234 this->__is_col_major = true;
00235
00236 const T * src = A_d.val();
00237
00238 T * dst = this->val();
00239
00240 BW::GetMatrix(nr, nc,
00241 src, nr,
00242 dst, nr);
00243
00244 src = 0;
00245 dst = 0;
00246
00247 return *this;
00248 }
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258 template <typename T>
00259 template<typename T2>
00260 void FullMatrixAccessor<T>::push_to(::FullMatrix<T2> & dst) const
00261 {
00262
00263 if (this->is_column_major())
00264 {
00265 dst.reinit(this->n_rows(), this->n_cols());
00266 for (unsigned int r = 0; r < this->n_rows(); r++)
00267 for(unsigned int c = 0; c < this->n_cols(); c++)
00268 dst(r,c) = (*this)(r,c);
00269 }
00270 else {
00271 Base & my_self = *this;
00272
00273 dst = my_self;
00274 }
00275
00276 }
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286 template <typename T>
00287 inline
00288 T &
00289 FullMatrixAccessor<T>::operator () (const unsigned int r,
00290 const unsigned int c)
00291 {
00292 Assert (r < this->table_size[0],
00293 ::ExcIndexRange (r, 0, this->table_size[0]));
00294 Assert (c < this->table_size[1],
00295 ::ExcIndexRange (c, 0, this->table_size[1]));
00296 if (this->__is_col_major)
00297 return this->val()[c*this->table_size[0]+r];
00298 else
00299 return this->val()[r*this->table_size[1]+c];
00300 }
00301
00302
00303
00304
00305 template <typename T>
00306 inline
00307 const T &
00308 FullMatrixAccessor<T>::operator () (const unsigned int r,
00309 const unsigned int c) const
00310 {
00311 Assert (r < this->table_size[0],
00312 ::ExcIndexRange (r, 0, this->table_size[0]));
00313 Assert (c < this->table_size[1],
00314 ::ExcIndexRange (c, 0, this->table_size[1]));
00315 if (this->__is_col_major)
00316 return this->val()[c*this->table_size[0]+r];
00317 else
00318 return this->val()[r*this->table_size[1]+c];
00319 }
00320
00321
00322
00323
00324
00325
00326
00327
00328 template<typename T>
00329 T * FullMatrixAccessor<T>::val()
00330 {
00331 return this->Base::val;
00332 }
00333
00334 template<typename T>
00335 const T * FullMatrixAccessor<T>::val() const
00336 {
00337 return this->Base::val;
00338 }
00339
00340
00341
00342
00343
00344 template<typename T>
00345 bool
00346 FullMatrixAccessor<T>::is_column_major() const
00347 {
00348 return this->__is_col_major;
00349 }
00350
00351
00352
00353
00354
00355 template<typename T>
00356 FullMatrixAccessor<T> &
00357 FullMatrixAccessor<T>::operator += (const FullMatrixAccessor<T> & A)
00358 {
00359 Assert(this->n_rows()==A.n_rows(), ::ExcMessage("Dimension mismatch"));
00360 Assert(this->n_cols()==A.n_cols(), ::ExcMessage("Dimension mismatch"));
00361 for(int i=0; i<this->n_rows();i++){
00362 for(int j = 0; j < this->n_cols() ; j++){
00363 (*this)(i,j) += A(i,j);
00364 }
00365 }
00366 return *this;
00367 }
00368
00369
00370
00371
00372
00373 template<typename T>
00374 FullMatrixAccessor<T> &
00375 FullMatrixAccessor<T>::operator -= (const FullMatrixAccessor<T> & A)
00376 {
00377 Assert(this->n_rows()==A.n_rows(), ::ExcMessage("Dimension mismatch"));
00378 Assert(this->n_cols()==A.n_cols(), ::ExcMessage("Dimension mismatch"));
00379 for(int i=0; i<this->n_rows();i++){
00380 for(int j = 0; j < this->n_cols() ; j++){
00381 (*this)(i,j) -= A(i,j);
00382 }
00383 }
00384 return *this;
00385 }
00386
00387
00388
00389
00390
00391
00392
00393 template<typename T>
00394 FullMatrixAccessor<T> &
00395 FullMatrixAccessor<T>::operator += (const ::FullMatrix<T> & A_h)
00396 {
00397 Assert(this->n_rows()==A_h.n_rows(), ::ExcMessage("Dimension mismatch"));
00398 Assert(this->n_cols()==A_h.n_cols(), ::ExcMessage("Dimension mismatch"));
00399 for(int i=0; i<this->n_rows();i++){
00400 for(int j = 0; j < this->n_cols() ; j++){
00401 (*this)(i,j) += A_h(i,j);
00402 }
00403 }
00404 return *this;
00405 }
00406
00407
00408
00409
00410
00411 template<typename T>
00412 FullMatrixAccessor<T> &
00413 FullMatrixAccessor<T>::operator -= (const ::FullMatrix<T> & A_h)
00414 {
00415 Assert(this->n_rows()==A_h.n_rows(), ::ExcMessage("Dimension mismatch"));
00416 Assert(this->n_cols()==A_h.n_cols(), ::ExcMessage("Dimension mismatch"));
00417 for(int i=0; i<this->n_rows();i++){
00418 for(int j = 0; j < this->n_cols() ; j++){
00419 (*this)(i,j) -= A_h(i,j);
00420 }
00421 }
00422 return *this;
00423 }
00424
00425
00426
00427
00428
00429
00430 template<typename T>
00431 FullMatrixAccessor<T> &
00432 FullMatrixAccessor<T>::operator += (const ::IdentityMatrix & I_h)
00433 {
00434 Assert(this->n_rows()==I_h.n(), ::ExcMessage("Dimension mismatch"));
00435 Assert(this->n_cols()==I_h.n(), ::ExcMessage("Dimension mismatch"));
00436 for(int i=0; i<this->n_rows();i++){
00437 (*this)(i,i) += 1;
00438 }
00439 return *this;
00440 }
00441
00442
00443
00444
00445
00446 template<typename T>
00447 FullMatrixAccessor<T> &
00448 FullMatrixAccessor<T>::operator -= (const ::IdentityMatrix & I_h)
00449 {
00450 Assert(this->n_rows()==I_h.n(), ::ExcMessage("Dimension mismatch"));
00451 Assert(this->n_cols()==I_h.n(), ::ExcMessage("Dimension mismatch"));
00452 for(int i=0; i<this->n_rows();i++){
00453 (*this)(i,i) -= 1;
00454 }
00455 return *this;
00456 }
00457
00458
00459
00460
00461
00462
00463 template<typename T>
00464 T
00465 FullMatrixAccessor<T>::frobenius_norm() const
00466 {
00467 return this->Base::frobenius_norm();
00468 }
00469
00470
00471
00472
00473 #endif // FULLMATRIXACCESSOR_H
00474