#include <ATen/AccumulateType.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
at::Tensor npu_scaled_masked_softmax(const at::Tensor& x, const at::Tensor& mask, const at::Scalar& scale,
bool fixed_triu_mask)
{
if (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950) {
return acl_op::npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask);
}
DO_COMPATIBILITY(aclnnScaledMaskedSoftmax, acl_op::npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask));
double scale_value = 1.0;
if (scale.isFloatingPoint()) {
scale_value = scale.to<double>();
} else if (scale.isIntegral(true)) {
scale_value = static_cast<double>(scale.to<int64_t>());
} else {
TORCH_CHECK(false, "scaled_masked_softmax expects scale to be float or int", OPS_ERROR(ErrCode::TYPE));
}
at::Tensor result = npu_preparation::apply_tensor_without_format(x.sizes(), x.options());
EXEC_NPU_CMD(aclnnScaledMaskedSoftmax, x, mask, scale_value, fixed_triu_mask, result);
return result;
}
at::Tensor npu_scaled_masked_softmax_backward(
const at::Tensor& y_grad,
const at::Tensor& y,
const at::Tensor& mask,
const at::Scalar& scale,
bool fixed_triu_mask)
{
if (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950) {
return acl_op::npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask);
}
DO_COMPATIBILITY(aclnnScaledMaskedSoftmaxBackward, acl_op::npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask));
double scale_value = 1.0;
if (scale.isFloatingPoint()) {
scale_value = scale.to<double>();
} else if (scale.isIntegral(true)) {
scale_value = static_cast<double>(scale.to<int64_t>());
} else {
TORCH_CHECK(false, "scaled_masked_softmax_backward expects scale to be float or int", OPS_ERROR(ErrCode::TYPE));
}
at::Tensor result = npu_preparation::apply_tensor_without_format(y_grad.sizes(), y_grad.options());
EXEC_NPU_CMD(aclnnScaledMaskedSoftmaxBackward, y_grad, y, mask, scale_value, fixed_triu_mask, result);
return result;
}
}