00001 #ifndef __MATRIX_H
00002 #define __MATRIX_H
00003
00004 #include <valarray>
00005 #include<numeric>
00006
00011 namespace Teem
00012 {
00013 template<typename T> class SliceIter;
00014
00017 template<typename T> bool operator==(const SliceIter<T>& p, const SliceIter<T>& q)
00018 {
00019 return p.curr == q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start();
00020 }
00021
00024 template<typename T> bool operator!=(const SliceIter<T>& p, const SliceIter<T>& q)
00025 {
00026 return !(p==q);
00027 }
00028
00031 template<typename T> bool operator<(const SliceIter<T>& p, const SliceIter<T>& q)
00032 {
00033 return p.curr < q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start();
00034 }
00035
00038 template<typename T> class SliceIter
00039 {
00040 public:
00042 SliceIter(std::valarray<T> *vv, const std::slice &ss) : v(vv), s(ss), curr(0) { }
00043
00045 SliceIter end() const
00046 {
00047 SliceIter t = *this;
00048 t.curr = s.size();
00049 return t;
00050 }
00051
00053 SliceIter& operator++() { curr++; return *this; }
00055 SliceIter operator++(int) { SliceIter t = *this; curr++; return t; }
00056
00058 T& operator[](size_t i) { return ref(i); }
00060 T& operator()(size_t i) { return ref(i); }
00062 T& operator*() { return ref(curr); }
00063
00065 friend bool operator== <>(const SliceIter& p, const SliceIter& q);
00067 friend bool operator!= <>(const SliceIter& p, const SliceIter& q);
00069 friend bool operator< <>(const SliceIter& p, const SliceIter& q);
00070
00071 protected:
00072 std::valarray<T> *v;
00073 const std::slice s;
00074 size_t curr;
00075
00077 T& ref(size_t i) const { return (*v)[s.start() + i * s.stride()]; }
00078 };
00079
00080
00081 template<typename T> class ConstSliceIter;
00082
00085 template<typename T> bool operator==(const ConstSliceIter<T>& p, const ConstSliceIter<T>& q)
00086 {
00087 return p.curr == q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start();
00088 }
00089
00092 template<typename T> bool operator!=(const ConstSliceIter<T>& p, const ConstSliceIter<T>& q)
00093 {
00094 return !(p==q);
00095 }
00096
00099 template<typename T> bool operator<(const ConstSliceIter<T>& p, const ConstSliceIter<T>& q)
00100 {
00101 return p.curr < q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start();
00102 }
00103
00106 template<typename T> class ConstSliceIter
00107 {
00108 public:
00110 ConstSliceIter(const std::valarray<T> *vv, const std::slice &ss) : v(vv), s(ss), curr(0) { }
00111
00113 ConstSliceIter end() const
00114 {
00115 ConstSliceIter t = *this;
00116 t.curr = s.size();
00117 return t;
00118 }
00119
00121 ConstSliceIter& operator++() { curr++; return *this; }
00123 ConstSliceIter operator++(int) { ConstSliceIter t = *this; curr++; return t; }
00124
00126 const T& operator[](size_t i) const { return ref(i); }
00128 const T& operator()(size_t i) const { return ref(i); }
00130 const T& operator*() const { return ref(curr); }
00131
00133 friend bool operator== <>(const ConstSliceIter& p, const ConstSliceIter& q);
00135 friend bool operator!= <>(const ConstSliceIter& p, const ConstSliceIter& q);
00137 friend bool operator< <>(const ConstSliceIter& p, const ConstSliceIter& q);
00138
00139 protected:
00140 const std::valarray<T> *v;
00141 const std::slice s;
00142 size_t curr;
00143
00145 const T& ref(size_t i) const { return (*v)[s.start() + i * s.stride()]; }
00146 };
00147
00150 template<typename T> class Matrix
00151 {
00152 public:
00154 Matrix(size_t nx = 0, size_t ny = 0, T c = T()) : items(c, nx*ny), xDim(nx), yDim(ny) { }
00156 ~Matrix() { }
00157
00159 size_t size() const { return xDim * yDim; }
00161 size_t columnNum() const { return xDim; }
00163 size_t rowNum() const { return yDim; }
00164
00166 std::valarray<T>& flat() { return items; }
00167
00169 const std::valarray<T>& const_flat() const { return items; }
00170
00172 void resize(size_t nx, size_t ny, T c = T()) { items.resize(nx*ny, c); xDim = nx; yDim = ny; }
00173
00175 SliceIter<T> row(size_t i) { return SliceIter<T>(&items, std::slice(i, xDim, yDim)); }
00177 ConstSliceIter<T> row(size_t i) const { return ConstSliceIter<T>(&items, std::slice(i, xDim, yDim)); }
00178
00180 SliceIter<T> column(size_t i) { return SliceIter<T>(&items, std::slice(i*yDim, yDim, 1)); }
00182 ConstSliceIter<T> column(size_t i) const { return ConstSliceIter<T>(&items, std::slice(i*yDim, yDim, 1)); }
00183
00185 T& operator() (size_t x, size_t y) { return column(x)[y]; }
00187 const T& operator() (size_t x, size_t y) const { return column(x)[y]; }
00188
00190 SliceIter<T> operator() (size_t i) { return column(i); }
00192 ConstSliceIter<T> operator() (size_t i) const { return column(i); }
00193
00194 protected:
00195 std::valarray<T> items;
00196 size_t xDim;
00197 size_t yDim;
00198 };
00199
00202 template<typename T> T operator*(const ConstSliceIter<T> &v1, const std::valarray<T> &v2)
00203 {
00204 T res = 0;
00205 for(size_t i = 0; i < v2.size(); i++)
00206 res += v1[i] * v2[i];
00207 return res;
00208 }
00209
00215 template<typename T> std::valarray<T> operator*(const Matrix<T>& m, const std::valarray<T>& v)
00216 {
00217 assert(m.columnNum() == v.size());
00218
00219 std::valarray<T> res(m.rowNum());
00220 for(size_t i = 0; i < m.rowNum(); i++)
00221 res[i] = m.row(i) * v;
00222 return res;
00223 }
00224
00230 template<typename T> std::valarray<T> operator*(const std::valarray<T>& v, const Matrix<T>& m)
00231 {
00232 assert(m.rowNum() == v.size());
00233
00234 std::valarray<T> res(m.columnNum());
00235 for(size_t i = 0; i < m.columnNum(); i++)
00236 res[i] = m.column(i) * v;
00237 return res;
00238 }
00239 }
00240
00241 #endif