* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/
#ifndef OPTEST_MX_MATMUL_H
#define OPTEST_MX_MATMUL_H
#include <torch/torch.h>
#include <tiling/platform/platform_ascendc.h>
#include "catlass_kernel.h"
#include "common/run_npu_func.h"
#include "torch_utils.h"
#include "type_utils.hpp"
namespace CatlassKernelWrapper {
constexpr uint32_t kMxScaleGroupNum = 32;
inline uint32_t CeilDivUint32(uint32_t a, uint32_t b)
{
return (a + b - 1) / b;
}
inline uint32_t RoundUp2Uint32(uint32_t v)
{
return (v + 1U) / 2U * 2U;
}
* @brief Compute MX scale tensor shapes for the default layout mapping.
*
* When A is RowMajor and B is ColumnMajor (trans_a=0, trans_b=1):
* mx_scale_a: (m, mxScaleAlignedK / 2, 2) → numel = m * mxScaleAlignedK
* mx_scale_b: (n, mxScaleAlignedK / 2, 2) → numel = n * mxScaleAlignedK
*/
inline void ComputeMxScaleShapes(
uint32_t m, uint32_t n, uint32_t k, uint32_t& mxScaleAlignedK, int64_t& scaleANumel, int64_t& scaleBNumel)
{
const uint32_t mxScaleK = CeilDivUint32(k, kMxScaleGroupNum);
mxScaleAlignedK = RoundUp2Uint32(mxScaleK);
scaleANumel = static_cast<int64_t>(m) * static_cast<int64_t>(mxScaleAlignedK);
scaleBNumel = static_cast<int64_t>(n) * static_cast<int64_t>(mxScaleAlignedK);
}
inline int64_t ProductInt64(uint32_t a, uint32_t b, uint32_t c)
{
return static_cast<int64_t>(a) * static_cast<int64_t>(b) * static_cast<int64_t>(c);
}
inline void CheckNpuTensor(const at::Tensor& tensor, const char* name)
{
TORCH_CHECK(tensor.device().type() == c10::DeviceType::PrivateUse1, name, " must be on NPU");
}
inline void CheckSameDevice(const at::Tensor& reference, const char* referenceName, const at::Tensor& tensor,
const char* tensorName)
{
TORCH_CHECK(
tensor.device() == reference.device(), tensorName, " must be on the same device as ", referenceName,
" (got ", tensor.device(), " and ", reference.device(), ")");
}
inline void CheckMxScaleDType(const at::Tensor& tensor, const char* name)
{
TORCH_CHECK(
tensor.scalar_type() == torch::kFloat8_e8m0fnu, name, " must have dtype torch.float8_e8m0fnu, got ",
tensor.scalar_type());
}
using KernelFn = void (*)(const uint32_t, aclrtStream, const CatlassKernel::TParams&, const CatlassKernel::MatmulParams&);
template <KernelFn KernelFunc>
struct MxMatmulLike {
using OutputType = at::Tensor;
static void GetKernelInfo(
const at::Tensor& mat1, const at::Tensor& mat2, const at::Tensor& mx_scale_a,
const at::Tensor& mx_scale_b, bool transA, bool transB, CatlassKernel::TParams& tParams,
CatlassKernel::MatmulParams& params)
{
CheckNpuTensor(mat1, "mat1");
CheckNpuTensor(mat2, "mat2");
CheckNpuTensor(mx_scale_a, "mx_scale_a");
CheckNpuTensor(mx_scale_b, "mx_scale_b");
CheckSameDevice(mat1, "mat1", mat2, "mat2");
CheckSameDevice(mat1, "mat1", mx_scale_a, "mx_scale_a");
CheckSameDevice(mat1, "mat1", mx_scale_b, "mx_scale_b");
CheckMxScaleDType(mx_scale_a, "mx_scale_a");
CheckMxScaleDType(mx_scale_b, "mx_scale_b");
tParams.element["A"] = TorchDtypeToAclDtype(mat1.scalar_type());
tParams.element["B"] = TorchDtypeToAclDtype(mat2.scalar_type());
tParams.element["C"] = ACL_FLOAT;
tParams.element["MX_SCALE"] = test_utils::TypeCast<std::string, aclDataType>("float8_e8m0fnu");
tParams.transpose["A"] = transA;
tParams.transpose["B"] = transB;
tParams.transpose["C"] = false;
tParams.useNz["A"] = false;
tParams.useNz["B"] = false;
tParams.useNz["C"] = false;
params.inputAddr.resize(4);
params.inputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(mat1.storage().data()));
params.inputAddr[1] = static_cast<uint8_t*>(const_cast<void*>(mat2.storage().data()));
params.inputAddr[2] = static_cast<uint8_t*>(const_cast<void*>(mx_scale_a.storage().data()));
params.inputAddr[3] = static_cast<uint8_t*>(const_cast<void*>(mx_scale_b.storage().data()));
int64_t m, k1, k2, n;
if (transA) {
m = mat1.size(1);
k1 = mat1.size(0);
} else {
m = mat1.size(0);
k1 = mat1.size(1);
}
if (transB) {
k2 = mat2.size(1);
n = mat2.size(0);
} else {
k2 = mat2.size(0);
n = mat2.size(1);
}
TORCH_CHECK(k1 == k2, "mat1 and mat2 shapes cannot be multiplied (", m, "x", k1, " and ", k2, "x", n, ")");
TORCH_CHECK(mat1.is_contiguous(), "mat1 must be contiguous");
TORCH_CHECK(mat2.is_contiguous(), "mat2 must be contiguous");
TORCH_CHECK(mx_scale_a.is_contiguous(), "mx_scale_a must be contiguous");
TORCH_CHECK(mx_scale_b.is_contiguous(), "mx_scale_b must be contiguous");
params.m = static_cast<uint32_t>(m);
params.k = static_cast<uint32_t>(k1);
params.n = static_cast<uint32_t>(n);
uint32_t mxScaleAlignedK = 0;
int64_t scaleANumel = 0;
int64_t scaleBNumel = 0;
ComputeMxScaleShapes(params.m, params.n, params.k, mxScaleAlignedK, scaleANumel, scaleBNumel);
TORCH_CHECK(
mx_scale_a.numel() == scaleANumel,
"mx_scale_a must have ", scaleANumel, " elements (m=", m, ", mxScaleAlignedK=", mxScaleAlignedK,
"), got ", mx_scale_a.numel());
TORCH_CHECK(
mx_scale_b.numel() == scaleBNumel,
"mx_scale_b must have ", scaleBNumel, " elements (n=", n, ", mxScaleAlignedK=", mxScaleAlignedK,
"), got ", mx_scale_b.numel());
}
static OutputType AllocOutput(CatlassKernel::MatmulParams& params)
{
OutputType output = GetOutputTensor({params.m, params.n}, torch::kFloat32);
params.outputAddr.resize(1);
params.outputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(output.storage().data()));
return output;
}
static OutputType Run(
const at::Tensor& mat1, const at::Tensor& mat2, const at::Tensor& mx_scale_a,
const at::Tensor& mx_scale_b, bool transA, bool transB)
{
CatlassKernel::TParams tParams;
CatlassKernel::MatmulParams params;
GetKernelInfo(mat1, mat2, mx_scale_a, mx_scale_b, transA, transB, tParams, params);
OutputType output = AllocOutput(params);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
RUN_NPU_FUNC(KernelFunc, aicCoreNum, stream, tParams, params);
return output;
}
};
template <KernelFn KernelFunc>
struct MxBatchedMatmulLike {
using OutputType = at::Tensor;
static void GetKernelInfo(
const at::Tensor& mat1, const at::Tensor& mat2, const at::Tensor& mx_scale_a,
const at::Tensor& mx_scale_b, bool transA, bool transB, CatlassKernel::TParams& tParams,
CatlassKernel::MatmulParams& params)
{
TORCH_CHECK(mat1.dim() == 3, "mat1 must be a 3-D tensor with shape (batch, M, K)");
TORCH_CHECK(mat2.dim() == 3, "mat2 must be a 3-D tensor with shape (batch, N, K) when transB=True");
TORCH_CHECK(
!transA && transB,
"ascend950_fp8_mx_batch_matmul currently supports only transA=false and transB=true");
CheckNpuTensor(mat1, "mat1");
CheckNpuTensor(mat2, "mat2");
CheckNpuTensor(mx_scale_a, "mx_scale_a");
CheckNpuTensor(mx_scale_b, "mx_scale_b");
CheckSameDevice(mat1, "mat1", mat2, "mat2");
CheckSameDevice(mat1, "mat1", mx_scale_a, "mx_scale_a");
CheckSameDevice(mat1, "mat1", mx_scale_b, "mx_scale_b");
TORCH_CHECK(
mat1.scalar_type() == torch::kFloat8_e4m3fn, "mat1 must have dtype torch.float8_e4m3fn, got ",
mat1.scalar_type());
TORCH_CHECK(
mat2.scalar_type() == torch::kFloat8_e4m3fn, "mat2 must have dtype torch.float8_e4m3fn, got ",
mat2.scalar_type());
CheckMxScaleDType(mx_scale_a, "mx_scale_a");
CheckMxScaleDType(mx_scale_b, "mx_scale_b");
TORCH_CHECK(mat1.size(0) == mat2.size(0), "mat1 and mat2 batch dimensions must match");
TORCH_CHECK(mat1.is_contiguous(), "mat1 must be contiguous");
TORCH_CHECK(mat2.is_contiguous(), "mat2 must be contiguous");
TORCH_CHECK(mx_scale_a.is_contiguous(), "mx_scale_a must be contiguous");
TORCH_CHECK(mx_scale_b.is_contiguous(), "mx_scale_b must be contiguous");
tParams.element["A"] = TorchDtypeToAclDtype(mat1.scalar_type());
tParams.element["B"] = TorchDtypeToAclDtype(mat2.scalar_type());
tParams.element["C"] = ACL_BF16;
tParams.element["MX_SCALE"] = test_utils::TypeCast<std::string, aclDataType>("float8_e8m0fnu");
tParams.transpose["A"] = transA;
tParams.transpose["B"] = transB;
tParams.transpose["C"] = false;
tParams.useNz["A"] = false;
tParams.useNz["B"] = false;
tParams.useNz["C"] = false;
int64_t m, k1, k2, n;
if (transA) {
m = mat1.size(2);
k1 = mat1.size(1);
} else {
m = mat1.size(1);
k1 = mat1.size(2);
}
if (transB) {
k2 = mat2.size(2);
n = mat2.size(1);
} else {
k2 = mat2.size(1);
n = mat2.size(2);
}
TORCH_CHECK(k1 == k2, "mat1 and mat2 shapes cannot be multiplied (", m, "x", k1, " and ", k2, "x", n, ")");
params.batch = static_cast<uint32_t>(mat1.size(0));
params.m = static_cast<uint32_t>(m);
params.k = static_cast<uint32_t>(k1);
params.n = static_cast<uint32_t>(n);
uint32_t mxScaleAlignedK = 0;
int64_t scaleAPerBatch = 0;
int64_t scaleBPerBatch = 0;
ComputeMxScaleShapes(params.m, params.n, params.k, mxScaleAlignedK, scaleAPerBatch, scaleBPerBatch);
TORCH_CHECK(
mx_scale_a.numel() == static_cast<int64_t>(params.batch) * scaleAPerBatch,
"mx_scale_a must have ", static_cast<int64_t>(params.batch) * scaleAPerBatch,
" elements (batch=", params.batch, ", m=", m, ", mxScaleAlignedK=", mxScaleAlignedK,
"), got ", mx_scale_a.numel());
TORCH_CHECK(
mx_scale_b.numel() == static_cast<int64_t>(params.batch) * scaleBPerBatch,
"mx_scale_b must have ", static_cast<int64_t>(params.batch) * scaleBPerBatch,
" elements (batch=", params.batch, ", n=", n, ", mxScaleAlignedK=", mxScaleAlignedK,
"), got ", mx_scale_b.numel());
params.inputAddr.resize(4);
params.inputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(mat1.storage().data()));
params.inputAddr[1] = static_cast<uint8_t*>(const_cast<void*>(mat2.storage().data()));
params.inputAddr[2] = static_cast<uint8_t*>(const_cast<void*>(mx_scale_a.storage().data()));
params.inputAddr[3] = static_cast<uint8_t*>(const_cast<void*>(mx_scale_b.storage().data()));
}
static OutputType AllocOutput(CatlassKernel::MatmulParams& params)
{
OutputType output = GetOutputTensor({params.batch, params.m, params.n}, torch::kBFloat16);
params.outputAddr.resize(1);
params.outputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(output.storage().data()));
return output;
}
static OutputType Run(
const at::Tensor& mat1, const at::Tensor& mat2, const at::Tensor& mx_scale_a,
const at::Tensor& mx_scale_b, bool transA, bool transB)
{
CatlassKernel::TParams tParams;
CatlassKernel::MatmulParams params;
GetKernelInfo(mat1, mat2, mx_scale_a, mx_scale_b, transA, transB, tParams, params);
OutputType output = AllocOutput(params);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
RUN_NPU_FUNC(KernelFunc, aicCoreNum, stream, tParams, params);
return output;
}
};
template <KernelFn KernelFunc>
struct DualLevelQuantMxBatchedMatmulLike {
using OutputType = at::Tensor;
struct ScratchBundle {
at::Tensor output;
at::Tensor scaleA1;
at::Tensor scaleA2;
at::Tensor scaleB1;
at::Tensor scaleB2;
at::Tensor workspace;
};
static void GetKernelInfo(
const at::Tensor& mat1, const at::Tensor& mat2, CatlassKernel::TParams& tParams,
CatlassKernel::MatmulParams& params)
{
TORCH_CHECK(mat1.dim() == 3, "mat1 must be a 3-D tensor with shape (batch, M, K)");
TORCH_CHECK(mat2.dim() == 3, "mat2 must be a 3-D tensor with shape (batch, N, K)");
CheckNpuTensor(mat1, "mat1");
CheckNpuTensor(mat2, "mat2");
CheckSameDevice(mat1, "mat1", mat2, "mat2");
TORCH_CHECK(
mat1.scalar_type() == torch::kFloat16 || mat1.scalar_type() == torch::kBFloat16,
"mat1 must have dtype torch.float16 or torch.bfloat16, got ", mat1.scalar_type());
TORCH_CHECK(
mat2.scalar_type() == mat1.scalar_type(), "mat2 dtype must match mat1 dtype (got ", mat2.scalar_type(),
" and ", mat1.scalar_type(), ")");
TORCH_CHECK(mat1.size(0) == mat2.size(0), "mat1 and mat2 batch dimensions must match");
TORCH_CHECK(mat1.size(2) == mat2.size(2), "mat1 K and mat2 K dimensions must match");
TORCH_CHECK(mat1.size(2) % 2 == 0, "K must be even for FP4 packing, got ", mat1.size(2));
TORCH_CHECK(mat1.is_contiguous(), "mat1 must be contiguous");
TORCH_CHECK(mat2.is_contiguous(), "mat2 must be contiguous");
tParams.element["INPUT"] = TorchDtypeToAclDtype(mat1.scalar_type());
tParams.element["C"] = ACL_BF16;
tParams.element["MX_SCALE"] = test_utils::TypeCast<std::string, aclDataType>("float8_e8m0fnu");
tParams.transpose["A"] = false;
tParams.transpose["B"] = true;
tParams.transpose["C"] = false;
tParams.useNz["A"] = false;
tParams.useNz["B"] = false;
tParams.useNz["C"] = false;
params.batch = static_cast<uint32_t>(mat1.size(0));
params.m = static_cast<uint32_t>(mat1.size(1));
params.k = static_cast<uint32_t>(mat1.size(2));
params.n = static_cast<uint32_t>(mat2.size(1));
params.inputAddr.resize(2);
params.inputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(mat1.storage().data()));
params.inputAddr[1] = static_cast<uint8_t*>(const_cast<void*>(mat2.storage().data()));
}
static ScratchBundle AllocOutputAndScratch(CatlassKernel::MatmulParams& params)
{
ScratchBundle scratch;
scratch.output = GetOutputTensor({params.batch, params.m, params.n}, torch::kBFloat16);
const uint32_t scaleA1K = CeilDivUint32(params.k, 512);
const uint32_t mxScaleAlignedK = RoundUp2Uint32(CeilDivUint32(params.k, kMxScaleGroupNum));
scratch.scaleA1 = GetOutputTensor({params.batch, params.m, scaleA1K}, torch::kFloat32);
scratch.scaleA2 = GetOutputTensor({params.batch, params.m, mxScaleAlignedK}, torch::kFloat8_e8m0fnu);
scratch.scaleB1 = GetOutputTensor({params.batch, params.n, scaleA1K}, torch::kFloat32);
scratch.scaleB2 = GetOutputTensor({params.batch, params.n, mxScaleAlignedK}, torch::kFloat8_e8m0fnu);
scratch.workspace = GetOutputTensor({ProductInt64(params.batch, params.m, params.k / 2) +
ProductInt64(params.batch, params.n, params.k / 2)}, torch::kUInt8);
params.outputAddr.resize(6);
params.outputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(scratch.output.storage().data()));
params.outputAddr[1] = static_cast<uint8_t*>(const_cast<void*>(scratch.scaleA1.storage().data()));
params.outputAddr[2] = static_cast<uint8_t*>(const_cast<void*>(scratch.scaleA2.storage().data()));
params.outputAddr[3] = static_cast<uint8_t*>(const_cast<void*>(scratch.scaleB1.storage().data()));
params.outputAddr[4] = static_cast<uint8_t*>(const_cast<void*>(scratch.scaleB2.storage().data()));
params.outputAddr[5] = static_cast<uint8_t*>(const_cast<void*>(scratch.workspace.storage().data()));
return scratch;
}
static OutputType Run(const at::Tensor& mat1, const at::Tensor& mat2)
{
CatlassKernel::TParams tParams;
CatlassKernel::MatmulParams params;
GetKernelInfo(mat1, mat2, tParams, params);
ScratchBundle scratch = AllocOutputAndScratch(params);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
RUN_NPU_FUNC(KernelFunc, aicCoreNum, stream, tParams, params);
return scratch.output;
}
};
}
#endif