#include <torch/extension.h>
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "torch_npu/csrc/core/npu/NPUFormat.h"
#include <torch_npu/csrc/include/ops.h>
#include "inc/aclnn_common.h"
#include "../flop_counter/flop_counter.h"
const static int FLASH_THRESHOLD = 512;
const static int N = 32;
const static int64_t SOFTMAXMAX_LAST_DIMSHAPE = 8;
const static double EPSILON = 0.00000000001;
using namespace at_npu::native;
constexpr static int SIZE_8 = 8;
enum class DropOutStatus {
DROPOUT_NORMAL = 0,
DROPOUT_NONE,
DROPOUT_ALL
};
enum class SparseMode {
NO_MASK = 0,
ALL_MASK,
LEFT_UP_CAUSAL,
RIGHT_DOWN_CAUSAL,
BAND,
PREFIX,
PREFIX_COMPRESS,
RIGHT_DOWN_CAUSAL_BAND,
BAND_LEFT_UP_CAUSAL
};
DropOutStatus get_dropout_status(double keep_prob)
{
if (std::abs(keep_prob) < EPSILON) {
return DropOutStatus::DROPOUT_ALL;
}
if (std::abs(keep_prob) < 1 + EPSILON) {
return DropOutStatus::DROPOUT_NONE;
}
return DropOutStatus::DROPOUT_NORMAL;
}
at::Tensor format_trans(const at::Tensor &at_tensor)
{
if (at_tensor.defined()) {
TORCH_CHECK(torch_npu::utils::is_npu(at_tensor), "only npu tensor is supported");
return at_npu::native::npu_format_cast(at_tensor, ACL_FORMAT_ND);
}
return at_tensor;
}
at::Tensor dropout_gen_mask(const at::Tensor &query, const at::Tensor &key, double keep_prob, int64_t head_num, const std::string &input_layout,
bool gen_mask_parallel, bool sync, int64_t &seed, int64_t &offset, int64_t &numels)
{
at::Tensor drop_mask;
if (input_layout == "BSH") {
numels = query.size(0) * head_num * query.size(1) * key.size(1);
} else if (input_layout == "SBH") {
numels = query.size(1) * head_num * query.size(0) * key.size(0);
} else if (input_layout == "BNSD") {
numels = query.size(0) * query.size(1) * query.size(2) * key.size(2);
} else if (input_layout == "BSND") {
numels = query.size(0) * query.size(2) * query.size(1) * key.size(1);
}
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += 32;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
const auto gen = at_npu::detail::getDefaultNPUGenerator();
auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(10);
seed = pair.first;
offset = pair.second;
drop_mask = at_npu::native::npu_dropout_gen_mask(query, at::IntArrayRef{ numels }, 1 - keep_prob,
seed, offset, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
return drop_mask;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_backward_v2(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
const std::string &input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &drop_mask,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tokens,
int64_t next_tokens,
int64_t inner_precise,
const c10::optional<std::vector<int64_t>> &prefix,
const c10::optional<std::vector<int64_t>> &actual_seq_qlen,
const c10::optional<std::vector<int64_t>> &actual_seq_kvlen,
const c10::optional<std::vector<int64_t>> &q_start_idx,
const c10::optional<std::vector<int64_t>> &kv_start_idx,
int64_t sparse_mode,
int64_t pse_type)
{
double scale = scale_value;
const at::Tensor &pse_const = pse.value_or(at::Tensor());
const at::Tensor &drop_mask_const = drop_mask.value_or(at::Tensor());
const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
const at::Tensor &softmax_max_const = softmax_max.value_or(at::Tensor());
const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
auto prefixN_tmp = prefix.value_or(std::vector<int64_t>{});
auto ac_seq_qlen_tmp = actual_seq_qlen.value_or(std::vector<int64_t>{});
auto ac_seq_kvlen_tmp = actual_seq_kvlen.value_or(std::vector<int64_t>{});
auto q_start_idx_val_tmp = q_start_idx.value_or(std::vector<int64_t>{});
auto kv_start_idx_val_tmp = kv_start_idx.value_or(std::vector<int64_t>{});
c10::optional<at::IntArrayRef> prefixN(prefixN_tmp);
c10::optional<at::IntArrayRef> ac_seq_qlen(ac_seq_qlen_tmp);
c10::optional<at::IntArrayRef> ac_seq_kvlen(ac_seq_kvlen_tmp);
c10::optional<at::IntArrayRef> q_start_idx_val(q_start_idx_val_tmp);
c10::optional<at::IntArrayRef> kv_start_idx_val(kv_start_idx_val_tmp);
at::Tensor format_query = format_trans(query);
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_dy = format_trans(dy);
at::Tensor format_pse = format_trans(pse_const);
at::Tensor format_drop_mask = format_trans(drop_mask_const);
at::Tensor format_padding_mask = format_trans(padding_mask_const);
at::Tensor format_atten_mask = format_trans(atten_mask_const);
at::Tensor format_softmax_max = format_trans(softmax_max_const);
at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
at::Tensor format_softmax = format_trans(softmax_const);
at::Tensor format_attention = format_trans(attention_const);
at::Tensor dq = at::empty(format_query.sizes(), format_query.options());
at::Tensor dk = at::empty(format_key.sizes(), format_key.options());
at::Tensor dv = at::empty(format_value.sizes(), format_value.options());
char* input_layout_ptr = const_cast<char *>(input_layout.c_str());
at::Tensor dpse;
if (format_pse.defined()) {
dpse = at::empty(format_pse.sizes(), format_pse.options());
} else {
dpse = at::empty({0}, query.options());
}
if (!ac_seq_qlen_tmp.empty() && !ac_seq_kvlen_tmp.empty()) {
ACLNN_CMD(
aclnnFlashAttentionUnpaddingScoreGradV2, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, ac_seq_qlen, ac_seq_kvlen, q_start_idx_val, kv_start_idx_val,
scale_value, keep_prob, pre_tokens, next_tokens, head_num, input_layout_ptr, inner_precise, sparse_mode, pse_type,
dq, dk, dv, dpse);
} else {
ACLNN_CMD(
aclnnFlashAttentionScoreGradV2, format_query, format_key, format_value, format_dy,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
format_softmax_sum, format_softmax, format_attention, prefixN, q_start_idx_val, kv_start_idx_val, scale_value, keep_prob,
pre_tokens, next_tokens, head_num, input_layout_ptr, inner_precise, sparse_mode, pse_type, dq, dk, dv, dpse);
}
if (!format_pse.defined()) {
at::Tensor dpse_required;
dpse = dpse_required;
}
#ifdef FLOP_COUNT
FLOP_COUNT(FlopCounter::flash_attention_backward_flop, query, key, value, dy, head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
#endif
return std::make_tuple(dq, dk, dv, dpse);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_fusion_attention_grad_v2(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &dy,
int64_t head_num,
const std::string &input_layout,
const c10::optional<at::Tensor> &pse,
const c10::optional<at::Tensor> &padding_mask,
const c10::optional<at::Tensor> &atten_mask,
const c10::optional<at::Tensor> &softmax_max,
const c10::optional<at::Tensor> &softmax_sum,
const c10::optional<at::Tensor> &softmax_in,
const c10::optional<at::Tensor> &attention_in,
double scale_value,
double keep_prob,
int64_t pre_tokens,
int64_t next_tokens,
int64_t inner_precise,
int64_t seed,
int64_t offset,
int64_t numels,
const c10::optional<std::vector<int64_t>> &prefix,
const c10::optional<std::vector<int64_t>> &actual_seq_qlen,
const c10::optional<std::vector<int64_t>> &actual_seq_kvlen,
int64_t sparse_mode,
bool gen_mask_parallel,
bool sync,
int64_t pse_type,
const c10::optional<std::vector<int64_t>> &q_start_idx,
const c10::optional<std::vector<int64_t>> &kv_start_idx)
{
TORCH_CHECK(query.dim() == 3 || query.dim() == 4, "The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional");
TORCH_CHECK(key.dim() == 3 || key.dim() == 4, "The shapes of the input key should be 3 or 4 dimensional, but got ",
key.dim(), "-dimensional");
TORCH_CHECK(value.dim() == 3 || value.dim() == 4, "The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional");
TORCH_CHECK(dy.dim() == 3 || dy.dim() == 4, "The shapes of the input dy should be 3 or 4 dimensional, but got ", dy.dim(), "-dimensional");
TORCH_CHECK(keep_prob >= 0 && keep_prob <= 1, "The keep_prob value must be in range of [0, 1], but got ", keep_prob);
TORCH_CHECK(pse_type >= 0 && pse_type <= 3, "The pse_type value must be in range of [0, 3], but got ", pse_type);
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode);
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode);
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ", input_layout);
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
length += 32;
at::Tensor drop_mask;
if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
drop_mask = at_npu::native::npu_dropout_gen_mask(query, at::IntArrayRef{ numels }, 1 - keep_prob,
seed, offset, gen_mask_parallel, sync);
} else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
}
auto result = npu_fusion_attention_backward_v2(query,
key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
softmax_max, softmax_sum, softmax_in, attention_in, scale_value, keep_prob, pre_tokens,
next_tokens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, q_start_idx, kv_start_idx, sparse_mode, pse_type);
if (!sync && get_dropout_status(keep_prob) != DropOutStatus::DROPOUT_NONE) {
c10::Device device = drop_mask.device();
c10::impl::VirtualGuardImpl impl(device.type());
impl.recordDataPtrOnStream(drop_mask.storage().data_ptr(), c10_npu::getCurrentNPUStream());
}
return result;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_fusion_attention_v2(
const at::Tensor &query, const at::Tensor &key,
const at::Tensor &value, int64_t head_num, const std::string &input_layout,
const c10::optional<at::Tensor> &pse_opt, const c10::optional<at::Tensor> &padding_mask_opt,
const c10::optional<at::Tensor> &atten_mask_opt,
double scale, double keep_prob, int64_t pre_tokens, int64_t next_tokens, int64_t inner_precise,
const c10::optional<std::vector<int64_t>> &prefix_opt, const c10::optional<std::vector<int64_t>> &actual_seq_qlen,
const c10::optional<std::vector<int64_t>> &actual_seq_kvlen, int64_t sparse_mode, bool gen_mask_parallel, bool sync,
int64_t pse_type, const c10::optional<std::vector<int64_t>> &q_start_idx, const c10::optional<std::vector<int64_t>> &kv_start_idx)
{
const at::Tensor &pse = pse_opt.value_or(at::Tensor());
const at::Tensor &padding_mask = padding_mask_opt.value_or(at::Tensor());
const at::Tensor &atten_mask = atten_mask_opt.value_or(at::Tensor());
auto prefixN_tmp = prefix_opt.value_or(std::vector<int64_t>{});
auto ac_seq_qlen_tmp = actual_seq_qlen.value_or(std::vector<int64_t>{});
auto ac_seq_kvlen_tmp = actual_seq_kvlen.value_or(std::vector<int64_t>{});
auto q_start_idx_val_tmp = q_start_idx.value_or(std::vector<int64_t>{});
auto kv_start_idx_val_tmp = kv_start_idx.value_or(std::vector<int64_t>{});
c10::optional<at::IntArrayRef> prefixN(prefixN_tmp);
c10::optional<at::IntArrayRef> ac_seq_qlen(ac_seq_qlen_tmp);
c10::optional<at::IntArrayRef> ac_seq_kvlen(ac_seq_kvlen_tmp);
c10::optional<at::IntArrayRef> q_start_idx_val(q_start_idx_val_tmp);
c10::optional<at::IntArrayRef> kv_start_idx_val(kv_start_idx_val_tmp);
TORCH_CHECK(head_num > 0, "head_num must > 0, but got ", head_num);
TORCH_CHECK(query.dim() == 3 || query.dim() == 4, "The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional");
TORCH_CHECK(key.dim() == 3 || key.dim() == 4, "The shapes of the input key should be 3 or 4 dimensional, but got ", key.dim(),
"-dimensional");
TORCH_CHECK(value.dim() == 3 || value.dim() == 4, "The shapes of the input value should be 3 or 4 dimensional, but got ",
value.dim(), "-dimensional");
TORCH_CHECK(keep_prob >= 0 && keep_prob <= 1, "The keep_prob value must be in range of [0, 1], but got ", keep_prob);
TORCH_CHECK(pse_type >= 0 && pse_type <= 3, "The pse_type value must be in range of [0, 3], but got ", pse_type);
std::string input_layout_str = std::string(input_layout);
if (input_layout_str == "TND") {
TORCH_CHECK((sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode < static_cast<int64_t>(SparseMode::PREFIX)) ||
(sparse_mode > static_cast<int64_t>(SparseMode::PREFIX) &&
sparse_mode <= static_cast<int64_t>(SparseMode::BAND_LEFT_UP_CAUSAL)),
"The sparse_mode value must be in range of [0,5) or (5,8], but got ",
sparse_mode);
} else {
TORCH_CHECK(sparse_mode >= static_cast<int64_t>(SparseMode::NO_MASK) &&
sparse_mode <= static_cast<int64_t>(SparseMode::PREFIX_COMPRESS),
"The sparse_mode value must be in range of [0,6], but got ",
sparse_mode);
}
for (auto &c : input_layout_str) {
c = toupper(c);
}
TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
input_layout_str == "BSND" || input_layout_str == "TND",
"The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ", input_layout);
int64_t B = 0;
int64_t S0 = 0;
int64_t S1 = 0;
int64_t N = 0;
int64_t D = 0;
int64_t H = 0;
int64_t T = 0;
int64_t D2 = 0;
c10::SmallVector<int64_t> atten_score_shape;
if (input_layout_str == "BSH") {
B = query.size(0);
S0 = query.size(1);
S1 = key.size(1);
H = query.size(2);
D = H / head_num;
D2 = (!D || !key.size(2)) ? 0 : value.size(2) / (key.size(2) / D);
atten_score_shape = {B, S0, head_num * D2};
} else if (input_layout_str == "SBH") {
B = query.size(1);
S0 = query.size(0);
S1 = key.size(0);
H = query.size(2);
D = H / head_num;
D2 = (!D || !key.size(2)) ? 0 : value.size(2) / (key.size(2) / D);
atten_score_shape = {S0, B, head_num * D2};
} else if (input_layout_str == "BNSD") {
B = query.size(0);
N = query.size(1);
S0 = query.size(2);
S1 = key.size(2);
D = query.size(3);
D2 = value.size(3);
atten_score_shape = {B, N, S0, D2};
} else if (input_layout_str == "BSND") {
B = query.size(0);
N = query.size(2);
S0 = query.size(1);
S1 = key.size(1);
D = query.size(3);
D2 = value.size(3);
atten_score_shape = {B, S0, N, D2};
} else if (input_layout_str == "TND") {
T = query.size(0);
N = query.size(1);
D = query.size(2);
D2 = value.size(2);
atten_score_shape = {T, N, D2};
}
double scale_value = scale;
at::Tensor format_query = format_trans(query);
at::Tensor attention_score = at::empty(atten_score_shape, query.options());
at::Tensor format_key = format_trans(key);
at::Tensor format_value = format_trans(value);
at::Tensor format_pse = format_trans(pse);
at::Tensor format_padding_mask = format_trans(padding_mask);
at::Tensor format_atten_mask = format_trans(atten_mask);
int64_t seed;
int64_t offset;
int64_t numels;
for(size_t i = 0; i < ac_seq_qlen_tmp.size(); i++){
TORCH_CHECK(ac_seq_qlen_tmp[i] <= 1000000 && ac_seq_kvlen_tmp[i] <= 1000000, "The sequence length should not greater than 1M, but got q", ac_seq_qlen_tmp[i],"kv", ac_seq_kvlen_tmp[i]);
}
if (input_layout_str == "TND" && ac_seq_qlen_tmp.size() == ac_seq_kvlen_tmp.size()) {
numels = N;
int64_t accum = ac_seq_qlen_tmp[0] * ac_seq_kvlen_tmp[0];
for (size_t i = 1; i < ac_seq_qlen_tmp.size(); i++) {
accum += ((ac_seq_qlen_tmp[i] - ac_seq_qlen_tmp[i - 1]) * (ac_seq_kvlen_tmp[i] - ac_seq_kvlen_tmp[i - 1]));
}
numels *= accum;
}
at::Tensor format_drop_mask = dropout_gen_mask(format_query, format_key, keep_prob, head_num, input_layout_str,
gen_mask_parallel, sync, seed, offset, numels);
at::Tensor softmax_max;
at::Tensor softmax_sum;
at::Tensor softmax_out;
if (input_layout_str != "TND") {
softmax_max = at::empty({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat));
softmax_sum = at::empty({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat));
} else {
softmax_max = at::empty({T, N, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat));
softmax_sum = at::empty({T, N, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat));
}
softmax_out = at::empty({0}, query.options());
char* input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
if (!ac_seq_qlen_tmp.empty() && !ac_seq_kvlen_tmp.empty()) {
ACLNN_CMD(
aclnnFlashAttentionVarLenScoreV2, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
ac_seq_qlen, ac_seq_kvlen, q_start_idx_val, kv_start_idx_val, scale, keep_prob, pre_tokens, next_tokens, head_num,
input_layout_ptr, inner_precise, sparse_mode, pse_type, softmax_max, softmax_sum,
softmax_out, attention_score);
} else {
ACLNN_CMD(
aclnnFlashAttentionScoreV2, format_query, format_key, format_value,
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN, q_start_idx_val, kv_start_idx_val,
scale, keep_prob, pre_tokens, next_tokens, head_num, input_layout_ptr, inner_precise,
sparse_mode, pse_type, softmax_max, softmax_sum, softmax_out, attention_score);
}
if (!sync && get_dropout_status(keep_prob) != DropOutStatus::DROPOUT_NONE) {
c10::Device device = format_drop_mask.device();
c10::impl::VirtualGuardImpl impl(device.type());
impl.recordDataPtrOnStream(format_drop_mask.storage().data_ptr(), c10_npu::getCurrentNPUStream());
}
#ifdef FLOP_COUNT
FLOP_COUNT(FlopCounter::flash_attention_forward_flop, query, key, value, head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
#endif
return std::make_tuple(attention_score, softmax_max, softmax_sum, softmax_out,
seed, offset, numels);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("npu_fusion_attention_v2", &npu_fusion_attention_v2, "fusion attention forward v2");
m.def("npu_fusion_attention_grad_v2", &npu_fusion_attention_grad_v2, "fusion attention backward v2");
}