#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using tuple_tensor =
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>;
using tuple_tensors = std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor,
at::Tensor, at::Tensor, at::Tensor, at::Tensor>;
namespace {
static const int64_t FZ_ALIGN_NUM1 = 16;
static const size_t BIAS_BUM1 = 4;
static const int64_t FZ_ALIGN_NUM = 16;
tuple_tensor multi_head_attention_nocheck(const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const at::Tensor &query_weight, const at::Tensor &key_weight,
const at::Tensor &value_weight, const at::Tensor &attn_mask,
const at::Tensor &out_proj_weight, const at::Tensor &query_bias,
const at::Tensor &key_bias, const at::Tensor &value_bias,
const at::Tensor &out_proj_bias, const at::Tensor &mask,
int64_t attn_head_num, int64_t attn_dim_per_head, int64_t src_len,
int64_t tgt_len, double dropout_prob, bool softmax_use_float)
{
TORCH_CHECK(tgt_len > 0 && src_len > 0 && attn_head_num > 0 && attn_dim_per_head > 0,
"tgt_len, src_len, attn_head_num, attn_dim_per_head should not equal zero."
+ OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(tgt_len % FZ_ALIGN_NUM == 0 && src_len % FZ_ALIGN_NUM == 0 && attn_head_num % FZ_ALIGN_NUM == 0 &&
attn_dim_per_head % FZ_ALIGN_NUM == 0,
"tgt_len, src_len, attn_head_num, attn_dim_per_head should align to 16."
+ OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(query.dim() >= 1 && key.dim() >= 1 && value.dim() >= 1, "query, key, value should be at least 1d."
+ OPS_ERROR(ErrCode::PARAM));
auto query_shape = query.sizes();
int64_t batch = query_shape[0] / tgt_len;
auto weight_col = attn_head_num * attn_dim_per_head;
auto query_options = query.options();
auto query_format = npu_preparation::get_tensor_npu_format(query);
at::Tensor y = npu_preparation::apply_tensor_with_format({query_shape[0], weight_col}, query_options, query_format);
at::Tensor dropout_mask = npu_preparation::apply_tensor_with_format(
{batch * attn_head_num * tgt_len * src_len / 8}, query.options().dtype(at::kByte), ACL_FORMAT_ND);
at::Tensor query_res = npu_preparation::apply_tensor_with_format({batch, attn_head_num, tgt_len, attn_dim_per_head},
query_options, query_format);
at::Tensor key_res = npu_preparation::apply_tensor_with_format({batch, attn_head_num, src_len, attn_dim_per_head},
query_options, query_format);
at::Tensor value_res = npu_preparation::apply_tensor_with_format({batch, attn_head_num, src_len, attn_dim_per_head},
query_options, query_format);
at::Tensor attn_scores;
if (softmax_use_float) {
attn_scores = npu_preparation::apply_tensor_with_format({batch, attn_head_num, tgt_len, src_len},
query.options().dtype(at::kFloat), query_format);
} else {
attn_scores = npu_preparation::apply_tensor_with_format({batch, attn_head_num, tgt_len, src_len}, query_options,
query_format);
}
at::Tensor attn_res = npu_preparation::apply_tensor_with_format({batch, attn_head_num, tgt_len, src_len},
query_options, query_format);
at::Tensor context =
npu_preparation::apply_tensor_with_format({query_shape[0], weight_col}, query_options, query_format);
at_npu::native::OpCommand cmd;
cmd.Name("MultiHeadAttention")
.Input(query)
.Input(key)
.Input(value)
.Input(query_weight)
.Input(key_weight)
.Input(value_weight)
.Input(attn_mask)
.Input(out_proj_weight);
if (query_bias.defined()) {
cmd.Input(query_bias);
}
if (key_bias.defined()) {
cmd.Input(key_bias);
}
if (value_bias.defined()) {
cmd.Input(value_bias);
}
if (out_proj_bias.defined()) {
cmd.Input(out_proj_bias);
}
if (mask.defined()) {
cmd.Input(mask);
}
cmd.Output(y)
.Output(dropout_mask)
.Output(query_res)
.Output(key_res)
.Output(value_res)
.Output(attn_scores)
.Output(attn_res)
.Output(context)
.Attr("attn_head_num", attn_head_num)
.Attr("attn_dim_per_head", attn_dim_per_head)
.Attr("src_len", src_len)
.Attr("tgt_len", tgt_len)
.Attr("keep_prob", static_cast<float>(1 - dropout_prob))
.Attr("softmax_use_float", softmax_use_float)
.Run();
return std::make_tuple(y, dropout_mask, query_res, key_res, value_res, attn_scores, attn_res, context);
}
}
tuple_tensors npu_multi_head_attention_backward(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, const at::Tensor &query_weight,
const at::Tensor &key_weight, const at::Tensor &value_weight, const at::Tensor &out_proj_weight,
const c10::optional<at::Tensor> &query_bias_opt, const c10::optional<at::Tensor> &key_bias_opt,
const c10::optional<at::Tensor> &value_bias_opt, const c10::optional<at::Tensor> &out_proj_bias_opt,
const at::Tensor &query_res, const at::Tensor &key_res, const at::Tensor &value_res, const at::Tensor &attn_scores,
const at::Tensor &attn_res, const at::Tensor &context, const at::Tensor &y_grad, const at::Tensor &dropout_mask,
int64_t attn_head_num, int64_t attn_dim_per_head, int64_t src_len, int64_t tgt_len, double dropout_prob,
bool softmax_use_float)
{
const at::Tensor &query_bias = c10::value_or_else(query_bias_opt, [] { return at::Tensor(); });
const at::Tensor &key_bias = c10::value_or_else(key_bias_opt, [] { return at::Tensor(); });
const at::Tensor &value_bias = c10::value_or_else(value_bias_opt, [] { return at::Tensor(); });
const at::Tensor &out_proj_bias = c10::value_or_else(out_proj_bias_opt, [] { return at::Tensor(); });
TORCH_CHECK(tgt_len > 0 && src_len > 0 && attn_head_num > 0 && attn_dim_per_head > 0,
"tgt_len, src_len, attn_head_num, attn_dim_per_head should not equal zero."
+ OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(tgt_len % FZ_ALIGN_NUM1 == 0 && src_len % FZ_ALIGN_NUM1 == 0 && attn_head_num % FZ_ALIGN_NUM1 == 0 &&
attn_dim_per_head % FZ_ALIGN_NUM1 == 0,
"tgt_len, src_len, attn_head_num, attn_dim_per_head should align to 16."
+ OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(query.dim() >= 1 && key.dim() >= 1 && value.dim() >= 1, "query, key, value should be at least 1d."
+ OPS_ERROR(ErrCode::PARAM));
auto query_shape = query.sizes();
int64_t batch = query_shape[0] / tgt_len;
auto weight_col = attn_head_num * attn_dim_per_head;
at::Tensor query_weight_grad = npu_preparation::apply_tensor(query_weight, {weight_col, weight_col});
at::Tensor key_weight_grad = npu_preparation::apply_tensor(key_weight, {weight_col, weight_col});
at::Tensor value_weight_grad = npu_preparation::apply_tensor(value_weight, {weight_col, weight_col});
at::Tensor out_proj_weight_grad = npu_preparation::apply_tensor(out_proj_weight, {weight_col, weight_col});
at::Tensor query_grad = npu_preparation::apply_tensor(query, {query_shape[0], weight_col});
at::Tensor key_grad = npu_preparation::apply_tensor(key, {batch * src_len, weight_col});
at::Tensor value_grad = npu_preparation::apply_tensor(value, {batch * src_len, weight_col});
at::Tensor query_bias_grad = npu_preparation::apply_tensor(query_bias, {1, weight_col});
at::Tensor key_bias_grad = npu_preparation::apply_tensor(key_bias, {1, weight_col});
at::Tensor value_bias_grad = npu_preparation::apply_tensor(value_bias, {1, weight_col});
at::Tensor out_proj_bias_grad = npu_preparation::apply_tensor(out_proj_bias, {1, weight_col});
vector<uint8_t> grad_mask(BIAS_BUM1);
grad_mask.clear();
grad_mask.push_back(query_bias.defined());
grad_mask.push_back(key_bias.defined());
grad_mask.push_back(value_bias.defined());
grad_mask.push_back(out_proj_bias.defined());
at::ArrayRef<uint8_t> bias_grad_mask(grad_mask);
at_npu::native::OpCommand cmd;
cmd.Name("MultiHeadAttentionGrad")
.Input(query)
.Input(key)
.Input(value)
.Input(query_weight)
.Input(key_weight)
.Input(value_weight)
.Input(out_proj_weight)
.Input(query_res)
.Input(key_res)
.Input(value_res)
.Input(attn_scores)
.Input(attn_res)
.Input(context)
.Input(y_grad);
if (dropout_prob > 0) {
cmd.Input(dropout_mask);
}
cmd.Output(query_weight_grad)
.Output(key_weight_grad)
.Output(value_weight_grad)
.Output(out_proj_weight_grad)
.Output(query_grad)
.Output(key_grad)
.Output(value_grad)
.Output(query_bias_grad)
.Output(key_bias_grad)
.Output(value_bias_grad)
.Output(out_proj_bias_grad)
.Attr("attn_head_num", attn_head_num)
.Attr("attn_dim_per_head", attn_dim_per_head)
.Attr("src_len", src_len)
.Attr("tgt_len", tgt_len)
.Attr("keep_prob", static_cast<float>(1 - dropout_prob))
.Attr("softmax_use_float", softmax_use_float)
.Attr("bias_grad_mask", bias_grad_mask)
.Run();
return std::make_tuple(query_weight_grad, key_weight_grad, value_weight_grad, out_proj_weight_grad, query_grad,
key_grad, value_grad, query_bias_grad, key_bias_grad, value_bias_grad, out_proj_bias_grad);
}
tuple_tensor npu_multi_head_attention(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, const at::Tensor &query_weight,
const at::Tensor &key_weight, const at::Tensor &value_weight, const at::Tensor &attn_mask,
const at::Tensor &out_proj_weight, const c10::optional<at::Tensor> &query_bias_opt,
const c10::optional<at::Tensor> &key_bias_opt, const c10::optional<at::Tensor> &value_bias_opt,
const c10::optional<at::Tensor> &out_proj_bias_opt, const c10::optional<at::Tensor> &dropout_mask_opt,
int64_t attn_head_num, int64_t attn_dim_per_head, int64_t src_len, int64_t tgt_len, double dropout_prob,
bool softmax_use_float)
{
const at::Tensor &query_bias = c10::value_or_else(query_bias_opt, [] { return at::Tensor(); });
const at::Tensor &key_bias = c10::value_or_else(key_bias_opt, [] { return at::Tensor(); });
const at::Tensor &value_bias = c10::value_or_else(value_bias_opt, [] { return at::Tensor(); });
const at::Tensor &out_proj_bias = c10::value_or_else(out_proj_bias_opt, [] { return at::Tensor(); });
const at::Tensor &mask = c10::value_or_else(dropout_mask_opt, [] { return at::Tensor(); });
return multi_head_attention_nocheck(query, key, value, query_weight, key_weight, value_weight, attn_mask,
out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias, mask,
attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob,
softmax_use_float);
}
}