00001 #ifndef cublas_VectorView_H
00002 #define cublas_VectorView_H
00003
00004
00005
00006
00007
00008 struct vmu_view;
00009
00010
00011
00012 namespace bw_types {
00013
00014
00015
00016
00017
00018
00019
00020
00021 template<typename T, typename T_src>
00022 class VectorView {
00023
00024 public:
00025 typedef typename T_src::blas_wrapper_type BW;
00026
00027 friend class Vector<T, BW>;
00028
00029 friend class Matrix<T, BW>;
00030
00031 friend class SubMatrixView<T, BW>;
00032
00033 template<typename, typename> friend class ColVectorView;
00034
00035
00036
00037
00038
00039
00040
00041 static const bool is_vector_view = true;
00042
00043 VectorView(T_src & src,
00044 int r_begin, int c);
00045
00046 const T * val() const;
00047
00048 T l2_norm() const;
00049
00050 template<typename VECTOR>
00051 T dot(const VECTOR & other) const;
00052
00053 void print() const;
00054
00055 VectorView<T, T_src> & operator = (const Vector<T, BW>& col);
00056
00057 int r_begin() const { return __r_begin; }
00058
00059 int c_begin() const { return __col; }
00060
00061 int size() const { return __n_el; }
00062
00063 template<typename T2_src>
00064 VectorView & operator += (const VectorView <T, T2_src> &other);
00065
00066 VectorView & operator -= (const VectorView <T, T_src> &other);
00067 VectorView & operator *= (const T alpha);
00068 VectorView & operator /= (const T alpha);
00069
00070
00071
00072 T * val();
00073
00074 VectorView & operator = (const VectorView<T, T_src> & other)
00075 {
00076
00077
00078
00079 Assert(this->__n_el == other.__n_el,
00080 ::ExcMessage("Cannot copy subarrays of different lengths"));
00081
00082
00083
00084
00085
00086 int inc_src = other._stride;
00087 int inc_dst = this->_stride;
00088
00089 BW::copy(this->__n_el, other.val(), inc_src, this->val(), inc_dst);
00090
00091 return *this;
00092 }
00093 private:
00094 VectorView() {}
00095
00096 VectorView(const VectorView<T, T_src> & ) {}
00097
00098
00099
00100
00101
00102
00103
00104 ::SmartPointer<T_src> __src;
00105
00106
00107 int __r_begin;
00108 int __col;
00109
00110 protected:
00111 int __n_el;
00112
00113 private:
00114 int __view_begin;
00115
00116 protected:
00117 bool _is_col;
00118
00119 int _stride;
00120 };
00121
00122
00123
00124
00125
00126
00127 template<typename T, typename T_src>
00128 class ColVectorView : public VectorView<T, T_src> {
00129
00130 public:
00131 typedef VectorView<T, T_src> Base;
00132
00133 ColVectorView(T_src & src,
00134 int r_begin, int c=0) : Base(src, r_begin, c) { this->_stride = 1; }
00135
00136
00137 template<typename M>
00138 ColVectorView(const
00139 X_read_read<M, vmu_view, bw_types::ColVectorView<T, T_src> >
00140 & Ax);
00141
00142
00143 template<typename M>
00144 ColVectorView<T, T_src> & operator = (const
00145 X_read_read<M, vmu_view, ColVectorView<T, T_src> >
00146 & Ax)
00147 {
00148 Ax.apply(*this);
00149 return *this;
00150 }
00151
00152 ColVectorView<T, T_src> & operator = (const Vector<T, typename Base::BW>& col)
00153 {
00154 Assert(this->__n_el == col.size(),
00155 ::ExcMessage("Dimension mismatch"));
00156
00157 int incx = 1;
00158 int incy = this->_stride;
00159 Base::BW::copy(this->__n_el, col.val(), incx, this->val(), incy);
00160
00161 return *this;
00162 }
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179 template<typename T2_src>
00180 ColVectorView<T, T_src> & operator += (const ColVectorView<T, T2_src> &other)
00181 {
00182 Base & self = *this;
00183 const typename ColVectorView<T, T2_src>::Base & o = other;
00184 self += o;
00185
00186 return *this;
00187 }
00188
00189 void reset(int r_begin, int c=0)
00190 {
00191 this->__r_begin = r_begin;
00192 this->__col = c;
00193 this->__n_el = this->__src->n_rows() - r_begin;
00194 this->__view_begin = c*this->__src->n_rows() + r_begin;
00195 }
00196
00197
00198 ColVectorView & operator = (const ColVectorView<T, T_src> & other)
00199 {
00200 Base & self = *this;
00201 const Base & src = other;
00202 self = src;
00203
00204 return *this;
00205 }
00206 };
00207
00208
00209
00210
00211
00212
00213 template<typename T, typename T_src>
00214 class RowVectorView : public VectorView<T, T_src> {
00215
00216 public:
00217 typedef VectorView<T, T_src> Base;
00218
00219 RowVectorView(T_src & src,
00220 int r_begin, int c) : Base(src, r_begin, c)
00221 {
00222 this->__n_el = src.n_cols() - c;
00223 this->_stride = src.n_rows(); this->_is_col = false;
00224 }
00225
00226 };
00227 }
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239 template<typename T, typename T_src>
00240 bw_types::VectorView<T, T_src >::VectorView(T_src & src,
00241 int r_begin, int c)
00242 :
00243 __src(&src),
00244 __r_begin(r_begin),
00245 __col (c),
00246 __n_el(src.n_rows() - r_begin),
00247 __view_begin(c*src.n_rows() + r_begin),
00248 _is_col(true),
00249 _stride(1)
00250 {}
00251
00252
00253
00254
00255
00256
00257
00258
00259 template<typename T, typename T_src>
00260 const T *
00261 bw_types::VectorView<T, T_src>::val() const
00262 {
00263
00264 return &(this->__src->val()[__view_begin]);
00265 }
00266
00267
00268 template<typename T, typename T_src>
00269 T *
00270 bw_types::VectorView<T, T_src>::val()
00271 {
00272
00273 return &(this->__src->val()[__view_begin]);
00274 }
00275
00276
00277
00278
00279
00280 template<typename T, typename T_src>
00281 T
00282 bw_types::VectorView<T, T_src>::l2_norm() const
00283 {
00284 T result = BW::nrm2(this->__n_el,
00285 this->val(), this->_stride);
00286 return result;
00287 }
00288
00289
00290
00291
00292
00293
00294 template<typename T, typename T_src>
00295 template<typename VECTOR>
00296 T
00297 bw_types::VectorView<T, T_src>::dot(const VECTOR & other) const
00298 {
00299
00300 Assert(this->__n_el == other.size(),
00301 ::ExcMessage("Dimension mismatch"));
00302
00303 int incx = this->_stride;
00304 int incy = other._stride;
00305 T result = BW::dot(this->__n_el,
00306 this->val(), incx, other.val(), incy);
00307
00308 return result;
00309 }
00310
00311
00312
00313
00314
00315
00316
00317 template<typename T, typename T_src>
00318 template<typename T2_src>
00319 bw_types::VectorView<T, T_src> &
00320 bw_types::VectorView<T, T_src> ::operator +=(const bw_types::VectorView <T, T2_src> &other)
00321 {
00322 Assert(this->__n_el == other.size(),
00323 ::ExcMessage("Dimension mismatch"));
00324
00325 int incx = other._stride;
00326 int incy = this->_stride;
00327
00328 BW::axpy(this->__n_el, 1, other.val(), incx,this->val(), incy);
00329
00330 return *this;
00331 }
00332
00333
00334
00335
00336
00337
00338 template<typename T, typename T_src>
00339 bw_types::VectorView<T, T_src> &
00340 bw_types::VectorView<T, T_src> ::operator -=(const bw_types::VectorView <T, T_src> &other)
00341 {
00342 Assert(this->__n_el == other.size(),
00343 ::ExcMessage("Dimension mismatch"));
00344
00345 int incx = other._stride;
00346 int incy = this->_stride;
00347
00348 BW::axpy(this->__n_el, -1, other.val(), incx,this->val(), incy);
00349
00350 return *this;
00351 }
00352
00353
00354
00355
00356
00357 template<typename T, typename T_src>
00358 bw_types::VectorView<T, T_src> &
00359 bw_types::VectorView<T, T_src> ::operator *=(const T alpha)
00360 {
00361 int incx = this->_stride;
00362
00363 BW::scal(this->__n_el, alpha, this->val(), incx);
00364
00365 return *this;
00366 }
00367
00368
00369
00370
00371
00372 template<typename T, typename T_src>
00373 bw_types::VectorView<T, T_src> &
00374 bw_types::VectorView<T, T_src> ::operator /=(const T alpha)
00375 {
00376 #ifdef DEBUG
00377 Assert(alpha,
00378 ::ExcMessage("Div/0"));
00379 #else
00380 AssertThrow(alpha,
00381 ::ExcMessage("Div/0"));
00382 #endif
00383
00384 int incx = this->_stride;
00385
00386 BW::scal(this->__n_el, (1/alpha), this->val(), incx);
00387
00388 return *this;
00389 }
00390
00391
00392
00393
00394
00395 template<typename T, typename T_src>
00396 bw_types::VectorView<T, T_src> &
00397 bw_types::VectorView<T, T_src>::operator = (const Vector<T, BW>& col)
00398 {
00399 Assert(this->__n_el == col.size(),
00400 ::ExcMessage("Dimension mismatch"));
00401
00402 int incx = 1;
00403 int incy = this->_stride;
00404 BW::copy(this->__n_el, col.val(), incx, this->val(), incy);
00405
00406 return *this;
00407 }
00408
00409
00410
00411
00412
00413 template<typename T, typename T_src>
00414 void
00415 bw_types::VectorView<T, T_src>::print() const
00416 {
00417 std::vector<T> tmp(this->__n_el);
00418 int inc_src = this->_stride;
00419 int inc_dst = 1;
00420
00421 T * tmp_ptr = &tmp[0];
00422 BW::GetVector(this->__n_el, this->val(), inc_src, tmp_ptr, inc_dst);
00423
00424 if (this->_is_col)
00425 for (int i = 0; i < this->__n_el; ++i)
00426 std::cout << tmp[i] << std::endl;
00427 else
00428 for (int i = 0; i < this->__n_el; ++i)
00429 std::cout << tmp[i] << " ";
00430 std::cout << std::endl;
00431
00432 }
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444 struct vmu_view
00445 {
00446 template<typename T, typename BW, typename T_src>
00447 static void apply( bw_types::ColVectorView<T, T_src> & b,
00448 const bw_types::Matrix<T, BW> & A,
00449 const bw_types::ColVectorView<T, T_src> & x)
00450 {
00451 A.vmult(b,x);
00452 }
00453
00454
00455 template<typename T, typename BW, typename T_src>
00456 static void apply( bw_types::ColVectorView<T, T_src> & b,
00457 const transpose<bw_types::Matrix<T, BW> > & A_t,
00458 const bw_types::ColVectorView<T, T_src> & x)
00459 {
00460 A_t.A.Tvmult(b,x);
00461 }
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490
00491
00492 };
00493
00494
00495
00496
00497
00498
00499
00500 template<typename L, typename T,
00501 typename BW
00502 >
00503 inline
00504 X_read_read<L, vmu_view, bw_types::ColVectorView<T, bw_types::Matrix<T, BW> >
00505 >
00506 operator * (const L & _l, const bw_types::ColVectorView<T,
00507 bw_types::Matrix<T, BW> >
00508 & _r)
00509 {
00510 typedef bw_types::Matrix<T, BW> T_src;
00511 return X_read_read<L, vmu_view, bw_types::ColVectorView<T, T_src>
00512 > (_l,_r);
00513 }
00514
00515
00516
00517
00518 template<typename T, typename T_src>
00519 template<typename M>
00520 bw_types::ColVectorView<T, T_src>::ColVectorView(const
00521 X_read_read<M, vmu_view, bw_types::ColVectorView<T, T_src> >
00522 & Ax)
00523 {
00524 Ax.apply(*this);
00525 }
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561 #endif
00562