#include "torch_npu/csrc/profiler/utils.h"
namespace torch_npu {
namespace profiler {
static constexpr auto kConv2dStride = 3;
static constexpr auto kConv2dPadding = 4;
static constexpr auto kConv2dDilation = 5;
static constexpr auto kConv2dGroups = 6;
static constexpr auto kConv2dOp = "aten::conv2d";
static constexpr auto kGemmOp = "aten::mm";
static constexpr auto kMulOp = "aten::mul";
static constexpr auto kAddOp = "aten::add";
static constexpr auto kInputSize = "input_size";
static constexpr auto kWeightSize = "weight_size";
static constexpr auto kGroups = "groups";
static constexpr auto kPadding = "padding";
static constexpr auto kStride = "stride";
static constexpr auto kDilation = "dilation";
static constexpr auto kMatSize = "mat_size";
static constexpr auto kMat1Size = "mat1_size";
static constexpr auto kMat2Size = "mat2_size";
bool NPURecordFunction::use_npu_simple = false;
static bool validateInput(
const std::string &op_name,
size_t min_size,
c10::ArrayRef<const c10::IValue> inputs,
const c10::ArrayRef<int> &should_be_tensor)
{
std::stringstream ss;
if (inputs.size() < min_size) {
ss << "Failed to save extra arguments for flops compuation of op " << op_name << ", min size: " << min_size <<
", actual size: " << inputs.size();
TORCH_NPU_WARN(ss.str());
return false;
}
for (auto index : should_be_tensor) {
if (!inputs[index].isTensor()) {
ss << "Failed to save extra arguments for flops compuation of op " << op_name << ", input[" << index <<
"] must be a tensor.";
TORCH_NPU_WARN(ss.str());
return false;
}
}
return true;
}
std::unordered_map<std::string, c10::IValue> saveExtraArgs(const at::RecordFunction &fn)
{
std::unordered_map<std::string, c10::IValue> map;
auto inputs = fn.inputs();
std::string fname(fn.name());
if (inputs.empty()) {
return map;
}
if (fname == kConv2dOp) {
std::vector<int> tensors{ 0, 1 };
bool check = validateInput(fname, kConv2dGroups + 1, inputs, tensors);
if (!check) {
return map;
}
at::Tensor input = inputs[0].toTensor();
at::Tensor weight = inputs[1].toTensor();
if (weight.sizes().size() != 4) {
TORCH_NPU_WARN("Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
return map;
}
map[kInputSize] = at::IValue(input.sizes());
map[kWeightSize] = at::IValue(weight.sizes());
map[kStride] = inputs[kConv2dStride];
map[kPadding] = inputs[kConv2dPadding];
map[kDilation] = inputs[kConv2dDilation];
map[kGroups] = inputs[kConv2dGroups];
} else if (fname == kGemmOp) {
std::vector<int> tensors{ 0, 1 };
bool check = validateInput(fname, 2, inputs, tensors);
if (!check) {
return map;
}
at::Tensor left = inputs[0].toTensor();
at::Tensor right = inputs[1].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
} else if (fname == kMulOp) {
std::vector<int> tensors{ 0 };
bool check = validateInput(fname, 1, inputs, tensors);
if (!check) {
return map;
}
at::Tensor mat = inputs[0].toTensor();
map[kMatSize] = at::IValue(mat.sizes());
} else if (fname == kAddOp) {
std::vector<int> tensors{ 0 };
bool check = validateInput(fname, 1, inputs, tensors);
if (!check) {
return map;
}
at::Tensor mat = inputs[0].toTensor();
map[kMatSize] = at::IValue(mat.sizes());
}
return map;
}
uint64_t computeFlops(const std::string &op_name, const std::unordered_map<std::string, c10::IValue> &extra_args)
{
if (op_name == kConv2dOp) {
if (extra_args.find(kInputSize) == extra_args.end() || extra_args.find(kWeightSize) == extra_args.end() ||
extra_args.find(kGroups) == extra_args.end() || extra_args.find(kPadding) == extra_args.end() ||
extra_args.find(kStride) == extra_args.end() || extra_args.find(kDilation) == extra_args.end()) {
TORCH_NPU_WARN("Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, "
"and weight_size in saved arguments.");
return 0;
}
auto input_sizes_ref = extra_args.at(kInputSize);
auto kernel_sizes_ref = extra_args.at(kWeightSize);
auto groups_ref = extra_args.at(kGroups);
auto padding_ref = extra_args.at(kPadding);
auto stride_ref = extra_args.at(kStride);
auto dilation_ref = extra_args.at(kDilation);
if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) {
TORCH_NPU_WARN(
"Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes.");
return 0;
}
if (!padding_ref.isIntList() || !stride_ref.isIntList() || !dilation_ref.isIntList()) {
TORCH_NPU_WARN("Failed to compute flops for op aten::conv2d because it requires padding, stride, and "
"dilation values.");
return 0;
}
const std::vector<int64_t> input_sizes = input_sizes_ref.toIntVector();
const std::vector<int64_t> kernel_sizes = kernel_sizes_ref.toIntVector();
const uint64_t groups = (uint64_t)groups_ref.toInt();
const std::vector<int64_t> padding = padding_ref.toIntVector();
const std::vector<int64_t> stride = stride_ref.toIntVector();
const std::vector<int64_t> dilation = dilation_ref.toIntVector();
if (input_sizes.size() != 4 || kernel_sizes.size() != 4) {
TORCH_NPU_WARN("Failed to compute flops for op aten::conv2d because both input and weight must be size 4.");
return 0;
}
if (!groups) {
TORCH_NPU_WARN("Failed to compute flops for op aten::conv2d because group size must not be 0.");
return 0;
}
if (padding.size() != 2 || dilation.size() != 2) {
TORCH_NPU_WARN(
"Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2.");
return 0;
}
if (stride.size() != 2 || (stride[0] * stride[1] == 0)) {
TORCH_NPU_WARN(
"Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0.");
return 0;
}
uint64_t minibatch = 0;
uint64_t in_channels = 0;
uint64_t input_h = 0;
uint64_t input_w = 0;
uint64_t out_channels = 0;
uint64_t kernel_h = 0;
uint64_t kernel_w = 0;
const uint64_t conv2d_multiply_factor = 2;
std::tie(minibatch, in_channels, input_h, input_w) =
std::make_tuple(input_sizes[0], input_sizes[1], input_sizes[2], input_sizes[3]);
std::tie(out_channels, std::ignore, kernel_h, kernel_w) =
std::make_tuple(kernel_sizes[0], kernel_sizes[1], kernel_sizes[2], kernel_sizes[3]);
uint64_t output_h = (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) / stride[0] + 1;
uint64_t output_w = (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) / stride[1] + 1;
if (groups == 0) {
TORCH_CHECK(false, "groups can not be 0.", PTA_ERROR(ErrCode::VALUE));
}
return conv2d_multiply_factor * minibatch * output_h * output_w * kernel_h * kernel_w * in_channels *
out_channels / groups;
} else if (op_name == kGemmOp) {
if (extra_args.find(kMat1Size) == extra_args.end() || extra_args.find(kMat2Size) == extra_args.end()) {
TORCH_NPU_WARN("Calculating flops for aten::mm requires mat1_size and mat2_size in saved arguments.");
return 0;
}
auto mat1_sizes_ref = extra_args.at(kMat1Size);
auto mat2_sizes_ref = extra_args.at(kMat2Size);
if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
TORCH_NPU_WARN(
"Failed to compute flops for op aten::mm because it requires mat1_size and mat2_size to be IntList.");
return 0;
}
std::vector<int64_t> mat1_size = mat1_sizes_ref.toIntVector();
std::vector<int64_t> mat2_size = mat2_sizes_ref.toIntVector();
if (mat1_size.size() == 0) {
return 0;
} else {
int64_t overlap_dim = mat1_size.back();
const uint64_t gemm_multiply_factor = 2;
uint64_t flops = 1;
for (int64_t dim : mat1_size) {
flops *= (uint64_t)dim;
}
if (overlap_dim == 0) {
TORCH_CHECK(false, "overlap_dim can not be 0.", PTA_ERROR(ErrCode::VALUE));
}
flops /= (uint64_t)overlap_dim;
for (int64_t dim : mat2_size) {
flops *= (uint64_t)dim;
}
flops *= gemm_multiply_factor;
return flops;
}
} else if (op_name == kMulOp) {
if (extra_args.find(kMatSize) == extra_args.end()) {
TORCH_NPU_WARN("Calculating flops for aten::mul.Tensor requires mat_size in saved arguments.");
return 0;
}
auto mat_sizes = extra_args.at(kMatSize);
if (!mat_sizes.isIntList()) {
TORCH_NPU_WARN("Failed to compute flops for op aten::mul because it requires mat_size to be IntList.");
return 0;
}
std::vector<int64_t> mat_size = mat_sizes.toIntVector();
uint64_t flops = 1;
for (int64_t dim : mat_size) {
flops *= (uint64_t)dim;
}
return flops;
} else if (op_name == kAddOp) {
if (extra_args.find(kMatSize) == extra_args.end()) {
TORCH_NPU_WARN("Calculating flops for aten::add.Tensor requires mat_size in saved arguments.");
return 0;
}
auto mat_sizes = extra_args.at(kMatSize);
if (!mat_sizes.isIntList()) {
TORCH_NPU_WARN("Failed to compute flops for op aten::add because it requires mat_size to be IntList.");
return 0;
}
std::vector<int64_t> mat_size = mat_sizes.toIntVector();
uint64_t flops = 1;
for (int64_t dim : mat_size) {
flops *= (uint64_t)dim;
}
return flops;
}
return 0;
}
}
}