#include <c10/util/accumulate.h>
#include "torch_npu/csrc/framework/StorageDescHelper.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/framework/InferFormat.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/NPUStorageImpl.h"
namespace at_npu {
namespace native {
bool StorageDescHelper::MetaDataAreMatch(const at::Tensor *tensor)
{
auto &desc = torch_npu::NPUBridge::GetNpuStorageImplDesc(*tensor);
return IsSameSize(desc.base_sizes_, tensor->sizes()) && IsSameSize(desc.base_strides_, tensor->strides());
}
bool StorageDescHelper::IsSameDesc(const torch_npu::NPUStorageDesc &a, const torch_npu::NPUStorageDesc &b)
{
if ((a.origin_format_ != b.origin_format_) || (a.npu_format_ != b.npu_format_)) {
if ((!FormatHelper::IsBaseFormatType(a.npu_format_)) || (!FormatHelper::IsBaseFormatType(b.npu_format_))) {
return false;
}
}
return (a.base_sizes_ == b.base_sizes_) && (a.base_strides_ == b.base_strides_) &&
(a.storage_sizes_ == b.storage_sizes_);
}
bool StorageDescHelper::IsSameDesc(const at::Tensor &a, const at::Tensor &b)
{
const auto &descA = torch_npu::NPUBridge::GetNpuStorageImplDesc(a);
const auto &descB = torch_npu::NPUBridge::GetNpuStorageImplDesc(b);
return IsSameDesc(descA, descB);
}
bool StorageDescHelper::IsSameSize(const c10::SmallVector<int64_t, 5> &a, const c10::IntArrayRef &b)
{
if (a.size() == b.size()) {
return std::equal(a.begin(), a.end(), b.begin());
}
return false;
}
void StorageDescHelper::UpdateDesc(torch_npu::NPUStorageDesc &npuDesc, const c10::IntArrayRef &new_data_sizes,
const c10::IntArrayRef &new_shape_sizes)
{
int64_t new_data_numel = c10::multiply_integers(new_data_sizes);
int64_t new_shape_numel = c10::multiply_integers(new_shape_sizes);
const c10::IntArrayRef &new_size = new_data_numel > new_shape_numel ? new_data_sizes : new_shape_sizes;
npuDesc.base_sizes_ = new_size;
int64_t dim_ = static_cast<int64_t>(new_size.size());
c10::SmallVector<int64_t, 5> new_stride(dim_);
if (dim_ > 0) {
int64_t last_idx = dim_ - 1;
new_stride[last_idx] = 1;
for (auto i = last_idx - 1; i >= 0; --i) {
new_stride[i] = new_stride[i + 1] * std::max<int64_t>(new_size[i + 1], 1);
}
}
npuDesc.base_strides_ = new_stride;
unsigned int NCDHW_OR_NDHWC_DIM = 5;
if ((npuDesc.npu_format_ == ACL_FORMAT_NCDHW || npuDesc.npu_format_ == ACL_FORMAT_NDHWC) && new_size.size() < NCDHW_OR_NDHWC_DIM) {
npuDesc.storage_sizes_ = new_size;
} else {
npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc);
}
if (new_data_numel > new_shape_numel) {
npuDesc.storage_sizes_ = new_size;
npuDesc.npu_format_ = InferFormat::GuessStorageFormat(npuDesc.storage_sizes_, ACL_FORMAT_ND);
}
}
FormatShape StorageDescHelper::ComputeStrideFromShape(const FormatShape &shape)
{
FormatShape compute_stride = shape;
compute_stride[shape.size() - 1] = 1;
for (auto i = shape.size() - 1; i > 0; i--) {
compute_stride[i - 1] = shape[i] * compute_stride[i];
}
return compute_stride;
}
void StorageDescHelper::SetDesc(at::Tensor &dst)
{
torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_ = SetDesc(dst.dtype());
}
void StorageDescHelper::SetDesc(at::Tensor &dst, const c10::IntArrayRef& size, const c10::IntArrayRef& strides)
{
torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_ = SetDesc(dst.dtype(), size, strides);
}
void StorageDescHelper::SetDesc(at::Tensor &dst, const c10::IntArrayRef& size, const c10::IntArrayRef& strides,
aclFormat format)
{
torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_ = SetDesc(dst.dtype(), size, strides, format);
}
void StorageDescHelper::SetDesc(at::Tensor &dst, const c10::IntArrayRef &base_size,
const c10::IntArrayRef &storage_size, const c10::IntArrayRef &strides, aclFormat format)
{
torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_ =
SetDesc(dst.dtype(), base_size, storage_size, strides, format);
}
bool StorageDescHelper::CheckDescInit(const c10::Storage &storage)
{
return torch_npu::NPUBridge::GetNpuStorageImpl(storage.unsafeGetStorageImpl())->npu_desc_.origin_format_ !=
ACL_FORMAT_UNDEFINED;
}
void StorageDescHelper::GetDescForSerialization(const at::Tensor &tensor,
std::unordered_map<std::string, bool> &desc_map)
{
auto tensor_storage_impl = torch_npu::NPUBridge::GetNpuStorageImpl(tensor);
if (tensor_storage_impl->data_ptr() == nullptr) {
return;
}
auto &desc = torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor);
auto small_vector_to_str = [](std::string &str, std::string key_name, c10::SmallVector<int64_t, 5> vec) -> void {
str = key_name;
for (const auto i : c10::irange(vec.size())) {
str += '/';
str += std::to_string(vec[i]);
}
str += '/';
return;
};
std::string base_sizes_;
std::string base_strides_;
std::string storage_sizes_;
small_vector_to_str(base_sizes_, "base_sizes_", desc.base_sizes_);
small_vector_to_str(base_strides_, "base_strides_", desc.base_strides_);
small_vector_to_str(
storage_sizes_,
"storage_sizes_",
FormatHelper::GetStorageSizes(ACL_FORMAT_ND, desc.base_sizes_, desc.data_type_));
desc_map[base_sizes_] = true;
desc_map[base_strides_] = true;
desc_map[storage_sizes_] = true;
auto acl_format_to_str = [](std::string &str, std::string key_name, aclFormat format) -> void {
str = key_name;
str += '/';
str += FormatHelper::GetFormatName(format);
return;
};
std::string origin_format_;
std::string npu_format_;
acl_format_to_str(origin_format_, "origin_format_", desc.origin_format_);
acl_format_to_str(npu_format_, "npu_format_", desc.npu_format_);
desc_map[origin_format_] = true;
desc_map[npu_format_] = true;
}
void StorageDescHelper::SetDescForSerialization(const at::Tensor &tensor,
std::unordered_map<std::string, bool> &desc_map)
{
auto &cur_desc = torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor);
struct torch_npu::NPUStorageDesc load_desc;
auto str_to_small_vector = [](std::string str) -> c10::SmallVector<int64_t, 5> {
int start = 0;
while ((start < static_cast<int64_t>(str.size())) && (str[start++] != '/')) {
}
int end = start;
c10::SmallVector<int64_t, 5> vec;
while (end < static_cast<int64_t>(str.size())) {
if (str[end] != '/') {
end++;
} else {
vec.emplace_back(std::stoll(str.substr(start, end - start)));
end++;
start = end;
}
}
return vec;
};
for (auto &m : desc_map) {
if (m.first.find("base_sizes_") != std::string::npos) {
load_desc.base_sizes_ = str_to_small_vector(m.first);
} else if (m.first.find("base_strides_") != std::string::npos) {
load_desc.base_strides_ = str_to_small_vector(m.first);
} else if (m.first.find("storage_sizes_") != std::string::npos) {
load_desc.storage_sizes_ = str_to_small_vector(m.first);
}
}
if (IsSameDesc(cur_desc, load_desc)) {
return;
}
cur_desc.base_sizes_ = load_desc.base_sizes_;
cur_desc.base_strides_ = load_desc.base_strides_;
cur_desc.storage_sizes_ = load_desc.storage_sizes_;
cur_desc.data_type_ = tensor.dtype();
}
void StorageDescHelper::CopyDesc(at::Tensor &dst, const at::Tensor &src)
{
CopyDesc(dst, src.storage());
}
void StorageDescHelper::CopyDesc(at::Tensor &dst, const c10::Storage &src)
{
CopyDesc(dst, torch_npu::NPUBridge::GetNpuStorageImpl(src.unsafeGetStorageImpl())->npu_desc_);
}
void StorageDescHelper::CopyDesc(const at::Tensor &dst, const torch_npu::NPUStorageDesc &src_desc)
{
auto &dstDesc = torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_;
dstDesc = src_desc;
}
void StorageDescHelper::ReflushDescBySelf(const at::Tensor &src)
{
auto &desc = torch_npu::NPUBridge::GetNpuStorageImpl(src)->npu_desc_;
desc.base_sizes_ = src.sizes();
desc.storage_sizes_ = src.sizes();
desc.base_strides_ = src.strides();
}
torch_npu::NPUStorageDesc StorageDescHelper::SetDesc(const caffe2::TypeMeta &dtype)
{
return SetDesc(dtype, {0}, {});
}
torch_npu::NPUStorageDesc StorageDescHelper::SetDesc(const caffe2::TypeMeta &dtype, const c10::IntArrayRef& size,
const c10::IntArrayRef& strides)
{
return SetDesc(dtype, size, strides, InferFormat::GuessBaseFormat(size));
}
torch_npu::NPUStorageDesc StorageDescHelper::SetDesc(const caffe2::TypeMeta &dtype, const c10::IntArrayRef& size,
const c10::IntArrayRef& strides, aclFormat format)
{
struct torch_npu::NPUStorageDesc npu_desc;
npu_desc.data_type_ = dtype;
npu_desc.base_sizes_ = size;
npu_desc.base_strides_ = strides;
aclFormat baseFormat;
aclFormat npuFormat;
std::tie(baseFormat, npuFormat) = InferFormat::GuessFormatUnit(size, format);
npu_desc.storage_sizes_ = FormatHelper::GetStorageSizes(npuFormat, size, dtype);
npu_desc.origin_format_ = baseFormat;
npu_desc.npu_format_ = npuFormat;
return npu_desc;
}
torch_npu::NPUStorageDesc StorageDescHelper::SetDesc(const caffe2::TypeMeta &dtype, const c10::IntArrayRef& base_size,
const c10::IntArrayRef& storage_size, const c10::IntArrayRef& strides, aclFormat format)
{
struct torch_npu::NPUStorageDesc npu_desc;
npu_desc.data_type_ = dtype;
npu_desc.base_sizes_ = base_size;
npu_desc.base_strides_ = strides;
aclFormat baseFormat;
aclFormat npuFormat;
std::tie(baseFormat, npuFormat) = InferFormat::GuessFormatUnit(base_size, format);
npu_desc.storage_sizes_ = storage_size;
npu_desc.origin_format_ = baseFormat;
npu_desc.npu_format_ = npuFormat;
return npu_desc;
}
int64_t StorageDescHelper::GetMemorySize(const torch_npu::NPUStorageDesc &dst)
{
const auto &physical_size = FormatHelper::GetStorageSizes(dst);
return c10::multiply_integers(physical_size);
}
int64_t StorageDescHelper::GetMemorySize(const at::Tensor &dst)
{
auto desc = torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_;
return GetMemorySize(desc);
}
int64_t StorageDescHelper::GetMemorySize(const c10::IntArrayRef& size, aclFormat format, caffe2::TypeMeta dtype)
{
const auto &physical_size = FormatHelper::GetStorageSizes(format, size, dtype);
return c10::multiply_integers(physical_size);
}
int64_t StorageDescHelper::GetValidMemorySize(const at::Tensor &tensor)
{
int64_t real_bytes = 0;
for (int64_t i = tensor.dim() - 1; i >= 0; i--) {
real_bytes += (tensor.size(i) - 1) * tensor.stride(i);
}
return real_bytes + 1;
}
}
}