diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 50544ade4..3dccaf5bd 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -142,14 +142,14 @@ class SparseTensorDenseMatMulOp : public OpKernel {
 #define KDNN_ADJOINT(ADJ_A, ADJ_B)                                             \
   if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                            \
     Status functor_status = functor::KDNNSparseMatMulFunctor<                  \
-        CPUDevice, float, Tindices, ADJ_A,                                     \
-        ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<float>(),     \
+        Device, float, Tindices, ADJ_A,                                     \
+        ADJ_B>::Compute(ctx, out->matrix<float>(), \
                         a_indices->matrix<Tindices>(), a_values->vec<float>(), \
                         b->matrix<float>());                                   \
     OP_REQUIRES_OK(ctx, functor_status);                                       \
   }
     const int int32max = std::numeric_limits<int>::max();
-    if (IsKDNNEnabled() && std::is_same<T, float>::value && adjoint_a_ == false &&
+    if (IsKDNNEnabled() && std::is_same<Device, CPUDevice>::value && std::is_same<T, float>::value && adjoint_a_ == false &&
         FastBoundsCheck(inner_left, int32max) &&
         FastBoundsCheck(inner_right, int32max) &&
         FastBoundsCheck(outer_left, int32max)) {
@@ -387,7 +387,7 @@ template <typename Tindices, bool ADJ_A, bool ADJ_B>
 struct KDNNSparseMatMulFunctor<CPUDevice, float, Tindices, ADJ_A, ADJ_B> {
   static const std::size_t kNumVectorize = 32;
 
-  static Status Compute(const CPUDevice& d, typename TTypes<float>::Matrix out,
+  static Status Compute(OpKernelContext* ctx, typename TTypes<float>::Matrix out,
                         typename TTypes<Tindices>::ConstMatrix a_indices,
                         typename TTypes<float>::ConstVec a_values,
                         typename TTypes<float>::ConstMatrix b) {
@@ -426,7 +426,7 @@ struct KDNNSparseMatMulFunctor<CPUDevice, float, Tindices, ADJ_A, ADJ_B> {
         col_major_conj_b = b.swap_layout().shuffle(shuffle).conjugate().eval();
         b_data = col_major_conj_b.data();
       }
-      kdnnSparseMatmul<Tindices>(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b_data);
+      kdnnSparseMatmul<Tindices>(ctx, nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b_data);
     }
     return OkStatus();
   }
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index 3de9474a1..a7713c679 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -40,7 +40,7 @@ template <typename Device, typename T, typename Tindices, bool ADJ_A,
           bool ADJ_B>
 struct KDNNSparseMatMulFunctor {
   static EIGEN_ALWAYS_INLINE Status Compute(
-      const Device& d, typename TTypes<T>::Matrix out,
+      OpKernelContext* ctx, typename TTypes<T>::Matrix out,
       typename TTypes<Tindices>::ConstMatrix a_indices,
       typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
 };
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
index fe856e6b6..6db36cf03 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc
@@ -16,13 +16,331 @@ limitations under the License.
 #include <random>
 
 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 
 namespace tensorflow {
 
+// Fully naive reference for all adjoint combinations
+static void ReferenceSparseMatmulFn(
+    const int nnz,
+    const int64_t* a_indices_data,
+    const float* a_values_data,
+    const float* b_data,
+    int lhs_left, int lhs_right, int rhs_right,
+    bool adjoint_a, bool adjoint_b,
+    float* out_data) {
+  // Initialize output to zero
+  for (int i = 0; i < lhs_left * rhs_right; ++i) out_data[i] = 0.0f;
+
+  const int lhs_index_a = adjoint_a ? 1 : 0;
+  const int rhs_index_a = adjoint_a ? 0 : 1;
+
+  for (int i = 0; i < nnz; ++i) {
+    int row = a_indices_data[i * 2 + lhs_index_a];
+    int col = a_indices_data[i * 2 + rhs_index_a];
+    float a_val = a_values_data[i];
+    if (adjoint_b) {
+      for (int n = 0; n < rhs_right; ++n) {
+        out_data[row * rhs_right + n] += a_val * b_data[n * lhs_right + col];
+      }
+    } else {
+      for (int n = 0; n < rhs_right; ++n) {
+        out_data[row * rhs_right + n] += a_val * b_data[col * rhs_right + n];
+      }
+    }
+  }
+}
+
+// ===========================================================================
+// SparseMatMul accuracy tests (Section 10.1.1 of design doc)
+// ===========================================================================
+
+class SparseMatmulAccuracyTest : public OpsTestBase {
+ protected:
+  void MakeOp(bool adjoint_a, bool adjoint_b) {
+    TF_CHECK_OK(NodeDefBuilder("sparse_matmul", "SparseTensorDenseMatMul")
+                     .Input(FakeInput(DT_INT64))    // a_indices
+                     .Input(FakeInput(DT_FLOAT))    // a_values
+                     .Input(FakeInput(DT_INT64))    // a_shape
+                     .Input(FakeInput(DT_FLOAT))    // b
+                     .Attr("T", DT_FLOAT)
+                     .Attr("adjoint_a", adjoint_a)
+                     .Attr("adjoint_b", adjoint_b)
+                     .Finalize(node_def()));
+    TF_CHECK_OK(InitOp());
+  }
+
+  // Set up and run the SparseTensorDenseMatMul op.
+  Tensor RunOp(const Tensor& a_indices, const Tensor& a_values,
+               const Tensor& a_shape, const Tensor& b,
+               bool adjoint_a, bool adjoint_b) {
+    MakeOp(adjoint_a, adjoint_b);
+
+    *AddInput(DT_INT64, a_indices.shape()) = a_indices;
+    *AddInput(DT_FLOAT, a_values.shape()) = a_values;
+    *AddInput(DT_INT64, a_shape.shape()) = a_shape;
+    *AddInput(DT_FLOAT, b.shape()) = b;
+
+    TF_CHECK_OK(RunOpKernel());
+
+    return *GetOutput(0);
+  }
+
+  // Compare two matrices with a tolerance
+  void ExpectClose(const Tensor& actual, const Tensor& expected,
+                   float rtol = 1e-5, float atol = 1e-5) {
+    ASSERT_EQ(actual.shape(), expected.shape());
+    auto actual_m = actual.matrix<float>();
+    auto expected_m = expected.matrix<float>();
+    int rows = actual.dim_size(0);
+    int cols = actual.dim_size(1);
+    for (int i = 0; i < rows; ++i) {
+      for (int j = 0; j < cols; ++j) {
+        float diff = std::abs(actual_m(i, j) - expected_m(i, j));
+        float tol = rtol * std::max(std::abs(expected_m(i, j)), 1.0f) + atol;
+        EXPECT_LE(diff, tol)
+            << "(" << i << "," << j << "): actual=" << actual_m(i, j)
+            << " expected=" << expected_m(i, j);
+      }
+    }
+  }
+
+  // Helper: create a deterministic sparse matrix with given sparsity
+  void CreateSparseMatrix(int rows, int cols, int nnz, int seed,
+                          Tensor* a_indices, Tensor* a_values,
+                          Tensor* a_shape) {
+    *a_indices = Tensor(DT_INT64, TensorShape({nnz, 2}));
+    *a_values = Tensor(DT_FLOAT, TensorShape({nnz}));
+    *a_shape = Tensor(DT_INT64, TensorShape({2}));
+    a_shape->vec<int64_t>()(0) = rows;
+    a_shape->vec<int64_t>()(1) = cols;
+
+    std::mt19937 gen(seed);
+    std::uniform_int_distribution<> row_dist(0, rows - 1);
+    std::uniform_int_distribution<> col_dist(0, cols - 1);
+    std::uniform_real_distribution<float> val_dist(-1.0, 1.0);
+
+    auto idx_m = a_indices->matrix<int64_t>();
+    auto val_v = a_values->vec<float>();
+    for (int i = 0; i < nnz; ++i) {
+      idx_m(i, 0) = row_dist(gen);
+      idx_m(i, 1) = col_dist(gen);
+      val_v(i) = val_dist(gen);
+    }
+  }
+};
+
+// ---------------------------------------------------------------------------
+// 1. Basic functionality: different sparsity levels
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, BasicSparsity) {
+  const int m = 8, k = 16, n = 4;
+  const int nnz = 20;
+
+  Tensor a_indices, a_values, a_shape;
+  CreateSparseMatrix(m, k, nnz, /*seed=*/42, &a_indices, &a_values, &a_shape);
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  b.flat<float>().setRandom();
+
+  // Compute reference
+  Tensor ref(DT_FLOAT, TensorShape({m, n}));
+  ReferenceSparseMatmulFn(
+      nnz, a_indices.matrix<int64_t>().data(), a_values.vec<float>().data(),
+      b.matrix<float>().data(), m, k, n,
+      /*adjoint_a=*/false, /*adjoint_b=*/false,
+      ref.matrix<float>().data());
+
+  // Compute via TF op (KDNN path when ENABLE_KDNN is on)
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+  ExpectClose(result, ref);
+}
+
+// ---------------------------------------------------------------------------
+// 2. Boundary conditions
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, ZeroNonZeroEntries) {
+  const int m = 4, k = 8, n = 2;
+  const int nnz = 0;
+
+  Tensor a_indices(DT_INT64, TensorShape({0, 2}));
+  Tensor a_values(DT_FLOAT, TensorShape({0}));
+  Tensor a_shape(DT_INT64, TensorShape({2}));
+  a_shape.vec<int64_t>()(0) = m;
+  a_shape.vec<int64_t>()(1) = k;
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  b.flat<float>().setRandom();
+
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+  // All zeros
+  auto result_m = result.matrix<float>();
+  for (int i = 0; i < m; ++i)
+    for (int j = 0; j < n; ++j)
+      EXPECT_EQ(result_m(i, j), 0.0f);
+}
+
+TEST_F(SparseMatmulAccuracyTest, IdentityMatrix) {
+  const int m = 4, k = 4, n = 4;
+  // A = I (identity sparse), A * B = B
+  Tensor a_indices(DT_INT64, TensorShape({4, 2}));
+  Tensor a_values(DT_FLOAT, TensorShape({4}));
+  Tensor a_shape(DT_INT64, TensorShape({2}));
+  a_shape.vec<int64_t>()(0) = m;
+  a_shape.vec<int64_t>()(1) = k;
+
+  auto idx = a_indices.matrix<int64_t>();
+  auto val = a_values.vec<float>();
+  for (int i = 0; i < 4; ++i) {
+    idx(i, 0) = i;
+    idx(i, 1) = i;
+    val(i) = 1.0f;
+  }
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  b.flat<float>().setRandom();
+
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+  ExpectClose(result, b);
+}
+
+TEST_F(SparseMatmulAccuracyTest, AllZeroMatrix) {
+  const int m = 4, k = 8, n = 4;
+  // A = zeros, so result = zeros
+  Tensor a_indices(DT_INT64, TensorShape({0, 2}));
+  Tensor a_values(DT_FLOAT, TensorShape({0}));
+  Tensor a_shape(DT_INT64, TensorShape({2}));
+  a_shape.vec<int64_t>()(0) = m;
+  a_shape.vec<int64_t>()(1) = k;
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  b.flat<float>().setRandom();
+
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+  auto result_m = result.matrix<float>();
+  for (int i = 0; i < m; ++i)
+    for (int j = 0; j < n; ++j)
+      EXPECT_EQ(result_m(i, j), 0.0f);
+}
+
+// ---------------------------------------------------------------------------
+// 3. Transpose semantics
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, AdjointA) {
+  const int m = 8, k = 16, n = 4;
+  const int nnz = 30;
+
+  Tensor a_indices, a_values, a_shape;
+  CreateSparseMatrix(k, m, nnz, /*seed=*/123, &a_indices, &a_values, &a_shape);
+  // adjoint_a=true: A is k×m but treated as A^T which is m×k
+  // So the sparse matrix shape is (k, m) with indices in (k, m) space.
+  // The op with adjoint_a treats it as: A^T (m×k) * B (k×n) → (m×n)
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  b.flat<float>().setRandom();
+
+  // Reference: manually transpose
+  Tensor ref(DT_FLOAT, TensorShape({m, n}));
+  auto a_shape_v = a_shape.vec<int64_t>();
+  int rows_a = a_shape_v(0);  // k
+  int cols_a = a_shape_v(1);  // m
+  auto a_idx = a_indices.matrix<int64_t>();
+  auto a_val = a_values.vec<float>();
+  auto b_m = b.matrix<float>();
+  auto ref_m = ref.matrix<float>();
+  ref_m.setZero();
+  for (int i = 0; i < nnz; ++i) {
+    // With adjoint_a=true: A_orig is k×m, index(i,0)=row_in_orig(=k_dim), index(i,1)=col_in_orig(=m_dim)
+    // A^T has dimension m×k: A^T(col_in_orig, row_in_orig) = A_orig(row_in_orig, col_in_orig)
+    int orig_row = a_idx(i, 0);  // k dimension
+    int orig_col = a_idx(i, 1);  // m dimension
+    // In A^T, result row = orig_col, A^T's k-index = orig_row
+    for (int j = 0; j < n; ++j) {
+      ref_m(orig_col, j) += a_val(i) * b_m(orig_row, j);
+    }
+  }
+
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/true, /*adjoint_b=*/false);
+  ExpectClose(result, ref);
+}
+
+// ---------------------------------------------------------------------------
+// 4. Large scale data (stability test)
+// ---------------------------------------------------------------------------
+TEST_F(SparseMatmulAccuracyTest, LargeScale) {
+  const int m = 64, k = 128, n = 16;
+  const int nnz = 200;
+
+  Tensor a_indices(DT_INT64, TensorShape({nnz, 2}));
+  Tensor a_values(DT_FLOAT, TensorShape({nnz}));
+  Tensor a_shape(DT_INT64, TensorShape({2}));
+  a_shape.vec<int64_t>()(0) = m;
+  a_shape.vec<int64_t>()(1) = k;
+
+  auto idx = a_indices.matrix<int64_t>();
+  auto val = a_values.vec<float>();
+  for (int i = 0; i < nnz; ++i) {
+    idx(i, 0) = (i * 7) % m;
+    idx(i, 1) = (i * 13) % k;
+    val(i) = static_cast<float>((i % 100) - 50) / 50.0f;
+  }
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  auto b_m = b.matrix<float>();
+  for (int i = 0; i < k; ++i)
+    for (int j = 0; j < n; ++j)
+      b_m(i, j) = static_cast<float>((i * 3 + j * 7) % 20 - 10) / 10.0f;
+
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/false, /*adjoint_b=*/false);
+  EXPECT_EQ(result.shape(), TensorShape({m, n}));
+}
+
+// ---------------------------------------------------------------------------
+// 5. Multi-thread path consistency (env var set before binary start)
+// ---------------------------------------------------------------------------
+// Note: TF_ENABLE_KDNN_SPARSE_MATMUL_PARALLEL controls parallel vs serial
+// in the KDNN SparseGemm path (see kdnn_adapter.h). It is read once at
+// first invocation. To test with parallel=0, run:
+//   TF_ENABLE_KDNN_SPARSE_MATMUL_PARALLEL=0 ... --benchmark_filter=...
+TEST_F(SparseMatmulAccuracyTest, KdnnPathMatchesReference) {
+  const int m = 32, k = 64, n = 8;
+  const int nnz = 80;
+
+  Tensor a_indices, a_values, a_shape;
+  CreateSparseMatrix(m, k, nnz, /*seed=*/99, &a_indices, &a_values, &a_shape);
+
+  Tensor b(DT_FLOAT, TensorShape({k, n}));
+  b.flat<float>().setRandom();
+
+  Tensor ref(DT_FLOAT, TensorShape({m, n}));
+  ReferenceSparseMatmulFn(
+      nnz, a_indices.matrix<int64_t>().data(), a_values.vec<float>().data(),
+      b.matrix<float>().data(), m, k, n,
+      false, false, ref.matrix<float>().data());
+
+  Tensor result = RunOp(a_indices, a_values, a_shape, b,
+                        /*adjoint_a=*/false, /*adjoint_b=*/false);
+
+  ExpectClose(result, ref);
+}
+
+
 Node* SparseTensorDenseMatMulNode(Graph* g, Node* a_indices, Node* a_values,
                                   Node* a_shape, Node* b, bool adjoint_a,
                                   bool adjoint_b) {
diff --git a/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h
index c72e96347..64f471d04 100644
--- a/third_party/KDNN/kdnn_adapter.h
+++ b/third_party/KDNN/kdnn_adapter.h
@@ -68,7 +68,8 @@ inline void kdnnFusedGemm(OpKernelContext* ctx, const Tensor& a, const Tensor& b
 }
 
 template<typename Tindices>
-inline void kdnnSparseMatmul(const std::size_t nnz,
+inline void kdnnSparseMatmul(OpKernelContext* ctx,
+                      const std::size_t nnz,
                       const std::size_t rhs_right, const std::size_t lhs_right,
                       const int lhs_index_a, const int rhs_index_a,
                       typename TTypes<float>::Matrix out,
@@ -97,8 +98,13 @@ inline void kdnnSparseMatmul(const std::size_t nnz,
         KDNN::Element::TypeT::F32, KDNN::Layout::AB};
     const KDNN::TensorInfo dstInfo = {{lhs_left, rhs_right},
         KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+    // intra_op thread_pool
+    thread::ThreadPool* thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
+    kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+    KDNN::Threading::ActivateThreadpool(&kdnn_tp);
     KDNN::SparseGemm sparse_csr(aInfo, bInfo, dstInfo);
     sparse_csr.Run(a_values.data(), b_data, out.data());
+    KDNN::Threading::DeactivateThreadpool();
 }
 
 template<typename T>