00001 #ifndef cublas_Vector_H
00002 #define cublas_Vector_H
00003
00004 #include <lac/expression_template.h>
00005
00006
00007 struct vmu;
00008
00009 #include <lac/cublas_Array.h>
00010
00011 template<typename> class FullMatrixAccessor;
00012
00013 template<typename T> struct PrecisionTraits;
00014
00015 template <>
00016 struct PrecisionTraits<cuComplex> {
00017
00018 typedef float NumberType;
00019
00020 typedef cuComplex CudaComplex;
00021 };
00022
00023 template <>
00024 struct PrecisionTraits<cuDoubleComplex> {
00025
00026 typedef double NumberType;
00027
00028 typedef cuDoubleComplex CudaComplex;
00029 };
00030
00031
00032 template <>
00033 struct PrecisionTraits<float> {
00034
00035 typedef float NumberType;
00036
00037 typedef cuComplex CudaComplex;
00038 };
00039
00040
00041 template <>
00042 struct PrecisionTraits<double> {
00043
00044 typedef double NumberType;
00045
00046 typedef cuDoubleComplex CudaComplex;
00047 };
00048
00049
00050
00051 template <typename T>
00052 struct PrecisionTraits {
00053
00054 typedef T NumberType;
00055
00056 };
00057
00058 template<typename T> struct One
00059 {
00060 typedef T Type;
00061 public:
00062 T operator()(bool plus=true);
00063 };
00064
00065
00066 template<>
00067 inline cuComplex One<cuComplex>::operator ()(bool plus)
00068 {
00069 Type result; result.x = (plus ? 1. : -1); result.y = 0.; return result;
00070 }
00071
00072 template<>
00073 inline cuDoubleComplex One<cuDoubleComplex>::operator ()(bool plus)
00074 {
00075 Type result; result.x = (plus ? 1. : -1); result.y = 0.; return result;
00076 }
00077
00078 template<typename T>
00079 inline T One<T>::operator ()(bool plus) { return (plus ? 1. : -1); }
00080
00081
00082
00083 std::ostream& operator << (std::ostream& out, const cuComplex& c)
00084 {
00085 out << "(" << c.x << ", " << c.y << ")";
00086 return out;
00087 }
00088
00089
00090 std::ostream& operator << (std::ostream& out, const cuDoubleComplex& c)
00091 {
00092 out << "(" << c.x << ", " << c.y << ")";
00093 return out;
00094 }
00095
00096
00097 namespace bw_types {
00098
00099 template<typename, typename> class Array;
00100 template<typename, typename> class Matrix;
00101
00102 template<typename, typename> class VectorView;
00103 template<typename, typename> class ColVectorView;
00104
00105 template<typename, typename> class SubMatrixView;
00106 }
00107
00108 namespace bw_types {
00109
00110
00111
00112
00113
00114
00115 template<typename T, typename BW>
00116 class Vector : public ::Subscriptor, protected Array<T, BW> {
00117
00118
00119
00120 friend class bw_types::Matrix<T, BW>;
00121
00122 friend class bw_types::SubMatrixView<T, BW>;
00123
00124 template<typename, typename> friend class VectorView;
00125
00126 template<typename, typename> friend class ColVectorView;
00127
00128
00129 Vector(const Vector<T, BW> & ) {}
00130
00131 public:
00132
00133 typedef BW blas_wrapper_type;
00134
00135
00136
00137
00138
00139
00140 static const bool is_vector_view = false;
00141
00142
00143
00144 Vector();
00145
00146 Vector(int n_elements);
00147
00148 Vector(int n_elements, const Array<T, BW> & raw_data);
00149
00150 Vector(const FullMatrixAccessor<T> & src,
00151 int r_begin, int c);
00152
00153 Vector(const Matrix<T, BW> & src,
00154 int r_begin, int c);
00155
00156 template<typename M, typename Op>
00157 Vector(const X_read_read<M, Op, Vector<T, BW> > & Ax);
00158
00159 template<typename M, typename Op,
00160 typename T_src>
00161 Vector(const
00162 X_read_read<M, Op, bw_types::ColVectorView<T, T_src> >
00163 & Ax);
00164
00165
00166
00167
00168
00169 inline Array<T, BW> & array() { return *this; }
00170
00171 inline const Array<T, BW> & array() const { return *this; }
00172
00173 Vector<T, BW> & operator = (const Vector<T, BW> & other);
00174 Vector<T, BW> & operator = (const std::vector<T> & other);
00175
00176
00177 template<typename T2>
00178 void push_to(std::vector<T2> & dst) const;
00179
00180 Vector<T, BW> & operator = (const T value);
00181
00182 template<typename T_src>
00183 Vector<T, BW> & operator = (const VectorView<T, T_src > & other);
00184
00185 template<typename M, typename Op>
00186 Vector<T, BW> & operator = (const X_read_read<M, Op, Vector<T, BW> > & Ax);
00187
00188 template<typename M, typename T_src>
00189 Vector<T, BW> & operator =(const
00190 X_read_read<M, vmu, ColVectorView<T, T_src> >
00191 & Ax);
00192
00193
00194 Vector<T, BW> & operator += (const Vector<T, BW> & other);
00195
00196 template<typename T_src>
00197 Vector<T, BW> & operator += (const VectorView<T, T_src > & other);
00198
00199 Vector<T, BW> & operator -= (const Vector<T, BW> & other);
00200
00201 template<typename T_src>
00202 Vector<T, BW> & operator -= (const VectorView<T, T_src > & other);
00203
00204 Vector<T, BW> & operator *= (const T scale);
00205
00206 T operator * (const Vector<T, BW> & other);
00207
00208 Vector<T, BW> & operator /= (const T scale);
00209
00210 template<typename VECTOR>
00211 T dot(const VECTOR & other) const;
00212
00213 void sadd (T alpha, const Vector<T, BW> & other);
00214
00215 void reinit(int new_size);
00216
00217 void print() const;
00218
00219 T l2_norm() const;
00220
00221 typename PrecisionTraits<T>::NumberType sum() const;
00222
00223 int size() const { return this->__n; }
00224
00225 int n_rows() const { return this->__n; }
00226
00227 int n_cols() const { return this->__n; }
00228
00229 T operator () (int k) const;
00230
00231 void set(int k,const T value);
00232
00233 void add(int k,const T value);
00234
00235 protected:
00236
00237
00238 static const int _stride = 1;
00239
00240 };
00241 }
00242
00243
00244
00245
00246
00247
00248
00249
00250 template<typename T, typename BW>
00251 bw_types::Vector<T, BW>::Vector()
00252 :
00253 Array<T, BW>()
00254 {}
00255
00256
00257
00258
00259
00260
00261 template<typename T, typename BW>
00262 bw_types::Vector<T, BW>::Vector(int n_elements)
00263 :
00264 Array<T, BW>(n_elements)
00265 {
00266 this->reinit(n_elements);
00267 }
00268
00269
00270
00271
00272
00273
00274
00275 template<typename T, typename BW>
00276 bw_types::Vector<T, BW>::Vector(int n_elements,
00277 const Array<T, BW> & raw_data)
00278 :
00279 Array<T, BW>(n_elements)
00280 {
00281 this->reinit(n_elements);
00282 *this = raw_data;
00283 }
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293 template<typename T, typename BW>
00294 bw_types::Vector<T, BW>::Vector(const FullMatrixAccessor<T> & src,
00295 int r_begin, int c)
00296 :
00297 Array<T, BW>(src.n_rows() - r_begin)
00298 {
00299
00300 int n_el = src.n_rows() - r_begin;
00301
00302 this->reinit(n_el);
00303
00304 int src_begin = c*src.n_rows() + r_begin;
00305 const T * src_val = src.val();
00306
00307 int inc_src = 1;
00308 if (!src.is_column_major())
00309 {
00310 inc_src = src.n_cols();
00311 src_begin = r_begin*inc_src + c;
00312 }
00313
00314 const T * src_ptr = &(src_val[src_begin]);
00315
00316 BW::SetVector(n_el, src_ptr, inc_src, this->val(), 1);
00317 }
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327 template<typename T, typename BW>
00328 bw_types::Vector<T, BW>::Vector(const Matrix<T, BW> & src,
00329 int r_begin, int c)
00330 :
00331 Array<T, BW>(src.n_rows() - r_begin)
00332 {
00333 int n_el = src.n_rows() - r_begin;
00334
00335 this->reinit(n_el);
00336
00337
00338 int inc_src = 1;
00339 int inc_this = 1;
00340
00341 const T * col = &(src.val()[c*src.n_rows() + r_begin]);
00342
00343
00344 BW::copy(this->__n, col, inc_src, this->val(), inc_this);
00345
00346 }
00347
00348
00349
00350
00351
00352
00353
00354
00355 template<typename T, typename BW>
00356 void bw_types::Vector<T, BW>::reinit(int new_size)
00357 {
00358 this->Array<T, BW>::reinit(new_size);
00359
00360
00361 }
00362
00363
00364
00365
00366
00367
00368
00369 template<typename T, typename BW>
00370 T bw_types::Vector<T, BW>::l2_norm() const
00371 {
00372 return BW::nrm2(this->__n, &(this->val()[0]), 1);
00373 }
00374
00375
00376
00377
00378
00379 template<typename T, typename BW>
00380 typename PrecisionTraits<T>::NumberType bw_types::Vector<T, BW>::sum() const
00381 {
00382 return BW::asum(this->__n, &(this->val()[0]), 1);
00383 }
00384
00385
00386
00387
00388
00389
00390
00391 template<typename T, typename BW>
00392 void
00393 bw_types::Vector<T, BW>::print() const
00394 {
00395
00396 std::vector<T> tmp(this->__n);
00397 T * dst_ptr = &tmp[0];
00398
00399 const T * src_ptr = this->val();
00400
00401 BW::GetVector(this->__n, src_ptr, 1, dst_ptr, 1);
00402
00403 for (int i = 0; i < this->__n; ++i)
00404 std::cout << tmp[i] << std::endl;
00405 }
00406
00407
00408
00409
00410
00411
00412
00413 template<typename T, typename BW>
00414 T
00415 bw_types::Vector<T, BW>::operator () (int k) const
00416 {
00417
00418 std::vector<T> tmp(this->__n);
00419 T * dst_ptr = &tmp[0];
00420
00421 BW::GetVector(1, this->val()+k, 1, dst_ptr, 1);
00422 return tmp[0];
00423 }
00424
00425
00426
00427
00428
00429
00430 template<typename T, typename BW>
00431 void
00432 bw_types::Vector<T, BW>::set(int k,const T value)
00433 {
00434 int inc_src = 1;
00435 int inc_dst = 1;
00436
00437 BW::SetVector(1, &value, inc_src, &(this->val()[k]), inc_dst);
00438 }
00439
00440
00441
00442
00443
00444
00445 template<typename T, typename BW>
00446 void
00447 bw_types::Vector<T, BW>::add(int k,const T value)
00448 {
00449 Assert((k >= 0) && (k < this->size()),
00450 ::ExcMessage("Index out of range") );
00451
00452
00453 Vector<T, BW> tmp_d(1);
00454 tmp_d.set(0,value);
00455
00456
00457 BW::axpy(1, 1, tmp_d.val(), 1,
00458 &(this->val()[k]), 1);
00459 }
00460
00461
00462
00463
00464
00465
00466
00467
00468 template<typename T, typename BW>
00469 bw_types::Vector<T, BW> &
00470 bw_types::Vector<T, BW>::operator = (const Vector<T, BW> & other)
00471 {
00472
00473 if(this->__n != other.__n) this->reinit(other.__n);
00474
00475
00476 int inc_src = 1;
00477 int inc_this = 1;
00478
00479 BW::copy(this->__n, other.val(), inc_src, this->val(), inc_this);
00480
00481 return *this;
00482 }
00483
00484
00485
00486
00487
00488
00489
00490 template<typename T, typename BW>
00491 bw_types::Vector<T, BW> &
00492 bw_types::Vector<T, BW>::operator = (const std::vector<T> & other)
00493 {
00494 this->reinit(other.size());
00495
00496 const T * tmp_ptr = &other[0];
00497
00498 BW::SetVector(size(), tmp_ptr, 1, this->data(), 1);
00499
00500 return *this;
00501 }
00502
00503 template<typename T, typename BW>
00504 template<typename T2>
00505 void
00506 bw_types::Vector<T, BW>::push_to(std::vector<T2> & dst) const
00507 {
00508 dst.resize(this->size());
00509
00510 const T * const src_ptr = this->val();
00511
00512 T * dst_ptr = &dst[0];
00513
00514 static int inc_src = 1;
00515
00516 static int inc_dst = 1;
00517
00518 BW::GetVector(this->size(), src_ptr, inc_src, dst_ptr, inc_dst);
00519 }
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531 template<typename T, typename BW>
00532 template<typename T_src>
00533 bw_types::Vector<T, BW> &
00534 bw_types::Vector<T, BW>::operator = (const VectorView<T, T_src > & other)
00535 {
00536 Assert(this->size() >= other.size(),
00537 ::ExcMessage("Dimension mismatch") );
00538
00539 int incx = other._stride;
00540 int incy = 1;
00541 BW::copy(other.size(), other.val(), incx,
00542 &(this->val()[(other._is_col ?
00543 other.r_begin() : other.c_begin())
00544 ]), incy);
00545
00546 return *this;
00547 }
00548
00549
00550
00551
00552
00553 template<typename T, typename BW>
00554 bw_types::Vector<T, BW> &
00555 bw_types::Vector<T, BW>::operator = (const T value)
00556 {
00557 Vector<T, BW> tmp(1);
00558 tmp.set(0, value);
00559 int incx = 0;
00560 int incy = 1;
00561 BW::copy(tmp.size(), tmp.val(), incx, this->val(), incy);
00562
00563 return *this;
00564 }
00565
00566
00567
00568
00569
00570
00571
00572
00573 template<typename T, typename BW>
00574 bw_types::Vector<T, BW> &
00575 bw_types::Vector<T, BW>::operator += (const Vector<T, BW> & other)
00576 {
00577 One<T> one;
00578 BW::axpy(this->__n, one(), other.val(), 1, this->val(), 1);
00579
00580 return *this;
00581 }
00582
00583
00584
00585
00586
00587
00588 template<typename T, typename BW>
00589 template<typename T_src>
00590 bw_types::Vector<T, BW> &
00591 bw_types::Vector<T, BW>::operator += (const VectorView<T, T_src > & other)
00592 {
00593 Assert(this->size() >= other.size(),
00594 ::ExcMessage("Dimension mismatch") );
00595
00596 BW::axpy(other.size(), 1, other.val(), 1,
00597 &(this->val()[other.r_begin()]), 1);
00598
00599 return *this;
00600 }
00601
00602
00603
00604
00605
00606
00607
00608
00609 template<typename T, typename BW>
00610 bw_types::Vector<T, BW> &
00611 bw_types::Vector<T, BW>::operator *= (const T scale)
00612 {
00613 int elem_dist = 1;
00614
00615 BW::scal(this->__n, scale, &(this->val()[0]), elem_dist);
00616
00617 return *this;
00618 }
00619
00620
00621
00622
00623
00624
00625
00626 template<typename T, typename BW>
00627 bw_types::Vector<T, BW> &
00628 bw_types::Vector<T, BW>::operator -= (const Vector<T, BW> & other)
00629 {
00630 One<T> one;
00631 bool plus = true;
00632
00633 BW::axpy(this->__n, one(!plus), other.val(), 1, this->val(), 1 );
00634
00635 return *this;
00636 }
00637
00638
00639
00640
00641
00642
00643
00644 template<typename T, typename BW>
00645 template<typename T_src>
00646 bw_types::Vector<T, BW> &
00647 bw_types::Vector<T, BW>::operator -= (const VectorView<T, T_src > & other)
00648 {
00649 Assert(this->size() >= other.size(),
00650 ::ExcMessage("Dimension mismatch") );
00651
00652 BW::axpy(other.size(), -1, other.val(), 1,
00653 &(this->val()[other.r_begin()]), 1);
00654
00655 return *this;
00656 }
00657
00658
00659
00660
00661
00662
00663
00664
00665 template<typename T, typename BW>
00666 T
00667 bw_types::Vector<T, BW>::operator * (const Vector<T, BW> & other)
00668 {
00669 Assert(this->size() == other.size(),
00670 ::ExcMessage("Dimension mismatch") );
00671 return this->dot(other);
00672 }
00673
00674
00675
00676
00677
00678
00679
00680
00681 template<typename T, typename BW>
00682 bw_types::Vector<T, BW> &
00683 bw_types::Vector<T, BW>::operator /= (const T scale)
00684 {
00685 int elem_dist = 1;
00686
00687 Assert(scale,
00688 ::ExcMessage("Division by Zero") );
00689
00690 BW::scal(this->__n, 1/scale, this->val(), elem_dist);
00691
00692 return *this;
00693 }
00694
00695
00696
00697
00698
00699
00700
00701 template<typename T, typename BW>
00702 template<typename VECTOR>
00703 T
00704 bw_types::Vector<T, BW>::dot(const VECTOR & other) const
00705 {
00706
00707
00708
00709
00710
00711 Assert(this->size() == other.size(),
00712 ::ExcMessage("Dimension mismatch"));
00713
00714 int incx = 1;
00715 int incy = 1;
00716 T result = BW::dot(this->__n,
00717 this->val(), incx, other.val(), incy);
00718
00719 return result;
00720 }
00721
00722
00723
00724
00725
00726
00727
00728
00729 template<typename T, typename BW>
00730 void
00731 bw_types::Vector<T, BW>::sadd (T alpha, const Vector<T, BW> & other)
00732 {
00733 int incx = 1;
00734 BW::scal(this->__n, alpha, this->val(), incx);
00735
00736
00737 int incy = 1;
00738 BW::axpy(this->__n, 1., other.dev_ptr, 1, this->val(), 1);
00739 }
00740
00741
00742
00743
00744
00745 template<typename M>
00746 struct transpose {
00747
00748 const M & A;
00749
00750 transpose(const M & m) : A(m) {}
00751
00752 };
00753
00754
00755
00756
00757
00758
00759
00760
00761 struct vmu
00762 {
00763 template<typename T, typename BW>
00764 static void apply( bw_types::Vector<T, BW> & b,
00765 const bw_types::Matrix<T, BW> & A,
00766 const bw_types::Vector<T, BW> & x)
00767 {
00768 b.reinit(A.n_rows()); A.vmult(b,x);
00769 }
00770
00771 template<typename T, typename BW>
00772 static void apply( bw_types::Vector<T, BW> & b,
00773 const bw_types::SubMatrixView<T, BW> & A,
00774 const bw_types::Vector<T, BW> & x)
00775 {
00776 b.reinit(A.r_end() - A.r_begin()); A.vmult(b,x);
00777 }
00778
00779
00780 template<typename T, typename BW>
00781 static void apply( bw_types::Vector<T, BW> & b,
00782 const transpose<bw_types::Matrix<T, BW> > & A_t,
00783 const bw_types::Vector<T, BW> & x)
00784 {
00785 b.reinit(A_t.A.n_cols()); A_t.A.Tvmult(b,x);
00786 }
00787
00788
00789
00790 template<typename T, typename BW>
00791 static void apply( bw_types::Vector<T, BW> & b,
00792 const transpose<bw_types::SubMatrixView<T, BW> > & A_t,
00793 const bw_types::Vector<T, BW> & x)
00794 {
00795 b.reinit(A_t.A.matrix().n_rows()); A_t.A.Tvmult(b,x);
00796 }
00797
00798
00799 template<typename T, typename BW, typename T_src>
00800 static void apply( bw_types::Vector<T, BW> & b,
00801 const bw_types::Matrix<T, BW> & A,
00802 const bw_types::ColVectorView<T, T_src> & x)
00803 {
00804 b.reinit(A.n_rows()); A.vmult(b,x);
00805 }
00806 };
00807
00808
00809
00810
00811
00812
00813 template<typename L, typename T, typename BW>
00814 inline
00815 X_read_read<L, vmu, bw_types::Vector<T, BW> >
00816 operator * (const L & _l, const bw_types::Vector<T, BW> & _r)
00817 {
00818 return X_read_read<L, vmu, bw_types::Vector<T, BW> > (_l,_r);
00819 }
00820
00821 template<typename L, typename T,
00822 typename BW, typename T_src>
00823 inline
00824 X_read_read<L, vmu, bw_types::ColVectorView<T, T_src>
00825 >
00826 operator * (const L & _l, const
00827 bw_types::ColVectorView<T, T_src>
00828 & _r)
00829 {
00830 return X_read_read<L, vmu,
00831 bw_types::ColVectorView<T, T_src>
00832 > (_l,_r);
00833 }
00834
00835
00836
00837
00838 template<typename T, typename BW>
00839 template<typename M, typename Op>
00840 bw_types::Vector<T, BW>::Vector(const X_read_read<M, Op, Vector<T, BW> > & Ax)
00841 {
00842 Ax.apply(*this);
00843 }
00844
00845
00846 template<typename T, typename BW>
00847 template<typename M, typename Op
00848 , typename T_src
00849 >
00850 bw_types::Vector<T, BW>::Vector(const
00851 X_read_read<M, Op, bw_types::ColVectorView<T, T_src> >
00852 & Ax)
00853 {
00854 Ax.apply(*this);
00855 }
00856
00857
00858
00859
00860 template<typename T, typename BW>
00861 template<typename M, typename Op>
00862 bw_types::Vector<T, BW> &
00863 bw_types::Vector<T, BW>::operator = (const X_read_read<M, Op, Vector<T, BW> > & Ax)
00864 {
00865 Ax.apply(*this);
00866
00867 return *this;
00868 }
00869
00870
00871
00872 template<typename T, typename BW>
00873 template<typename M,
00874 typename T_src>
00875 bw_types::Vector<T, BW> &
00876 bw_types::Vector<T, BW>::operator =(const
00877 X_read_read<M, vmu, bw_types::ColVectorView<T, T_src> >
00878 & Ax)
00879 {
00880 Ax.apply(*this);
00881
00882 return *this;
00883 }
00884
00885
00886 #endif // cublas_Vector_H
00887
00888