add addition operation for two csr matrices

This commit is contained in:
Irlan 2018-05-19 21:13:09 -03:00
parent 1ccb411fb6
commit a251a9e180

View File

@ -20,9 +20,10 @@
#define B3_SPARSE_MAT_33_H
#include <bounce/common/math/mat33.h>
#include <bounce/dynamics/cloth/diag_mat33.h>
#include <bounce/dynamics/cloth/dense_vec3.h>
// A static sparse matrix stored in in Compressed Sparse Row (CSR) format.
// A static sparse matrix stored in Compressed Sparse Row (CSR) format.
// It's efficient when using in iterative solvers such as CG method, where
// the coefficient matrix must be multiplied with a vector at each iteration.
// See https://en.wikipedia.org/wiki/Sparse_matrix
@ -30,16 +31,16 @@ struct b3SparseMat33
{
b3SparseMat33() { }
b3SparseMat33(u32 _M, u32 _N,
b3SparseMat33(u32 _M, u32 _N,
u32 _valueCount, b3Mat33* _values,
u32* _row_ptrs, u32* _cols)
{
M = _M;
N = _N;
values = _values;
valueCount = _valueCount;
row_ptrs = _row_ptrs;
cols = _cols;
}
@ -48,7 +49,7 @@ struct b3SparseMat33
{
}
// Output any given row of the original matrix.
// The given buffer must have size of greater or equal than M.
void AssembleRow(b3Mat33* out, u32 row) const;
@ -61,10 +62,7 @@ struct b3SparseMat33
// Output the block diagonal part of the original matrix.
// This matrix must be a square matrix.
// The given buffer must have size of greater or equal than M.
void AssembleDiagonal(b3Mat33* out) const;
// Multiplies this matrix with a given (compatible) vector.
void Mul(b3DenseVec3& out, const b3DenseVec3& v) const;
void AssembleDiagonal(b3DiagMat33& out) const;
// Dimensions of the original 2D matrix
u32 M;
@ -73,7 +71,7 @@ struct b3SparseMat33
// Non-zero values
b3Mat33* values;
u32 valueCount;
// Sparsity structure
u32* row_ptrs; // pointers to the first non-zero value of each row (size is M + 1)
u32* cols; // column indices for each non-zero value (size is valueCount)
@ -105,16 +103,21 @@ inline void b3SparseMat33::AssembleMatrix(b3Mat33* out) const
}
}
inline void b3SparseMat33::AssembleDiagonal(b3Mat33* out) const
inline void b3SparseMat33::AssembleDiagonal(b3DiagMat33& out) const
{
B3_ASSERT(M == N);
for (u32 row = 0; row < M; ++row)
{
out[row].SetZero();
for (u32 row_ptr = row_ptrs[row]; row_ptr < row_ptrs[row + 1]; ++row_ptr)
{
if (cols[row_ptr] > row)
{
break;
}
if (cols[row_ptr] == row)
{
out[row] = values[row_ptr];
@ -124,19 +127,19 @@ inline void b3SparseMat33::AssembleDiagonal(b3Mat33* out) const
}
}
inline void b3SparseMat33::Mul(b3DenseVec3& out, const b3DenseVec3& v) const
inline void b3Mul(b3DenseVec3& out, const b3SparseMat33& A, const b3DenseVec3& v)
{
B3_ASSERT(N == out.n);
B3_ASSERT(A.N == out.n);
for (u32 i = 0; i < N; ++i)
for (u32 row = 0; row < A.N; ++row)
{
out[i].SetZero();
out[row].SetZero();
for (u32 j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j)
for (u32 j = A.row_ptrs[row]; j < A.row_ptrs[row + 1]; ++j)
{
u32 col = cols[j];
out[i] += values[j] * v[col];
u32 col = A.cols[j];
out[row] += A.values[j] * v[col];
}
}
}
@ -144,8 +147,70 @@ inline void b3SparseMat33::Mul(b3DenseVec3& out, const b3DenseVec3& v) const
inline b3DenseVec3 operator*(const b3SparseMat33& A, const b3DenseVec3& v)
{
b3DenseVec3 result(v.n);
A.Mul(result, v);
b3Mul(result, A, v);
return result;
}
inline void b3Add(b3SparseMat33& out, b3SparseMat33& A, const b3SparseMat33& B)
{
B3_ASSERT(A.M == B.M);
B3_ASSERT(A.N == B.N);
B3_ASSERT(A.M == out.M);
B3_ASSERT(A.N == out.N);
// out = A
for (u32 i = 0; i < A.valueCount; ++i)
{
out.values[i] = A.values[i];
out.cols[i] = A.cols[i];
}
out.valueCount = A.valueCount;
for (u32 i = 0; i < A.M + 1; ++i)
{
out.row_ptrs[i] = A.row_ptrs[i];
}
// out += B
for (u32 i = 0; i < B.M; ++i)
{
for (u32 row_ptr_B = B.row_ptrs[i]; row_ptr_B < B.row_ptrs[i + 1]; ++row_ptr_B)
{
u32 col_B = B.cols[row_ptr_B];
// Does A has a non-zero element that exist in B?
u32 row_ptr_A = A.row_ptrs[i];
while (row_ptr_A != A.row_ptrs[i + 1])
{
u32 col_A = A.cols[row_ptr_A];
if (col_A > col_B)
{
break;
}
++row_ptr_A;
}
u32 col_A = A.cols[row_ptr_A];
if (col_A == col_B)
{
out.values[col_A] += B.values[col_B];
}
else
{
out.values[out.valueCount] = B.values[col_B];
out.cols[out.valueCount] = col_B;
++out.valueCount;
out.row_ptrs[i + 1] = out.row_ptrs[(i + 1) - 1] + col_A;
}
}
}
B3_ASSERT(out.valueCount <= A.N);
}
#endif