@@ -142,14 +142,14 @@ class SparseTensorDenseMatMulOp : public OpKernel {
#define KDNN_ADJOINT(ADJ_A, ADJ_B) \
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
Status functor_status = functor::KDNNSparseMatMulFunctor< \
- CPUDevice, float, Tindices, ADJ_A, \
- ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<float>(), \
+ Device, float, Tindices, ADJ_A, \
+ ADJ_B>::Compute(ctx, out->matrix<float>(), \
a_indices->matrix<Tindices>(), a_values->vec<float>(), \
b->matrix<float>()); \
OP_REQUIRES_OK(ctx, functor_status); \
}
const int int32max = std::numeric_limits<int>::max();
- if (IsKDNNEnabled() && std::is_same<T, float>::value && adjoint_a_ == false &&
+ if (IsKDNNEnabled() && std::is_same<Device, CPUDevice>::value && std::is_same<T, float>::value && adjoint_a_ == false &&
FastBoundsCheck(inner_left, int32max) &&
FastBoundsCheck(inner_right, int32max) &&
FastBoundsCheck(outer_left, int32max)) {
@@ -387,7 +387,7 @@ template <typename Tindices, bool ADJ_A, bool ADJ_B>
struct KDNNSparseMatMulFunctor<CPUDevice, float, Tindices, ADJ_A, ADJ_B> {
static const std::size_t kNumVectorize = 32;
- static Status Compute(const CPUDevice& d, typename TTypes<float>::Matrix out,
+ static Status Compute(OpKernelContext* ctx, typename TTypes<float>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<float>::ConstVec a_values,
typename TTypes<float>::ConstMatrix b) {
@@ -426,7 +426,7 @@ struct KDNNSparseMatMulFunctor<CPUDevice, float, Tindices, ADJ_A, ADJ_B> {
col_major_conj_b = b.swap_layout().shuffle(shuffle).conjugate().eval();
b_data = col_major_conj_b.data();
}
- kdnnSparseMatmul<Tindices>(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b_data);
+ kdnnSparseMatmul<Tindices>(ctx, nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b_data);
}
return OkStatus();
}
@@ -40,7 +40,7 @@ template <typename Device, typename T, typename Tindices, bool ADJ_A,
bool ADJ_B>
struct KDNNSparseMatMulFunctor {
static EIGEN_ALWAYS_INLINE Status Compute(
- const Device& d, typename TTypes<T>::Matrix out,
+ OpKernelContext* ctx, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
};
@@ -16,13 +16,331 @@ limitations under the License.
#include <random>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
+// Fully naive reference for all adjoint combinations
+static void ReferenceSparseMatmulFn(
+ const int nnz,
+ const int64_t* a_indices_data,
+ const float* a_values_data,
+ const float* b_data,
+ int lhs_left, int lhs_right, int rhs_right,
+ bool adjoint_a, bool adjoint_b,
+ float* out_data) {
+ // Initialize output to zero
+ for (int i = 0; i < lhs_left * rhs_right; ++i) out_data[i] = 0.0f;
+
+ const int lhs_index_a = adjoint_a ? 1 : 0;
+ const int rhs_index_a = adjoint_a ? 0 : 1;
+
+ for (int i = 0; i < nnz; ++i) {
+ int row = a_indices_data[i * 2 + lhs_index_a];
+ int col = a_indices_data[i * 2 + rhs_index_a];
+ float a_val = a_values_data[i];
+ if (adjoint_b) {
+ for (int n = 0; n < rhs_right; ++n) {
+ out_data[row * rhs_right + n] += a_val * b_data[n * lhs_right + col];
+ }
+ } else {
+ for (int n = 0; n < rhs_right; ++n) {
+ out_data[row * rhs_right + n] += a_val * b_data[col * rhs_right + n];
+ }
+ }
+ }
+}
+
+// ===========================================================================
+// SparseMatMul accuracy tests (Section 10.1.1 of design doc)
+// ===========================================================================
+
+class SparseMatmulAccuracyTest : public OpsTestBase {
+ protected:
+ void MakeOp(bool adjoint_a, bool adjoint_b) {
+ TF_CHECK_OK(NodeDefBuilder("sparse_matmul", "SparseTensorDenseMatMul")
+ .Input(FakeInput(DT_INT64)) // a_indices
+ .Input(FakeInput(DT_FLOAT)) // a_values
+ .Input(FakeInput(DT_INT64)) // a_shape
+ .Input(FakeInput(DT_FLOAT)) // b
+ .Attr("T", DT_FLOAT)
+ .Attr("adjoint_a", adjoint_a)
+ .Attr("adjoint_b", adjoint_b)
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+ }
+
+ // Set up and run the SparseTensorDenseMatMul op.
+ Tensor RunOp(const Tensor& a_indices, const Tensor& a_values,
+ const Tensor& a_shape, const Tensor& b,
+ bool adjoint_a, bool adjoint_b) {
+ MakeOp(adjoint_a, adjoint_b);
+
+ *AddInput(DT_INT64, a_indices.shape()) = a_indices;
+ *AddInput(DT_FLOAT, a_values.shape()) = a_values;
+ *AddInput(DT_INT64, a_shape.shape()) = a_shape;
+ *AddInput(DT_FLOAT, b.shape()) = b;
+
+ TF_CHECK_OK(RunOpKernel());
+
+ return *GetOutput(0);
+ }
+
+ // Compare two matrices with a tolerance
+ void ExpectClose(const Tensor& actual, const Tensor& expected,
+ float rtol = 1e-5, float atol = 1e-5) {
+ ASSERT_EQ(actual.shape(), expected.shape());
+ auto actual_m = actual.matrix<float>();
+ auto expected_m = expected.matrix<float>();
+ int rows = actual.dim_size(0);
+ int cols = actual.dim_size(1);
+ for (int i = 0; i < rows; ++i) {
+ for (int j = 0; j < cols; ++j) {
+ float diff = std::abs(actual_m(i, j) - expected_m(i, j));
+ float tol = rtol * std::max(std::abs(expected_m(i, j)), 1.0f) + atol;
+ EXPECT_LE(diff, tol)
+ << "(" << i << "," << j << "): actual=" << actual_m(i, j)
+ << " expected=" << expected_m(i, j);
+ }
+ }
+ }
+
+ // Helper: create a deterministic sparse matrix with given sparsity
+ void CreateSparseMatrix(int rows, int cols, int nnz, int seed,
+ Tensor* a_indices, Tensor* a_values,
+ Tensor* a_shape) {
+ *a_indices = Tensor(DT_INT64, TensorShape({nnz, 2}));
+ *a_values = Tensor(DT_FLOAT, TensorShape({nnz}));
+ *a_shape = Tensor(DT_INT64, TensorShape({2}));
+ a_shape->vec<int64_t>()(0) = rows;
+ a_shape->vec<int64_t>()(1) = cols;
+
+ std::mt19937 gen(seed);
+ std::uniform_int_distribution<> row_dist(0, rows - 1);
+ std::uniform_int_distribution<> col_dist(0, cols - 1);
+ std::uniform_real_distribution<float> val_dist(-1.0, 1.0);
+
+ auto idx_m = a_indices->matrix<int64_t>();
+ auto val_v = a_values->vec<float>();
+ for (int i = 0; i < nnz; ++i) {
+ idx_m(i, 0) = row_dist(gen);
+ idx_m(i, 1) = col_dist(gen);
+ val_v(i) = val_dist(gen);
+ }
+ }
+};
+
+// ---------------------------------------------------------------------------
+// 1. Basic functionality: different sparsity levels
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, BasicSparsity) {
+ const int m = 8, k = 16, n = 4;
+ const int nnz = 20;
+
+ Tensor a_indices, a_values, a_shape;
+ CreateSparseMatrix(m, k, nnz, /*seed=*/42, &a_indices, &a_values, &a_shape);
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ b.flat<float>().setRandom();
+
+ // Compute reference
+ Tensor ref(DT_FLOAT, TensorShape({m, n}));
+ ReferenceSparseMatmulFn(
+ nnz, a_indices.matrix<int64_t>().data(), a_values.vec<float>().data(),
+ b.matrix<float>().data(), m, k, n,
+ /*adjoint_a=*/false, /*adjoint_b=*/false,
+ ref.matrix<float>().data());
+
+ // Compute via TF op (KDNN path when ENABLE_KDNN is on)
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+ ExpectClose(result, ref);
+}
+
+// ---------------------------------------------------------------------------
+// 2. Boundary conditions
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, ZeroNonZeroEntries) {
+ const int m = 4, k = 8, n = 2;
+ const int nnz = 0;
+
+ Tensor a_indices(DT_INT64, TensorShape({0, 2}));
+ Tensor a_values(DT_FLOAT, TensorShape({0}));
+ Tensor a_shape(DT_INT64, TensorShape({2}));
+ a_shape.vec<int64_t>()(0) = m;
+ a_shape.vec<int64_t>()(1) = k;
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ b.flat<float>().setRandom();
+
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+ // All zeros
+ auto result_m = result.matrix<float>();
+ for (int i = 0; i < m; ++i)
+ for (int j = 0; j < n; ++j)
+ EXPECT_EQ(result_m(i, j), 0.0f);
+}
+
+TEST_F(SparseMatmulAccuracyTest, IdentityMatrix) {
+ const int m = 4, k = 4, n = 4;
+ // A = I (identity sparse), A * B = B
+ Tensor a_indices(DT_INT64, TensorShape({4, 2}));
+ Tensor a_values(DT_FLOAT, TensorShape({4}));
+ Tensor a_shape(DT_INT64, TensorShape({2}));
+ a_shape.vec<int64_t>()(0) = m;
+ a_shape.vec<int64_t>()(1) = k;
+
+ auto idx = a_indices.matrix<int64_t>();
+ auto val = a_values.vec<float>();
+ for (int i = 0; i < 4; ++i) {
+ idx(i, 0) = i;
+ idx(i, 1) = i;
+ val(i) = 1.0f;
+ }
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ b.flat<float>().setRandom();
+
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+ ExpectClose(result, b);
+}
+
+TEST_F(SparseMatmulAccuracyTest, AllZeroMatrix) {
+ const int m = 4, k = 8, n = 4;
+ // A = zeros, so result = zeros
+ Tensor a_indices(DT_INT64, TensorShape({0, 2}));
+ Tensor a_values(DT_FLOAT, TensorShape({0}));
+ Tensor a_shape(DT_INT64, TensorShape({2}));
+ a_shape.vec<int64_t>()(0) = m;
+ a_shape.vec<int64_t>()(1) = k;
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ b.flat<float>().setRandom();
+
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+ auto result_m = result.matrix<float>();
+ for (int i = 0; i < m; ++i)
+ for (int j = 0; j < n; ++j)
+ EXPECT_EQ(result_m(i, j), 0.0f);
+}
+
+// ---------------------------------------------------------------------------
+// 3. Transpose semantics
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, AdjointA) {
+ const int m = 8, k = 16, n = 4;
+ const int nnz = 30;
+
+ Tensor a_indices, a_values, a_shape;
+ CreateSparseMatrix(k, m, nnz, /*seed=*/123, &a_indices, &a_values, &a_shape);
+ // adjoint_a=true: A is k×m but treated as A^T which is m×k
+ // So the sparse matrix shape is (k, m) with indices in (k, m) space.
+ // The op with adjoint_a treats it as: A^T (m×k) * B (k×n) → (m×n)
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ b.flat<float>().setRandom();
+
+ // Reference: manually transpose
+ Tensor ref(DT_FLOAT, TensorShape({m, n}));
+ auto a_shape_v = a_shape.vec<int64_t>();
+ int rows_a = a_shape_v(0); // k
+ int cols_a = a_shape_v(1); // m
+ auto a_idx = a_indices.matrix<int64_t>();
+ auto a_val = a_values.vec<float>();
+ auto b_m = b.matrix<float>();
+ auto ref_m = ref.matrix<float>();
+ ref_m.setZero();
+ for (int i = 0; i < nnz; ++i) {
+ // With adjoint_a=true: A_orig is k×m, index(i,0)=row_in_orig(=k_dim), index(i,1)=col_in_orig(=m_dim)
+ // A^T has dimension m×k: A^T(col_in_orig, row_in_orig) = A_orig(row_in_orig, col_in_orig)
+ int orig_row = a_idx(i, 0); // k dimension
+ int orig_col = a_idx(i, 1); // m dimension
+ // In A^T, result row = orig_col, A^T's k-index = orig_row
+ for (int j = 0; j < n; ++j) {
+ ref_m(orig_col, j) += a_val(i) * b_m(orig_row, j);
+ }
+ }
+
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/true, /*adjoint_b=*/false);
+ ExpectClose(result, ref);
+}
+
+// ---------------------------------------------------------------------------
+// 4. Large scale data (stability test)
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, LargeScale) {
+ const int m = 64, k = 128, n = 16;
+ const int nnz = 200;
+
+ Tensor a_indices(DT_INT64, TensorShape({nnz, 2}));
+ Tensor a_values(DT_FLOAT, TensorShape({nnz}));
+ Tensor a_shape(DT_INT64, TensorShape({2}));
+ a_shape.vec<int64_t>()(0) = m;
+ a_shape.vec<int64_t>()(1) = k;
+
+ auto idx = a_indices.matrix<int64_t>();
+ auto val = a_values.vec<float>();
+ for (int i = 0; i < nnz; ++i) {
+ idx(i, 0) = (i * 7) % m;
+ idx(i, 1) = (i * 13) % k;
+ val(i) = static_cast<float>((i % 100) - 50) / 50.0f;
+ }
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ auto b_m = b.matrix<float>();
+ for (int i = 0; i < k; ++i)
+ for (int j = 0; j < n; ++j)
+ b_m(i, j) = static_cast<float>((i * 3 + j * 7) % 20 - 10) / 10.0f;
+
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/false, /*adjoint_b=*/false);
+ EXPECT_EQ(result.shape(), TensorShape({m, n}));
+}
+
+// ---------------------------------------------------------------------------
+// 5. Multi-thread path consistency (env var set before binary start)
+// ---------------------------------------------------------------------------
+// Note: TF_ENABLE_KDNN_SPARSE_MATMUL_PARALLEL controls parallel vs serial
+// in the KDNN SparseGemm path (see kdnn_adapter.h). It is read once at
+// first invocation. To test with parallel=0, run:
+// TF_ENABLE_KDNN_SPARSE_MATMUL_PARALLEL=0 ... --benchmark_filter=...
+TEST_F(SparseMatmulAccuracyTest, KdnnPathMatchesReference) {
+ const int m = 32, k = 64, n = 8;
+ const int nnz = 80;
+
+ Tensor a_indices, a_values, a_shape;
+ CreateSparseMatrix(m, k, nnz, /*seed=*/99, &a_indices, &a_values, &a_shape);
+
+ Tensor b(DT_FLOAT, TensorShape({k, n}));
+ b.flat<float>().setRandom();
+
+ Tensor ref(DT_FLOAT, TensorShape({m, n}));
+ ReferenceSparseMatmulFn(
+ nnz, a_indices.matrix<int64_t>().data(), a_values.vec<float>().data(),
+ b.matrix<float>().data(), m, k, n,
+ false, false, ref.matrix<float>().data());
+
+ Tensor result = RunOp(a_indices, a_values, a_shape, b,
+ /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+ ExpectClose(result, ref);
+}
+
+
Node* SparseTensorDenseMatMulNode(Graph* g, Node* a_indices, Node* a_values,
Node* a_shape, Node* b, bool adjoint_a,
bool adjoint_b) {
@@ -68,7 +68,8 @@ inline void kdnnFusedGemm(OpKernelContext* ctx, const Tensor& a, const Tensor& b
}
template<typename Tindices>
-inline void kdnnSparseMatmul(const std::size_t nnz,
+inline void kdnnSparseMatmul(OpKernelContext* ctx,
+ const std::size_t nnz,
const std::size_t rhs_right, const std::size_t lhs_right,
const int lhs_index_a, const int rhs_index_a,
typename TTypes<float>::Matrix out,
@@ -97,8 +98,13 @@ inline void kdnnSparseMatmul(const std::size_t nnz,
KDNN::Element::TypeT::F32, KDNN::Layout::AB};
const KDNN::TensorInfo dstInfo = {{lhs_left, rhs_right},
KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ // intra_op thread_pool
+ thread::ThreadPool* thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
KDNN::SparseGemm sparse_csr(aInfo, bInfo, dstInfo);
sparse_csr.Run(a_values.data(), b_data, out.data());
+ KDNN::Threading::DeactivateThreadpool();
}
template<typename T>