#include <ATen/record_function.h>
#include <ATen/native/Resize.h>
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/NPUStorageImpl.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include "torch_npu/csrc/core/npu/NPUWorkspaceAllocator.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/framework/InferFormat.h"
#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h"
#ifndef BUILD_LIBTORCH
#include "torch_npu/csrc/profiler/utils.h"
#endif
#include "torch_npu/csrc/core/npu/npu_log.h"
namespace at_npu {
namespace native {
UnifiedResult OpPreparation::binary_op_check(at::Tensor &out,
const at::Tensor &a,
const at::Tensor &b,
bool check_mem_overlap)
{
UnifiedResult unified_result;
if (a.dtype() != b.dtype()) {
std::tuple<at::ScalarType, c10::IntArrayRef> binary_op =
NPUTensorIterator::binary_op(out, a, b, check_mem_overlap);
unified_result.common_type = std::get<0>(binary_op);
unified_result.common_shape = std::get<1>(binary_op);
}
return unified_result;
}
UnifiedResult OpPreparation::binary_op_check(at::Tensor &out,
const at::Tensor &a,
const c10::Scalar b,
bool check_mem_overlap)
{
UnifiedResult unified_result;
std::tuple<at::ScalarType, c10::IntArrayRef> binary_op = NPUTensorIterator::binary_op(a, b);
unified_result.common_type = std::get<0>(binary_op);
unified_result.common_shape = std::get<1>(binary_op);
return unified_result;
}
UnifiedResult OpPreparation::comparison_op_check(at::Tensor &out,
const at::Tensor &a,
const at::Tensor &b,
bool check_mem_overlap)
{
UnifiedResult unified_result;
if (a.dtype() != b.dtype()) {
std::tuple<at::ScalarType, c10::IntArrayRef> comparison_op =
NPUTensorIterator::comparison_op(out, a, b, check_mem_overlap);
unified_result.common_type = std::get<0>(comparison_op);
unified_result.common_shape = std::get<1>(comparison_op);
}
if (out.dtype() != a.dtype() && out.dtype() != b.dtype()) {
unified_result.result_type_defined = true;
}
return unified_result;
}
UnifiedResult OpPreparation::unary_op_check(at::Tensor &out, const at::Tensor &a, bool check_mem_overlap)
{
UnifiedResult unified_result;
std::tuple<at::ScalarType, c10::IntArrayRef> unary_op = NPUTensorIterator::unary_op(out, a, check_mem_overlap);
unified_result.common_type = std::get<0>(unary_op);
unified_result.common_shape = std::get<1>(unary_op);
return unified_result;
}
void OpPreparation::nullary_op(at::Tensor &out) { NPUTensorIterator::nullary_op(out); }
UnifiedResult OpPreparation::reduce_op_check(at::Tensor &out, const at::Tensor &a)
{
UnifiedResult unified_result;
std::tuple<at::ScalarType, c10::IntArrayRef> reduce_op = NPUTensorIterator::reduce_op(out, a);
unified_result.common_type = std::get<0>(reduce_op);
unified_result.common_shape = std::get<1>(reduce_op);
return unified_result;
}
UnifiedResult OpPreparation::reduce_op_check(at::Tensor &out1, at::Tensor &out2, const at::Tensor &a)
{
UnifiedResult unified_result;
std::tuple<at::ScalarType, c10::IntArrayRef> reduce_op = NPUTensorIterator::reduce_op(out1, out2, a);
unified_result.common_type = std::get<0>(reduce_op);
unified_result.common_shape = std::get<1>(reduce_op);
return unified_result;
}
aclDataType OpPreparation::convert_to_acl_data_type(const at::ScalarType &data_type)
{
return CalcuOpUtil::ConvertToAclDataType(data_type);
}
aclDataType OpPreparation::convert_to_acl_data_type(const at::ScalarType &data_type, const std::string &realDataType)
{
return CalcuOpUtil::ConvertToAclDataType(data_type, realDataType);
}
at::ScalarType OpPreparation::convert_to_scalar_type(const aclDataType data_type)
{
return CalcuOpUtil::ConvertToScalarType(data_type);
}
at::Tensor OpPreparation::copy_scalar_to_device(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type)
{
return CalcuOpUtil::CopyScalarToDevice(cpu_scalar, scalar_data_type);
}
at::Tensor OpPreparation::copy_scalar_to_device(const c10::Scalar &cpu_scalar,
at::ScalarType scalar_data_type,
const c10::Device device)
{
c10_npu::NPUGuard guard(device);
return copy_scalar_to_device(cpu_scalar, scalar_data_type);
}
at::Tensor OpPreparation::copy_tensor_host_to_device(const at::Tensor &cpu_tensor)
{
return CalcuOpUtil::CopyTensorHostToDevice(cpu_tensor);
}
bool OpPreparation::is_scalar_wrapped_to_tensor(const at::Tensor &tensor) { return IsCPUScalar(tensor); }
c10::SmallVector<int64_t, 5> OpPreparation::get_tensor_desc_base_sizes(const at::Tensor &tensor)
{
return torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->get_npu_desc().base_sizes_;
}
int64_t OpPreparation::get_tensor_npu_format(const at::Tensor &tensor)
{
return CalcuOpUtil::GetTensorNpuFormat(tensor);
}
static bool check_inplace_tensor(const std::initializer_list<at::Tensor> &src_list, at::Tensor &dst)
{
bool is_inplace_tensor = false;
for (const auto &src : src_list) {
if (dst.is_same(src)) {
is_inplace_tensor = true;
break;
}
}
return is_inplace_tensor;
}
static void check_tensor_size(const std::initializer_list<at::Tensor> &src_list,
at::Tensor &dst,
c10::IntArrayRef expect_size)
{
bool is_inplace = check_inplace_tensor(src_list, dst);
if (!dst.sizes().equals(expect_size)) {
TORCH_CHECK(!is_inplace,
"output with shape ",
dst.sizes(),
" doesn't match the broadcast shape ",
expect_size,
OPS_ERROR(ErrCode::PARAM));
if (at::native::resize_output_check(dst, expect_size)) {
dst.resize_(expect_size);
}
}
return;
}
void OpPreparation::check_tensor(const std::initializer_list<at::Tensor> &src_list,
at::Tensor &dst,
at::ScalarType expect_dtype,
c10::IntArrayRef expect_size)
{
check_memory(src_list, {dst});
TORCH_CHECK(torch_npu::utils::is_npu(dst),
"output with device ",
dst.device(),
" doesn't match the desired device NPU",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(dst.scalar_type() == expect_dtype,
"expected dtype ",
expect_dtype,
" but got dtype ",
dst.scalar_type(),
OPS_ERROR(ErrCode::TYPE));
check_tensor_size(src_list, dst, expect_size);
}
void OpPreparation::check_tensor(const std::initializer_list<at::Tensor> &src_list,
at::Tensor &dst,
c10::IntArrayRef expect_size)
{
check_memory(src_list, {dst});
TORCH_CHECK(torch_npu::utils::is_npu(dst),
"output with device ",
dst.device(),
" doesn't match the desired device NPU",
OPS_ERROR(ErrCode::PARAM));
check_tensor_size(src_list, dst, expect_size);
}
void OpPreparation::check_tensor(const std::initializer_list<at::Tensor> &src_list,
at::Tensor &dst,
const at::Tensor &expect_tensor)
{
check_tensor(src_list, dst, expect_tensor.scalar_type(), expect_tensor.sizes());
}
void OpPreparation::check_tensor(const std::initializer_list<at::Tensor> &src_list,
at::Tensor &dst,
const at::Tensor &expect_tensor,
c10::IntArrayRef expect_size)
{
check_tensor(src_list, dst, expect_tensor.scalar_type(), expect_size);
}
void OpPreparation::check_memory(const std::initializer_list<at::Tensor> &inputs,
const std::initializer_list<at::Tensor> &outputs)
{
c10::SmallVector<at::Tensor, N> in = inputs;
c10::SmallVector<at::Tensor, N> out = outputs;
CalcuOpUtil::CheckMemoryOverLaps(in, out);
}
at::Tensor OpPreparation::cast_to_ori_format(const at::Tensor &tensor)
{
auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_;
auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_);
return ret;
}
at::Tensor &OpPreparation::cast_to_ori_format(at::Tensor &tensor)
{
auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_;
NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_);
return tensor;
}
at::Tensor OpPreparation::apply_tensor(const at::Tensor &src) { return apply_tensor(src, src.sizes()); }
at::Tensor OpPreparation::apply_tensor(const at::Tensor &src, c10::IntArrayRef sizes)
{
return apply_tensor_with_format(sizes, src.options(), CalcuOpUtil::GetTensorNpuFormat(src));
}
at::Tensor OpPreparation::apply_tensor(const at::Tensor &src, const c10::TensorOptions &options)
{
return apply_tensor_with_format(src.sizes(), options, CalcuOpUtil::GetTensorNpuFormat(src));
}
at::Tensor OpPreparation::apply_tensor(c10::IntArrayRef sizes, const c10::TensorOptions &options, const at::Tensor &src)
{
return apply_tensor_with_format(sizes, options, CalcuOpUtil::GetTensorNpuFormat(src));
}
at::Tensor OpPreparation::apply_tensor_with_format(const at::Tensor &src, int64_t format, bool keep_format)
{
return apply_tensor_with_format(src, src.sizes(), format, keep_format);
}
at::Tensor OpPreparation::apply_tensor_with_format(const at::Tensor &src,
c10::IntArrayRef sizes,
int64_t format,
bool keep_format)
{
return apply_tensor_with_format(sizes, src.options(), format, keep_format);
}
at::Tensor OpPreparation::apply_tensor_with_format(c10::IntArrayRef sizes,
const c10::TensorOptions &options,
int64_t format,
bool keep_format)
{
TORCH_CHECK(options.device().type() == c10::DeviceType::PrivateUse1,
"Expected all tensors to be on the same device. "
"Expected NPU tensor, please check whether the input tensor device is correct.",
OPS_ERROR(ErrCode::TYPE));
auto fixFormat = InferFormat::GuessStorageFormat(sizes, static_cast<aclFormat>(format));
if (options.dtype_opt() == at::ScalarType::Double && !FormatHelper::IsBaseFormatType(static_cast<aclFormat>(format))) {
ASCEND_LOGW("NPU don't support create double dtype tensor with inner format, repalce with base format.");
fixFormat = FormatHelper::GetBaseFormat(static_cast<aclFormat>(format));
}
return NPUNativeFunctions::unsafe_empty_with_format(sizes,
c10::optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
fixFormat,
keep_format);
}
at::Tensor OpPreparation::apply_tensor_with_sizes(c10::IntArrayRef sizes, const c10::TensorOptions &options)
{
auto format = InferFormat::GuessBaseFormat(sizes);
return NPUNativeFunctions::empty_with_format(sizes,
c10::optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
format,
c10::nullopt);
}
void OpPreparation::CheckOut(const std::initializer_list<at::Tensor> &inputs, at::Tensor &output, at::Tensor dst)
{
CheckOut(inputs, output, CalcuOpUtil::GetTensorNpuFormat(dst), dst.scalar_type(), dst.sizes());
}
void OpPreparation::CheckOut(const std::initializer_list<at::Tensor> &inputs,
at::Tensor &output,
at::Tensor dst,
c10::IntArrayRef shape)
{
CheckOut(inputs, output, CalcuOpUtil::GetTensorNpuFormat(dst), dst.scalar_type(), shape);
}
void OpPreparation::CheckOut(const std::initializer_list<at::Tensor> &input,
at::Tensor &output,
int64_t format,
at::ScalarType dtype,
c10::IntArrayRef shape)
{
c10::SmallVector<at::Tensor, N> inputs{input};
c10::SmallVector<at::Tensor, N> outputs = {output};
CalcuOpUtil::CheckMemoryOverLaps(inputs, outputs);
TORCH_CHECK(torch_npu::utils::is_npu(output),
"output with device ",
output.device(),
" doesn't match the desired device NPU",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(output.scalar_type() == dtype,
"expected dtype ",
dtype,
" but got dtype ",
output.scalar_type(),
OPS_ERROR(ErrCode::TYPE));
bool is_read_write = false;
for (const auto &local_input : inputs) {
if (output.is_same(local_input)) {
is_read_write = true;
break;
}
}
if (!output.sizes().equals(shape)) {
TORCH_CHECK(!is_read_write,
"output with shape ",
output.sizes(),
" doesn't match the broadcast shape ",
shape,
OPS_ERROR(ErrCode::PARAM));
if (at::native::resize_output_check(output, shape)) {
output.resize_(shape);
}
}
if (CalcuOpUtil::GetTensorNpuFormat(output) != format) {
TORCH_CHECK(!is_read_write, "can not cast format when output is input", OPS_ERROR(ErrCode::NOT_SUPPORT));
NPUNativeFunctions::npu_format_cast_(output, format);
}
}
at::Tensor OpPreparation::CastBackToOriFormat(const at::Tensor &tensor)
{
auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_;
auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_);
return ret;
}
at::Tensor &OpPreparation::CastBackToOriFormat(at::Tensor &tensor)
{
auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_;
NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_);
return tensor;
}
int8_t OpPreparation::get_cube_math_type() { return CalcuOpUtil::GetCubeMathType(); }
int8_t OpPreparation::get_cube_math_type(bool allowHf32) { return CalcuOpUtil::GetCubeMathType(allowHf32); }
inline at::Tensor apply_tensor_use_empty(c10::IntArrayRef sizes, const c10::TensorOptions &options)
{
c10::optional<c10::Device> device_opt = options.device_opt();
if (c10::device_or_default(device_opt).type() != c10::DeviceType::PrivateUse1) {
device_opt = at::Device(c10::DeviceType::PrivateUse1);
}
return NPUNativeFunctions::empty(
sizes, options.dtype().toScalarType(), c10::nullopt, device_opt, false, c10::MemoryFormat::Contiguous);
}
at::Tensor OpPreparation::apply_tensor_without_format(const at::Tensor &src)
{
return apply_tensor_use_empty(src.sizes(), src.options());
}
at::Tensor OpPreparation::apply_tensor_without_format(const at::Tensor &src, c10::IntArrayRef sizes)
{
return apply_tensor_use_empty(sizes, src.options());
}
at::Tensor OpPreparation::apply_tensor_without_format(c10::IntArrayRef sizes, const c10::TensorOptions &options)
{
return apply_tensor_use_empty(sizes, options);
}
at::Tensor OpPreparation::unsafe_empty_workspace(uint64_t workspace_size)
{
#ifndef BUILD_LIBTORCH
torch_npu::profiler::NPURecordFunction profiler_guard;
#endif
ASCEND_LOGD("Alloc workspace %zu bytes unsafely.", workspace_size);
c10::Allocator *allocator = c10_npu::NPUCachingAllocator::get();
c10::intrusive_ptr<c10::StorageImpl> storage_impl =
torch_npu::make_npu_storage_impl(c10::StorageImpl::use_byte_size_t(),
c10::SymInt(workspace_size),
allocator->allocate(workspace_size),
allocator,
true);
static auto dtype = c10::scalarTypeToTypeMeta(dtype_or_default(at::kByte));
auto tensor = at::detail::make_tensor<torch_npu::NPUTensorImpl>(storage_impl, dtype);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(c10::MemoryFormat::Contiguous);
return tensor;
}
at::Tensor OpPreparation::unsafe_empty_workspace(uint64_t workspace_size, aclrtStream stream)
{
#ifndef BUILD_LIBTORCH
torch_npu::profiler::NPURecordFunction profiler_guard;
#endif
ASCEND_LOGD("Alloc workspace %zu bytes unsafely.", workspace_size);
c10::Allocator *allocator = c10_npu::NPUWorkspaceAllocator::get();
c10::intrusive_ptr<c10::StorageImpl> storage_impl = c10::make_intrusive<torch_npu::NPUStorageImpl>(
c10::StorageImpl::use_byte_size_t(),
workspace_size,
c10_npu::NPUWorkspaceAllocator::malloc_with_stream(workspace_size, stream),
allocator,
true);
static auto dtype = c10::scalarTypeToTypeMeta(dtype_or_default(at::kByte));
auto tensor = at::detail::make_tensor<torch_npu::NPUTensorImpl>(storage_impl, dtype);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(c10::MemoryFormat::Contiguous);
return tensor;
}
at::Tensor OpPreparation::ApplyTensor(const at::Tensor &src) { return ApplyTensor(src, src.sizes()); }
at::Tensor OpPreparation::ApplyTensor(const at::Tensor &src, c10::IntArrayRef sizes)
{
return ApplyTensorWithFormat(sizes, src.options(), CalcuOpUtil::GetTensorNpuFormat(src));
}
at::Tensor OpPreparation::ApplyTensor(const at::Tensor &src, const c10::TensorOptions &options)
{
return ApplyTensorWithFormat(src.sizes(), options, CalcuOpUtil::GetTensorNpuFormat(src));
}
at::Tensor OpPreparation::ApplyTensor(c10::IntArrayRef sizes, const c10::TensorOptions &options, const at::Tensor &src)
{
return ApplyTensorWithFormat(sizes, options, CalcuOpUtil::GetTensorNpuFormat(src));
}
at::Tensor OpPreparation::ApplyTensorWithFormat(const at::Tensor &src, int64_t format, bool keep_format)
{
return ApplyTensorWithFormat(src, src.sizes(), format, keep_format);
}
at::Tensor OpPreparation::ApplyTensorWithFormat(const at::Tensor &src,
c10::IntArrayRef sizes,
int64_t format,
bool keep_format)
{
return ApplyTensorWithFormat(sizes, src.options(), format, keep_format);
}
at::Tensor OpPreparation::ApplyTensorWithFormat(c10::IntArrayRef sizes,
const c10::TensorOptions &options,
int64_t format,
bool keep_format)
{
TORCH_CHECK(options.device().type() == c10::DeviceType::PrivateUse1,
"Expected all tensors to be on the same device. "
"Expected NPU tensor, please check whether the input tensor device is correct.",
OPS_ERROR(ErrCode::TYPE));
auto fixFormat = InferFormat::GuessStorageFormat(sizes, static_cast<aclFormat>(format));
return NPUNativeFunctions::unsafe_empty_with_format(sizes,
c10::optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
fixFormat,
keep_format);
}
at::Tensor OpPreparation::ApplyTensorWithSizes(c10::IntArrayRef sizes, const c10::TensorOptions &options)
{
auto format = InferFormat::GuessBaseFormat(sizes);
return NPUNativeFunctions::empty_with_format(sizes,
c10::optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
format,
c10::nullopt);
}
void OpPreparation::CheckMemory(const std::initializer_list<at::Tensor> &inputs,
const std::initializer_list<at::Tensor> &outputs)
{
c10::SmallVector<at::Tensor, N> in = inputs;
c10::SmallVector<at::Tensor, N> out = outputs;
CalcuOpUtil::CheckMemoryOverLaps(in, out);
}
bool OpPreparation::IsCPUScalar(const at::Tensor &tensor)
{
if (tensor.dim() == 0 && !torch_npu::utils::is_npu(tensor)) {
return true;
}
return false;
}
int OpPreparation::GetAclDataTypeItemSize(aclDataType acl_type)
{
switch (acl_type) {
case ACL_INT8:
case ACL_UINT8:
case ACL_BOOL:
case ACL_HIFLOAT8:
case ACL_FLOAT8_E5M2:
case ACL_FLOAT8_E4M3FN:
case ACL_FLOAT8_E8M0:
case ACL_FLOAT4_E2M1:
case ACL_FLOAT4_E1M2:
return 1;
case ACL_FLOAT16:
case ACL_INT16:
case ACL_UINT16:
case ACL_BF16:
return 2;
case ACL_FLOAT:
case ACL_INT32:
case ACL_UINT32:
case ACL_COMPLEX32:
return 4;
case ACL_INT64:
case ACL_UINT64:
case ACL_DOUBLE:
case ACL_COMPLEX64:
return 8;
case ACL_COMPLEX128:
return 16;
default:
TORCH_CHECK(false,
"Unsupported acl_type:", acl_type, PTA_ERROR(ErrCode::NOT_SUPPORT));
}
}
}
}