#ifndef OPTEST_MATMUL_EXTRA_H
#define OPTEST_MATMUL_EXTRA_H
#include <torch/torch.h>
#include <tiling/platform/platform_ascendc.h>
#include "catlass_kernel_jit.h"
#include "common/run_npu_func.h"
#include "torch_utils.h"
#include "type_utils.hpp"
namespace CatlassKernelWrapper {
using KernelFn = void (*)(const uint32_t, aclrtStream, const CatlassKernel::TParams&, const CatlassKernel::MatmulParams&);
* @brief Unified adapter for matmul ops with an extra input tensor.
*
* Two modes selected by the compile-time ``IsBias`` flag:
*
* - ``IsBias == false`` (MatmulAdd, in-place):
* ``extra`` is the addend C, shape ``(M,N)``, dtype = outDType.
* C and output D share the same buffer — C is used directly as the output,
* no separate allocation or copy. The kernel reads C as addend and
* accumulates ``A @ B`` into C. Returns C (modified in-place).
* ``inputAddr = [A, B]``, ``outputAddr = [C]``.
*
* - ``IsBias == true`` (MatmulBias, out-of-place):
* ``extra`` is a 1-D bias vector, shape ``(N,)``, separate from output.
* Allocates a fresh output tensor.
* ``inputAddr = [A, B, bias]``, ``outputAddr = [output]``.
*/
template <KernelFn KernelFunc, bool IsBias>
struct MatmulExtraLike {
using OutputType = at::Tensor;
static void GetKernelInfo(
const at::Tensor& mat1, const at::Tensor& mat2,
const c10::ScalarType& outDType, bool transA, bool transB,
bool formatA, bool formatB,
CatlassKernel::TParams& tParams,
CatlassKernel::MatmulParams& params)
{
auto aclType = TorchDtypeToAclDtype(mat1.scalar_type());
tParams.element["A"] = aclType;
tParams.element["B"] = aclType;
tParams.element["C"] = TorchDtypeToAclDtype(outDType);
tParams.transpose["A"] = transA;
tParams.transpose["B"] = transB;
tParams.transpose["C"] = false;
tParams.useNz["A"] = formatA;
tParams.useNz["B"] = formatB;
tParams.useNz["C"] = false;
params.inputAddr.resize(2);
params.inputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(mat1.storage().data()));
params.inputAddr[1] = static_cast<uint8_t*>(const_cast<void*>(mat2.storage().data()));
int64_t m, k1, k2, n;
if (transA) {
m = mat1.size(1); k1 = mat1.size(0);
} else {
m = mat1.size(0); k1 = mat1.size(1);
}
if (transB) {
k2 = mat2.size(1); n = mat2.size(0);
} else {
k2 = mat2.size(0); n = mat2.size(1);
}
TORCH_CHECK(k1 == k2, "mat1 and mat2 shapes cannot be multiplied (",
m, "x", k1, " and ", k2, "x", n, ")");
params.m = static_cast<uint32_t>(m);
params.k = static_cast<uint32_t>(k1);
params.n = static_cast<uint32_t>(n);
}
static OutputType Run(
const at::Tensor& mat1, const at::Tensor& mat2, const at::Tensor& extra,
const c10::ScalarType& outDType, bool transA, bool transB,
bool formatA, bool formatB)
{
CatlassKernel::TParams tParams;
CatlassKernel::MatmulParams params;
GetKernelInfo(mat1, mat2, outDType, transA, transB, formatA, formatB, tParams, params);
if constexpr (IsBias) {
TORCH_CHECK(extra.dim() == 1 && extra.size(0) == static_cast<int64_t>(params.n),
"bias must be a 1-D tensor of size N=", params.n,
", got ", extra.dim(), "-D of size ",
extra.dim() > 0 ? extra.size(0) : 0);
params.inputAddr.resize(3);
params.inputAddr[2] = static_cast<uint8_t*>(const_cast<void*>(extra.storage().data()));
params.outputAddr.resize(1);
auto output = GetOutputTensor(
{params.m, params.n}, static_cast<torch::Dtype>(outDType));
params.outputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(output.storage().data()));
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
RUN_NPU_FUNC(KernelFunc, aicCoreNum, stream, tParams, params);
return output;
} else {
TORCH_CHECK(extra.dim() == 2 &&
extra.size(0) == static_cast<int64_t>(params.m) &&
extra.size(1) == static_cast<int64_t>(params.n),
"addend must have shape (", params.m, ", ", params.n, ")");
TORCH_CHECK(extra.scalar_type() == outDType,
"addend dtype ", extra.scalar_type(), " must match outDType ", outDType);
params.outputAddr.resize(1);
params.outputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(extra.storage().data()));
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
RUN_NPU_FUNC(KernelFunc, aicCoreNum, stream, tParams, params);
return extra;
}
}
};
}
#endif