* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file torch_tensor_converter.cpp
* \brief Implementation of PyTorch tensor to DeviceTensorData conversion.
*/
#include "bindings/torch_tensor_converter.h"
#include "tilefwk/error.h"
#include "interface/utils/error.h"
#include <optional>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
using namespace npu::tile_fwk;
namespace {
constexpr int kNpuFormatNz = 29;
struct TensorDeviceInfo {
std::string type;
int index{0};
};
uintptr_t ReadTensorDataPtr(const py::object& torchTensor)
{
return static_cast<uintptr_t>(py::cast<int64_t>(torchTensor.attr("data_ptr")()));
}
std::vector<int64_t> ReadTensorShape(const py::object& torchTensor, DataType dtype)
{
py::object tensorShape = torchTensor.attr("shape");
std::vector<int64_t> shape;
shape.reserve(static_cast<size_t>(py::len(tensorShape)));
for (auto dim : tensorShape) {
shape.push_back(py::cast<int64_t>(dim));
}
if (dtype == DataType::DT_FP4_E1M2 || dtype == DataType::DT_FP4_E2M1) {
shape.back() *= 0x2;
}
return shape;
}
TensorDeviceInfo ReadDeviceInfo(const py::object& torchTensor)
{
py::object device = torchTensor.attr("device");
std::string type = py::getattr(device, "type").cast<std::string>();
int index = type == "cpu" ? 0 : py::getattr(device, "index").cast<int>();
return TensorDeviceInfo{std::move(type), index};
}
DataType TorchDtypeToDataType(const py::object& torchDtype)
{
static const std::unordered_map<std::string, DataType> dtypeMap = {
{"torch.float16", DataType::DT_FP16},
{"torch.bfloat16", DataType::DT_BF16},
{"torch.float32", DataType::DT_FP32},
{"torch.float64", DataType::DT_DOUBLE},
{"torch.int8", DataType::DT_INT8},
{"torch.uint8", DataType::DT_UINT8},
{"torch.int16", DataType::DT_INT16},
{"torch.uint16", DataType::DT_UINT16},
{"torch.int32", DataType::DT_INT32},
{"torch.uint32", DataType::DT_UINT32},
{"torch.int64", DataType::DT_INT64},
{"torch.uint64", DataType::DT_UINT64},
{"torch.bool", DataType::DT_BOOL},
{"torch.float8_e4m3fn", DataType::DT_FP8E4M3},
{"torch.float8_e5m2", DataType::DT_FP8E5M2},
{"torch.float8_e8m0fnu", DataType::DT_FP8E8M0},
{"torch.float4_e2m1fn_x2", DataType::DT_FP4_E2M1X2},
};
const std::string dtype = py::str(torchDtype).cast<std::string>();
auto iter = dtypeMap.find(dtype);
if (iter != dtypeMap.end()) {
return iter->second;
}
throw std::runtime_error("Input torch.dtype is not supported. Got " + dtype);
}
DataType ReadTensorDataType(const py::object& tensorDef, const py::object& torchTensor, Tensor& baseTensor)
{
if (!tensorDef.attr("explicit_dtype").is_none()) {
return baseTensor.GetDataType();
}
return TorchDtypeToDataType(torchTensor.attr("dtype"));
}
TileOpFormat ReadTensorFormat(
const py::object& tensorDef, const py::object& torchTensor, Tensor& baseTensor, const TensorDeviceInfo& deviceInfo,
py::module_& torch_npu)
{
if (!tensorDef.attr("explicit_format").is_none()) {
return baseTensor.Format();
}
if (deviceInfo.type != "npu") {
return TileOpFormat::TILEOP_ND;
}
if (torch_npu.ptr() == nullptr) {
torch_npu = py::module::import("torch_npu");
}
int npuFormat = py::cast<int>(torch_npu.attr("get_npu_format")(torchTensor));
return npuFormat == kNpuFormatNz ? TileOpFormat::TILEOP_NZ : TileOpFormat::TILEOP_ND;
}
Tensor& ReadTensorDefBase(const py::object& tensorDef)
{
auto baseObj = py::getattr(tensorDef, "_base", py::none());
FE_ASSERT(FeError::INVALID_TYPE, py::isinstance<Tensor>(baseObj))
<< "the '_base' attribute must be a Tensor type";
return baseObj.cast<Tensor&>();
}
bool IsSameDevice(const TensorDeviceInfo& lhs, const TensorDeviceInfo& rhs)
{
return lhs.type == rhs.type && lhs.index == rhs.index;
}
void ValidateTorchTensorType(const py::object& torchTensor, const py::object& torchTensorType, size_t index)
{
if (py::isinstance(torchTensor, torchTensorType)) {
return;
}
throw std::runtime_error(
"Input " + std::to_string(index + 1) + " (index " + std::to_string(index) + ") is not a torch.Tensor");
}
const py::object& GetTorchTensorType()
{
static py::object torchTensorType = py::module::import("torch").attr("Tensor");
return torchTensorType;
}
TensorDeviceInfo ConvertSingleTensor(
const py::object& torchTensor, const py::object& tensorDef, py::module_& torch_npu,
npu::tile_fwk::dynamic::DeviceTensorData& out)
{
TensorDeviceInfo deviceInfo = ReadDeviceInfo(torchTensor);
Tensor& baseTensor = ReadTensorDefBase(tensorDef);
const DataType dtype = ReadTensorDataType(tensorDef, torchTensor, baseTensor);
const uintptr_t dataPtr = ReadTensorDataPtr(torchTensor);
std::vector<int64_t> shape = ReadTensorShape(torchTensor, dtype);
const TileOpFormat format = ReadTensorFormat(tensorDef, torchTensor, baseTensor, deviceInfo, torch_npu);
out = npu::tile_fwk::dynamic::DeviceTensorData(dtype, dataPtr, shape, format);
return deviceInfo;
}
int ValidateDeviceAndReturnIndex(const TensorDeviceInfo& deviceInfo)
{
if (config::GetRuntimeOption<int64_t>(CFG_RUN_MODE) == CFG_RUN_MODE_SIM) {
if (deviceInfo.type != "cpu") {
throw std::runtime_error("Not cpu device");
}
return 0;
}
if (deviceInfo.type != "npu") {
throw std::runtime_error("Not npu device");
}
return deviceInfo.index;
}
}
namespace pypto {
int TorchTensorConverter::Convert(
py::sequence& tensors, py::sequence& tensor_defs,
std::vector<npu::tile_fwk::dynamic::DeviceTensorData>& tensors_data)
{
const size_t n = static_cast<size_t>(py::len(tensors));
CHECK(FeError::INVALID_VAL, n != 0) << "Empty tensor list";
tensors_data.reserve(n);
py::module torch_npu;
const py::object& torchTensorType = GetTorchTensorType();
std::optional<TensorDeviceInfo> commonDeviceInfo;
for (size_t i = 0; i < n; i++) {
py::int_ index(i);
py::object torchTensor = tensors[index];
py::object tensorDef = tensor_defs[index];
ValidateTorchTensorType(torchTensor, torchTensorType, i);
tensors_data.emplace_back();
TensorDeviceInfo tensorDeviceInfo = ConvertSingleTensor(torchTensor, tensorDef, torch_npu, tensors_data.back());
if (!commonDeviceInfo.has_value()) {
commonDeviceInfo.emplace(std::move(tensorDeviceInfo));
} else if (!IsSameDevice(*commonDeviceInfo, tensorDeviceInfo)) {
throw std::runtime_error("All input tensors must be on the same device");
}
}
return ValidateDeviceAndReturnIndex(*commonDeviceInfo);
}
size_t ValidateInputs(py::sequence& tensors, py::sequence& tensorDefs)
{
size_t n = static_cast<size_t>(py::len(tensors));
CHECK(FeError::INVALID_VAL, n == static_cast<size_t>(py::len(tensorDefs)))
<< "Input length mismatch: tensors(" << n << ") vs tensor_defs(" << py::len(tensorDefs) << ")";
CHECK(FeError::INVALID_VAL, n != 0) << "Empty tensor list";
return n;
}
}