#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
at::Tensor npu_mm_reduce_scatter_base(const at::Tensor & self, const at::Tensor & x2, c10::string_view hcom,
int64_t world_size, c10::string_view reduce_op, const c10::optional<at::Tensor> & bias,
const c10::optional<at::Tensor> & x1_scale, const c10::optional<at::Tensor> & x2_scale,
int64_t comm_turn, c10::optional<at::ScalarType> output_dtype, c10::optional<c10::string_view> comm_mode)
{
TORCH_CHECK(world_size == 2 || world_size == 4 || world_size == 8 || world_size == 16 || world_size == 32,
"world_size should be in [2, 4, 8, 16, 32], but the actual value is ", world_size, "." + OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(self.dim() == 2 && x2.dim() == 2, "Both inputs of mm are required to be 2D, but the actual inputs are ",
self.dim(), "D and ", x2.dim(), "D." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(self.size(1) == x2.size(0),
"The K-axis in the two inputs of Matmul must be equal, but in reality, the K-axis of x1 is ",
self.size(1), " and the K-axis of x2 is ", x2.size(0), "." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(self.size(0) % world_size == 0, "The M-axis in input of Matmul should be be divisible by world_size."
+ OPS_ERROR(ErrCode::PARAM));
bool isSocBelowAscend950 = (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950);
std::string default_comm_mode = isSocBelowAscend950 ? "ai_cpu" : "";
c10::string_view comm_mode_value = comm_mode.value_or(default_comm_mode);
auto output_size = {self.size(0) / world_size, x2.size(1)};
auto result_dtype = self.scalar_type();
bool has_quant = x2_scale.has_value();
if (has_quant) {
result_dtype = x2_scale.value().scalar_type() == at::kLong ? at::kHalf: output_dtype.value_or(at::kBFloat16);
}
auto result = at_npu::native::OpPreparation::apply_tensor_without_format(output_size, self.options().dtype(result_dtype));
char *reduce_op_ptr = const_cast<char *>(reduce_op.data());
char *hcom_ptr = const_cast<char *>(hcom.data());
char *comm_mode_ptr = const_cast<char *>(comm_mode_value.data());
const at::Tensor &bias_real = bias.value_or(at::Tensor());
int64_t stream_mode = ACL_STOP_ON_FAILURE;
int64_t block_size = 0;
int64_t group_size = 0;
at::Tensor quant_scale;
at::Tensor amax_out;
if (isSocBelowAscend950) {
if (comm_mode_value == "ai_cpu") {
TORCH_CHECK(!has_quant, "When comm_mode is ai_cpu, quantization not supported." + OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnMatmulReduceScatter, self, x2, bias_real, hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, result);
} else {
EXEC_NPU_CMD(aclnnMatmulReduceScatterV2, self, x2, bias_real, x1_scale, x2_scale, quant_scale, block_size,
hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, group_size, comm_mode_ptr, result, amax_out);
}
} else {
EXEC_NPU_CMD(aclnnMatmulReduceScatterV2, self, x2, bias_real, x1_scale, x2_scale, quant_scale, block_size,
hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, group_size, comm_mode_ptr, result, amax_out);
}
FLOP_COUNT(FlopCounter::mm_flop, self, x2);
return result;
}
}