#include "torch_npu/csrc/flopcount/FlopCounter.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
int64_t FlopCounter::mm_flop(const at::Tensor &tensor1, const at::Tensor &tensor2)
{
auto dim_tensor1 = tensor1.dim();
auto dim_tensor2 = tensor2.dim();
TORCH_CHECK(dim_tensor1 > 0 && dim_tensor2 > 0, "matmul got error dimentions: ", "(", dim_tensor1, ", ",
dim_tensor2, ")", OPS_ERROR(ErrCode::PARAM));
auto x1_size = dim_tensor1 > 2 ? dim_tensor1 - 2 : 0;
auto x2_size = dim_tensor2 > 2 ? dim_tensor2 - 2 : 0;
at::IntArrayRef x1_sizes(tensor1.sizes().data(), x1_size);
at::IntArrayRef x2_sizes(tensor2.sizes().data(), x2_size);
std::vector<int64_t> output_size = at::infer_size(x1_sizes, x2_sizes);
if (dim_tensor1 >= 2) {
output_size.push_back(tensor1.size(-2));
}
if (dim_tensor2 >= 2) {
output_size.push_back(tensor2.size(-1));
}
int64_t k = tensor1.size(-1);
int64_t flop = 2 * k;
for (const auto& elem : output_size) {
flop *= elem;
}
return flop;
}
int64_t FlopCounter::all_gather_mm_flop(const at::Tensor &self, const at::Tensor &mat2, int64_t world_size, int64_t gather_index)
{
int64_t mm_flops = mm_flop(self, mat2);
return gather_index == 0 ? mm_flops * world_size : mm_flops;
}
int64_t FlopCounter::addmm_flop(const at::Tensor &mat1, const at::Tensor &mat2)
{
return mm_flop(mat1, mat2);
}
int64_t FlopCounter::bmm_flop(const at::Tensor &self, const at::Tensor &mat2)
{
int64_t b = self.size(0);
int64_t m = self.size(1);
int64_t k = self.size(2);
int64_t b2 = mat2.size(0);
int64_t k2 = mat2.size(1);
int64_t n = mat2.size(2);
TORCH_CHECK(b == b2 && k == k2, "The tensor dimension is incorrect", PTA_ERROR(ErrCode::VALUE));
return b * m * n * 2 * k;
}
int64_t FlopCounter::baddbmm_flop(const at::Tensor &batch1, const at::Tensor &batch2)
{
return bmm_flop(batch1, batch2);
}
int64_t conv_flop_count(std::vector<int64_t> x_shape, std::vector<int64_t> w_shape, std::vector<int64_t> out_shape, bool transposed)
{
int64_t batch_size = x_shape[0];
std::vector<int64_t> conv_shape = transposed ? out_shape : std::vector<int64_t>(out_shape.begin() + 2, out_shape.end());
int64_t c_out = w_shape[0];
int64_t c_in = w_shape[1];
int64_t filter_size = std::accumulate(w_shape.begin() + 2, w_shape.end(), 1, std::multiplies<int>());
int64_t flop = std::accumulate(conv_shape.begin(), conv_shape.end(), 1, std::multiplies<int>()) * filter_size * batch_size * c_out * c_in * 2;
return flop;
}
int64_t FlopCounter::conv_flop(const at::Tensor &input, const at::Tensor &weight, bool transposed, at::Tensor output)
{
std::vector<int64_t> out_shape(output.sizes().begin(), output.sizes().end());
std::vector<int64_t> x_shape(input.sizes().begin(), input.sizes().end());
std::vector<int64_t> w_shape(weight.sizes().begin(), weight.sizes().end());
return conv_flop_count(x_shape, w_shape, out_shape, transposed);
}
std::vector<int64_t> t(std::vector<int64_t> shape)
{
return {shape[1], shape[0], shape[2], shape[3]};
}
int64_t FlopCounter::conv_backward_flop(const at::Tensor &grad_output, const at::Tensor &input,
const at::Tensor &weight, bool transposed, ::std::array<bool, 3> output_mask,
const at::Tensor &gradInput, const at::Tensor &gradeWeight)
{
std::vector<int64_t> grad_output_shape(grad_output.sizes().begin(), grad_output.sizes().end());
std::vector<int64_t> w_shape(weight.sizes().begin(), weight.sizes().end());
std::vector<int64_t> input_shape(input.sizes().begin(), input.sizes().end());
int64_t flop_count = 0;
if (output_mask[0]) {
std::vector<int64_t> grad_input_shape(gradInput.sizes().begin(), gradInput.sizes().end());
flop_count += conv_flop_count(grad_output_shape, w_shape, grad_input_shape, !transposed);
}
if (output_mask[1]) {
std::vector<int64_t> grad_weight_shape(gradeWeight.sizes().begin(), gradeWeight.sizes().end());
if (transposed) {
flop_count += conv_flop_count(t(grad_output_shape), t(input_shape), t(grad_weight_shape), false);
} else {
flop_count += conv_flop_count(t(input_shape), t(grad_output_shape), t(grad_weight_shape), false);
}
}
return flop_count;
}
std::vector<std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>>> _unpack_flash_attention_nested_shapes(std::vector<int64_t> query,
std::vector<int64_t> key, std::vector<int64_t> value, int64_t head_num, std::vector<int64_t> grad_out,
c10::ArrayRef<int64_t> cum_seq_q, c10::ArrayRef<int64_t> cum_seq_k, std::string input_layer_str)
{
TORCH_CHECK(head_num != 0, "Divisor head_num may be 0, please check it.")
std::vector<std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>>> result;
int64_t q_1 = query[1];
int64_t q_2 = query[2];
int64_t k_1 = key[1];
int64_t k_2 = key[2];
int64_t v_1 = value[1];
int64_t v_2 = value[2];
if (input_layer_str == "SBH" || input_layer_str == "BSH" || input_layer_str == "BSND") {
if (q_2 != k_2 && q_2 != v_2) {
k_2 = q_2;
v_2 = q_2;
}
} else {
if (q_1 != k_1 && q_1 != v_1) {
k_1 = q_1;
v_1 = q_1;
}
}
if (input_layer_str == "BSH") {
std::vector<int64_t> new_query_shape = {query[0], head_num, q_1, q_2 / head_num};
std::vector<int64_t> new_key_shape = {key[0], head_num, k_1, k_2 / head_num};
std::vector<int64_t> new_value_shape = {value[0], head_num, v_1, v_2 / head_num};
std::vector<int64_t> new_grad_out_shape;
if (!grad_out.empty()) {
new_grad_out_shape = new_query_shape;
}
result.emplace_back(new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape);
} else if (input_layer_str == "SBH") {
std::vector<int64_t> new_query_shape = {q_1, head_num, query[0], q_2 / head_num};
std::vector<int64_t> new_key_shape = {k_1, head_num, key[0], k_2 / head_num};
std::vector<int64_t> new_value_shape = {v_1, head_num, value[0], v_2 / head_num};
std::vector<int64_t> new_grad_out_shape;
if (!grad_out.empty()) {
new_grad_out_shape = new_query_shape;
}
result.emplace_back(new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape);
} else if (input_layer_str == "BNSD") {
std::vector<int64_t> new_grad_out_shape;
if (!grad_out.empty()) {
new_grad_out_shape = query;
}
result.emplace_back(query, key, value, new_grad_out_shape);
} else if (input_layer_str == "BSND") {
std::vector<int64_t> new_query_shape = {query[0], q_2, q_1, query[3]};
std::vector<int64_t> new_key_shape = {key[0], k_2, k_1, key[3]};
std::vector<int64_t> new_value_shape = {value[0], v_2, v_1, value[3]};
std::vector<int64_t> new_grad_out_shape;
if (!grad_out.empty()) {
new_grad_out_shape = new_query_shape;
}
result.emplace_back(new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape);
} else if (input_layer_str == "TND") {
TORCH_CHECK(!cum_seq_q.empty(), "The actual_seq_qlen should not be empty when TND");
TORCH_CHECK(!cum_seq_k.empty(), "The actual_seq_kvlen should not be empty when TND");
TORCH_CHECK(cum_seq_q.size() == cum_seq_k.size(), "The size of actual_seq_qlen should be equal to actual_seq_kvlen when TND");
size_t sizeValue = cum_seq_q.size();
TORCH_CHECK(sizeValue <= static_cast<size_t>(std::numeric_limits<int64_t>::max()), "cum_seq_q.size() is too large to be represented as an int64_t", OPS_ERROR(ErrCode::PARAM));
int64_t b = static_cast<int64_t>(sizeValue);
TORCH_CHECK(b != 0, "Divisor b may be 0, please check it.")
std::vector<int64_t> new_query_shape = {b, q_1, query[0] / b, q_2};
std::vector<int64_t> new_key_shape = {b, k_1, key[0] / b, k_2};
std::vector<int64_t> new_value_shape = {b, v_1, value[0] / b, v_2};
std::vector<int64_t> new_grad_out_shape;
if (!grad_out.empty()) {
new_grad_out_shape = new_query_shape;
}
result.emplace_back(new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape);
}
return result;
}
inline int64_t safe_multiply(const std::vector<int64_t>& dims)
{
int64_t product = 1;
for (auto dim : dims) {
if (dim == 0) {
return 0;
}
TORCH_CHECK(product <= INT64_MAX / dim, "Integer overflow in multiply.", OPS_ERROR(ErrCode::PARAM))
product *= dim;
}
return product;
}
inline int64_t safe_sum(const std::initializer_list<int64_t>& values)
{
int64_t sum = 0;
for (auto val : values) {
TORCH_CHECK(sum <= INT64_MAX - val, "Integer overflow in sum.", OPS_ERROR(ErrCode::PARAM));
sum += val;
}
return sum;
}
int64_t sdpa_flop_count(const std::vector<int64_t> query_shape, const std::vector<int64_t> key_shape, const std::vector<int64_t> value_shape)
{
int64_t b;
int64_t h;
int64_t s_q;
int64_t d_q;
int64_t _b2;
int64_t _h2;
int64_t s_k;
int64_t _d2;
int64_t _b3;
int64_t _h3;
int64_t _s3;
int64_t d_v;
b = query_shape[0];
h = query_shape[1];
s_q = query_shape[2];
d_q = query_shape[3];
_b2 = key_shape[0];
_h2 = key_shape[1];
s_k = key_shape[2];
_d2 = key_shape[3];
_b3 = value_shape[0];
_h3 = value_shape[1];
_s3 = value_shape[2];
d_v = value_shape[3];
TORCH_CHECK(b == _b2 && b == _b3, "the dim of 0 is not equal between q and kv");
TORCH_CHECK(h == _h2 && h == _h3, "the dim of 1 is not equal between q and kv");
TORCH_CHECK(s_k == _s3, "the dim of 2 is not equal between k and v");
TORCH_CHECK(d_q == _d2, "the dim of 3 is not equal between q and k");
int64_t total_flops = safe_sum({
safe_multiply({b, h, s_q, d_q, s_k}),
safe_multiply({b, h, s_q, s_k, d_v})
});
return total_flops;
}
int64_t sdpa_backward_flop_count(const std::vector<int64_t> query_shape, const std::vector<int64_t> key_shape, const std::vector<int64_t> value_shape, const std::vector<int64_t> grad_out_shape)
{
int64_t b;
int64_t h;
int64_t s_q;
int64_t d_q;
int64_t _b2;
int64_t _h2;
int64_t s_k;
int64_t _d2;
int64_t _b3;
int64_t _h3;
int64_t _s3;
int64_t d_v;
int64_t _b4;
int64_t _h4;
int64_t _s4;
int64_t d_4;
b = query_shape[0];
h = query_shape[1];
s_q = query_shape[2];
d_q = query_shape[3];
_b2 = key_shape[0];
_h2 = key_shape[1];
s_k = key_shape[2];
_d2 = key_shape[3];
_b3 = value_shape[0];
_h3 = value_shape[1];
_s3 = value_shape[2];
d_v = value_shape[3];
_b4 = grad_out_shape[0];
_h4 = grad_out_shape[1];
_s4 = grad_out_shape[2];
d_4 = grad_out_shape[3];
TORCH_CHECK(b == _b2 && b == _b3 && b == _b4, "the dim of 0 is not equal between qkv and grad");
TORCH_CHECK(h == _h2 && h == _h3 && h == _h4, "the dim of 1 is not equal between qkv and grad");
TORCH_CHECK(s_k == _s3, "the dim of 2 is not equal between k and v");
TORCH_CHECK(s_q == _s4, "the dim of 2 is not equal between q and grad");
TORCH_CHECK(d_q == _d2, "the dim of 3 is not equal between q and k");
TORCH_CHECK(d_v == d_4, "the dim of 3 is not equal between v and grad");
int64_t total_flops = safe_sum({
safe_multiply({b, h, s_q, d_v, s_k}),
safe_multiply({b, h, s_k, s_q, d_v}),
safe_multiply({b, h, s_q, s_k, d_q}),
safe_multiply({b, h, d_q, s_q, s_k})
});
return total_flops;
}
int64_t FlopCounter::flash_attention_forward_flop(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, int64_t head_num,
const std::string &input_layout, const c10::OptionalIntArrayRef &actual_seq_qlen,
const c10::OptionalIntArrayRef &actual_seq_kvlen)
{
std::vector<int64_t> grad_out_shape;
std::vector<int64_t> query_shape(query.sizes().begin(), query.sizes().end());
std::vector<int64_t> key_shape(key.sizes().begin(), key.sizes().end());
std::vector<int64_t> value_shape(value.sizes().begin(), value.sizes().end());
auto ac_seq_qlen_tmp = actual_seq_qlen.value_or(c10::ArrayRef<int64_t>{});
auto ac_seq_kvlen_tmp = actual_seq_kvlen.value_or(c10::ArrayRef<int64_t>{});
auto sizes = _unpack_flash_attention_nested_shapes(query_shape, key_shape, value_shape, head_num, grad_out_shape, ac_seq_qlen_tmp, ac_seq_kvlen_tmp, input_layout);
int64_t total_flops = 0;
for (const auto& [query_shape_new, key_shape_new, value_shape_new, _] : sizes) {
total_flops += sdpa_flop_count(query_shape_new, key_shape_new, value_shape_new);
}
return total_flops;
}
int64_t FlopCounter::flash_attention_backward_flop(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, const at::Tensor &dy, int64_t head_num,
const std::string &input_layout, const c10::OptionalIntArrayRef &actual_seq_qlen,
const c10::OptionalIntArrayRef &actual_seq_kvlen)
{
std::vector<int64_t> dy_shape(query.sizes().begin(), query.sizes().end());
std::vector<int64_t> query_shape(query.sizes().begin(), query.sizes().end());
std::vector<int64_t> key_shape(key.sizes().begin(), key.sizes().end());
std::vector<int64_t> value_shape(value.sizes().begin(), value.sizes().end());
auto ac_seq_qlen_tmp = actual_seq_qlen.value_or(c10::ArrayRef<int64_t>{});
auto ac_seq_kvlen_tmp = actual_seq_kvlen.value_or(c10::ArrayRef<int64_t>{});
auto sizes = _unpack_flash_attention_nested_shapes(query_shape, key_shape, value_shape, head_num, dy_shape, ac_seq_qlen_tmp, ac_seq_kvlen_tmp, input_layout);
int64_t total_flops = 0;
for (const auto& [query_shape_new, key_shape_new, value_shape_new, grad_out_shape] : sizes) {
total_flops += sdpa_backward_flop_count(query_shape_new, key_shape_new, value_shape_new, grad_out_shape);
}
return total_flops;
}