SymmMatrix.h

Go to the documentation of this file.
00001 //
00002 //  Copyright (C) 2004-2006 Rational Discovery LLC
00003 //
00004 //   @@ All Rights Reserved  @@
00005 //
00006 #ifndef __RD_SYMM_MATRIX_H__
00007 #define __RD_SYMM_MATRIX_H__
00008 
00009 #include "Matrix.h"
00010 #include "SquareMatrix.h"
00011 #include <boost/smart_ptr.hpp>
00012 
00013 //#ifndef INVARIANT_SILENT_METHOD
00014 //#define INVARIANT_SILENT_METHOD
00015 //#endif
00016 namespace RDNumeric {
00017   //! A symmetric matrix class
00018   /*! 
00019     The data is stored as the lower triangle, so
00020      A[i,j] = data[i*(i+1) + j] when i >= j and
00021      A[i,j] = data[j*(j+1) + i] when i < j
00022   */
00023   template <class TYPE> class SymmMatrix {
00024   public:
00025     typedef boost::shared_array<TYPE> DATA_SPTR;
00026 
00027     explicit SymmMatrix(unsigned int N) : 
00028       d_size(N), d_dataSize(N*(N+1)/2)  {
00029       TYPE *data = new TYPE[d_dataSize];
00030       memset(static_cast<void *>(data),0,d_dataSize*sizeof(TYPE));
00031       d_data.reset(data);
00032     }
00033 
00034     SymmMatrix(unsigned int N, TYPE val) : 
00035       d_size(N), d_dataSize(N*(N+1)/2)  {
00036       TYPE *data = new TYPE[d_dataSize];
00037       unsigned int i;
00038       for (i = 0; i < d_dataSize; i++) {
00039         data[i] = val;
00040       }
00041       d_data.reset(data);
00042     }
00043     
00044     SymmMatrix(unsigned int N, DATA_SPTR data) :
00045       d_size(N), d_dataSize(N*(N+1)/2)  {
00046       d_data = data;
00047     }
00048     
00049     SymmMatrix(const SymmMatrix<TYPE> &other) :
00050       d_size(other.numRows()), d_dataSize(other.getDataSize())  {
00051       TYPE *data = new TYPE[d_dataSize];
00052       const TYPE *otherData = other.getData();
00053 
00054       memcpy(static_cast<void *>(data), static_cast<const void *>(otherData),
00055              d_dataSize*sizeof(TYPE));
00056       d_data.reset(data);
00057     }
00058 
00059     ~SymmMatrix() {}
00060     
00061     //! returns the number of rows
00062     inline unsigned int numRows() const {
00063       return d_size;
00064     }
00065 
00066     //! returns the number of columns
00067     inline unsigned int numCols() const {
00068       return d_size;
00069     }
00070 
00071     inline unsigned int getDataSize() const {
00072       return d_dataSize;
00073     }
00074 
00075     void setToIdentity() {
00076       TYPE *data = d_data.get();
00077       memset(static_cast<void *>(data), 0, d_dataSize*sizeof(TYPE));
00078       for (unsigned int i = 0; i < d_size; i++) {
00079         data[i*(i+3)/2] = (TYPE)1.0;
00080       }
00081     }
00082 
00083     TYPE getVal(unsigned int i, unsigned int j) const {
00084       RANGE_CHECK(0, i, d_size-1);
00085       RANGE_CHECK(0, j, d_size-1);
00086       unsigned int id;
00087       if (i >= j) {
00088         id = i*(i+1)/2 + j;
00089       } else {
00090         id = j*(j+1)/2 + i;
00091       }
00092       return d_data[id];
00093     }
00094 
00095     void setVal(unsigned int i, unsigned int j, TYPE val) {
00096       RANGE_CHECK(0, i, d_size-1);
00097       RANGE_CHECK(0, j, d_size-1);
00098       unsigned int id;
00099       if (i >= j) {
00100         id = i*(i+1)/2 + j;
00101       } else {
00102         id = j*(j+1)/2 + i;
00103       }
00104       d_data[id] = val;
00105     }
00106 
00107     void getRow(unsigned int i, Vector<TYPE> &row) { 
00108       CHECK_INVARIANT(d_size == row.size(), "");
00109       TYPE *rData  = row.getData(); 
00110       TYPE *data = d_data.get();
00111       for (unsigned int j = 0; j < d_size; j++) {
00112         unsigned int id;
00113         if (j <= i) {
00114           id = i*(i+1)/2 + j;
00115         } else {
00116           id = j*(j+1)/2 + i;
00117         }
00118         rData[j] = data[id];
00119       }
00120     }
00121      
00122     void getCol(unsigned int i, Vector<TYPE> &col) { 
00123       CHECK_INVARIANT(d_size == col.size(), "");
00124       TYPE *rData  = col.getData();
00125       TYPE *data = d_data.get();
00126       for (unsigned int j = 0; j < d_size; j++) {
00127         unsigned int id;
00128         if (i <= j) {
00129           id = j*(j+1)/2 + i;
00130         } else {
00131           id = i*(i+1)/2 + j;
00132         }
00133         rData[j] = data[id];
00134       }
00135     }
00136 
00137     //! returns a pointer to our data array
00138     inline TYPE *getData() {
00139       return d_data.get();
00140     }
00141     
00142     //! returns a const pointer to our data array
00143     inline const TYPE *getData() const {
00144       return d_data.get();
00145     }
00146 
00147     SymmMatrix<TYPE>& operator*=(TYPE scale) {
00148       TYPE *data = d_data.get();
00149       for (unsigned int i = 0; i < d_dataSize; i++) {
00150         data[i] *= scale;
00151       }
00152       return *this;
00153     }
00154 
00155     SymmMatrix<TYPE>& operator/=(TYPE scale) {
00156       TYPE *data = d_data.get();
00157       for (unsigned int i = 0; i < d_dataSize; i++) {
00158         data[i] /= scale;
00159       }
00160       return *this;
00161     }
00162 
00163     SymmMatrix<TYPE>& operator+=(const SymmMatrix<TYPE> &other) {
00164       CHECK_INVARIANT(d_size == other.numRows(), "Sizes don't match in the addition");
00165       const TYPE *oData = other.getData();
00166       TYPE *data = d_data.get();
00167       for (unsigned int i = 0; i < d_dataSize; i++) {
00168         data[i] += oData[i];
00169       }
00170       return *this;
00171     }
00172 
00173     SymmMatrix<TYPE>& operator-=(const SymmMatrix<TYPE> &other) {
00174       CHECK_INVARIANT(d_size == other.numRows(), "Sizes don't match in the addition");
00175       const TYPE *oData = other.getData();
00176       TYPE *data = d_data.get();
00177       for (unsigned int i = 0; i < d_dataSize; i++) {
00178         data[i] -= oData[i];
00179       }
00180       return *this;
00181     }
00182 
00183     //! in-place matrix multiplication
00184     SymmMatrix<TYPE>& operator*=(const SymmMatrix<TYPE> &B) {
00185       CHECK_INVARIANT(d_size == B.numRows(), "Size mismatch during multiplication");
00186       TYPE *cData = new TYPE[d_dataSize];
00187       const TYPE *bData = B.getData();
00188       TYPE *data = d_data.get();
00189       for (unsigned int i = 0; i < d_size; i++) {
00190         unsigned int idC = i*(i+1)/2;
00191         for (unsigned int j = 0; j < i+1; j++) {
00192           unsigned int idCt = idC + j;
00193           cData[idCt] = (TYPE)0.0;
00194           for (unsigned int k = 0; k < d_size; k++) {
00195             unsigned int idA,idB;
00196             if (k <= i) {
00197               idA = i*(i+1)/2 + k;
00198             } else {
00199               idA = k*(k+1)/2 + i;
00200             } 
00201             if (k <= j) {
00202               idB = j*(j+1)/2 + k;
00203             } else {
00204               idB = k*(k+1)/2 + j;
00205             }
00206             cData[idCt] += (data[idA]*bData[idB]);
00207           }
00208         }
00209       }
00210       
00211       for (unsigned int i = 0; i < d_dataSize; i++) {
00212         data[i] = cData[i];
00213       }
00214       delete [] cData;
00215       return (*this);
00216     }
00217 
00218     /* Transpose will basically return a copy of itself
00219      */
00220     SymmMatrix<TYPE>& transpose(SymmMatrix<TYPE> &transpose) const { 
00221       CHECK_INVARIANT(d_size == transpose.numRows(), "Size mismatch during transposing");
00222       TYPE *tData = transpose.getData(); 
00223       TYPE *data = d_data.get();
00224       for (unsigned int i = 0; i < d_dataSize; i++) {
00225         tData[i] = data[i];
00226       }
00227       return transpose;
00228     }
00229 
00230     SymmMatrix<TYPE> &transposeInplace() {
00231       // nothing to be done we are symmetric
00232       return (*this);
00233     }
00234 
00235   protected: 
00236     
00237     SymmMatrix() : d_size(0), d_dataSize(0), d_data(0){};
00238     unsigned int d_size;
00239     unsigned int d_dataSize;
00240     DATA_SPTR d_data;
00241 
00242   private:
00243     SymmMatrix<TYPE>& operator=(const SymmMatrix<TYPE> &other);
00244   };
00245   
00246   //! SymmMatrix-SymmMatrix multiplication 
00247   /*!
00248     Multiply SymmMatrix A with a second SymmMatrix B 
00249     and write the result to C = A*B
00250 
00251     \param A  the first SymmMatrix 
00252     \param B  the second SymmMatrix to multiply 
00253     \param C  SymmMatrix to use for the results
00254     
00255     \return the results of multiplying A by B.
00256     This is just a reference to C.
00257     
00258     This method is reimplemented here for efficiency reasons
00259     (we basically don't want to use getter and setter functions)
00260     
00261   */
00262   template <class TYPE>
00263     SymmMatrix<TYPE>& multiply(const SymmMatrix<TYPE> &A,
00264                                const SymmMatrix<TYPE> &B, 
00265                                SymmMatrix<TYPE> &C) {
00266     unsigned int aSize = A.numRows();
00267     CHECK_INVARIANT(B.numRows() == aSize, "Size mismatch in matric multiplication");
00268     CHECK_INVARIANT(C.numRows() == aSize, "Size mismatch in matric multiplication");
00269     TYPE *cData = C.getData();
00270     const TYPE *aData = A.getData();
00271     const TYPE *bData = B.getData();
00272     for (unsigned int i = 0; i < aSize; i++) {
00273       unsigned int idC = i*(i+1)/2;
00274       for (unsigned int j = 0; j < i+1; j++) {
00275         unsigned int idCt = idC + j;
00276         cData[idCt] = (TYPE)0.0;
00277         for (unsigned int k = 0; k < aSize; k++) {
00278           unsigned int idA,idB;
00279           if (k <= i) {
00280             idA = i*(i+1)/2 + k;
00281           } else {
00282             idA = k*(k+1)/2 + i;
00283           } 
00284           if (k <= j) {
00285             idB = j*(j+1)/2 + k;
00286           } else {
00287             idB = k*(k+1)/2 + j;
00288           }
00289           cData[idCt] += (aData[idA]*bData[idB]);
00290         }
00291       }
00292     }
00293     return C;
00294   }
00295 
00296   //! SymmMatrix-Vector multiplication
00297   /*!
00298     Multiply a SymmMatrix A with a Vector x
00299     so the result is y = A*x
00300     
00301     \param A  the SymmMatrix for multiplication 
00302     \param x  the Vector by which to multiply
00303     \param y  Vector to use for the results
00304     
00305     \return the results of multiplying x by A
00306     This is just a reference to y.
00307     
00308     This method is reimplemented here for efficiency reasons
00309     (we basically don't want to use getter and setter functions)
00310     
00311   */
00312   template <class TYPE>
00313     Vector<TYPE>& multiply(const SymmMatrix<TYPE> &A, const Vector<TYPE> &x, 
00314                                    Vector<TYPE> &y) {
00315     unsigned int aSize = A.numRows();
00316     CHECK_INVARIANT(aSize == x.size(), "Size mismatch during multiplication");
00317     CHECK_INVARIANT(aSize == y.size(), "Size mismatch during multiplication");
00318     const TYPE *xData = x.getData();
00319     const TYPE *aData = A.getData();
00320     TYPE *yData = y.getData();
00321     for (unsigned int i = 0; i < aSize; i++) {
00322       yData[i] = (TYPE)(0.0);
00323       unsigned int idA = i*(i+1)/2;
00324       for (unsigned int j = 0; j < i+1; j++) {
00325         //idA = i*(i+1)/2 + j;
00326         yData[i] += (aData[idA]*xData[j]);
00327         idA++;
00328       }
00329       idA--;
00330       for (unsigned int j = i+1; j < aSize; j++) {
00331         //idA = j*(j+1)/2 + i;
00332         idA += j;
00333         yData[i] += (aData[idA]*xData[j]);
00334       }
00335     }
00336     return y;
00337   }
00338 
00339   typedef SymmMatrix<double> DoubleSymmMatrix;
00340   typedef SymmMatrix<int> IntSymmMatrix;
00341   typedef SymmMatrix<unsigned int> UintSymmMatrix;
00342 }
00343 
00344 //! ostream operator for Matrix's
00345 template <class TYPE > std::ostream & operator<<(std::ostream& target, 
00346                                                  const RDNumeric::SymmMatrix<TYPE> &mat) {
00347   unsigned int nr = mat.numRows();
00348   unsigned int nc = mat.numCols();
00349   target << "Rows: " << mat.numRows() << " Columns: " << mat.numCols() << "\n";
00350 
00351   for (unsigned int i = 0; i < nr; i++) {
00352     for (unsigned int j = 0; j < nc; j++) {
00353       target << std::setw(7) << std::setprecision(3) << mat.getVal(i, j);
00354     }
00355     target << "\n";
00356   }
00357   return target;
00358 }
00359 
00360 #endif
00361     

Generated on Sat May 24 08:36:32 2008 for RDCode by  doxygen 1.5.3