#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/custom_functions/atb/AtbCommon.h"
#include "op_plugin/utils/custom_functions/atb/Utils.h"
namespace atb {
namespace {
std::unordered_map<c10::string_view, int> mask_type_map = {
{"undefined", 0},
{"mask_type_spec", 1},
{"mask_type_free", 2}
};
std::unordered_map<c10::string_view, int> calc_type_map = {
{"calc_type_undefined", 0},
{"calc_type_spec", 1},
{"calc_type_ring", 2}
};
std::unordered_map<c10::string_view, int> cache_mode_map = {
{"krope_ctkv", 1},
{"int8_nzcache", 2},
{"nzcache", 3}
};
std::tuple<int, int, int> get_mla_mode(c10::optional<c10::string_view> mask_type_opt,
c10::optional<c10::string_view> calc_type_opt,
c10::optional<c10::string_view> cache_mode_opt)
{
int mask_type = atb::utils::get_op_mode(
mask_type_map, mask_type_opt, "undefined", "mask_type");
int calc_type = atb::utils::get_op_mode(
calc_type_map, calc_type_opt, "calc_type_undefined", "calc_type");
int cache_mode = atb::utils::get_op_mode(
cache_mode_map, cache_mode_opt, "krope_ctkv", "cache_mode");
return std::make_tuple(mask_type, calc_type, cache_mode);
}
}
at::Tensor npu_multi_head_latent_attention(const at::Tensor & q_nope, const at::Tensor & q_rope, const at::Tensor & ctkv, const at::Tensor & k_rope, const at::Tensor & block_tables,
c10::SymIntArrayRef context_lens, int64_t head_num, double qk_scale, int64_t kv_headnum, const c10::optional<at::Tensor> & mask, c10::OptionalArrayRef<c10::SymInt> qseqlen,
const c10::optional<at::Tensor> & qk_descale, const c10::optional<at::Tensor> & pv_descale, c10::optional<c10::string_view> mask_type_opt, c10::optional<c10::string_view> calc_type_opt,
c10::optional<c10::string_view> cache_mode_opt)
{
const c10::OptionalDeviceGuard device_guard(device_of(q_nope));
at::Tensor output = at::empty(q_nope.sizes(), q_rope.options());
at::Tensor lse;
float qkScale_float = static_cast<float>(qk_scale);
auto mode = get_mla_mode(mask_type_opt, calc_type_opt, cache_mode_opt);
int mask_type = std::get<0>(mode);
int calc_type = std::get<1>(mode);
int cache_mode = std::get<2>(mode);
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
at::Tensor qseqlen_tensor = qseqlen.has_value()? at::tensor(c10::asIntArrayRefUnchecked(qseqlen.value()), at::kInt): at::Tensor();
EXEC_ATB_CMD(AtbMLA, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
return output;
}
at::Tensor& npu_multi_head_latent_attention_out(const at::Tensor & q_nope, const at::Tensor & q_rope, const at::Tensor & ctkv, const at::Tensor & k_rope, const at::Tensor & block_tables,
c10::SymIntArrayRef context_lens, int64_t head_num, double qk_scale, int64_t kv_headnum, const c10::optional<at::Tensor> & mask, c10::OptionalArrayRef<c10::SymInt> qseqlen,
const c10::optional<at::Tensor> & qk_descale, const c10::optional<at::Tensor> & pv_descale, c10::optional<c10::string_view> mask_type_opt, c10::optional<c10::string_view> calc_type_opt,
c10::optional<c10::string_view> cache_mode_opt, at::Tensor & output)
{
const c10::OptionalDeviceGuard device_guard(device_of(q_nope));
float qkScale_float = static_cast<float>(qk_scale);
at::Tensor lse;
auto mode = get_mla_mode(mask_type_opt, calc_type_opt, cache_mode_opt);
int mask_type = std::get<0>(mode);
int calc_type = std::get<1>(mode);
int cache_mode = std::get<2>(mode);
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
at::Tensor qseqlen_tensor = qseqlen.has_value()? at::tensor(c10::asIntArrayRefUnchecked(qseqlen.value()), at::kInt): at::Tensor();
EXEC_ATB_CMD(AtbMLA, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
return output;
}
std::tuple<at::Tensor, at::Tensor> npu_multi_head_latent_attention_lse(const at::Tensor & q_nope, const at::Tensor & q_rope, const at::Tensor & ctkv, const at::Tensor & k_rope, const at::Tensor & block_tables,
c10::SymIntArrayRef context_lens, int64_t head_num, double qk_scale, int64_t kv_headnum, bool return_lse, const c10::optional<at::Tensor> & mask, c10::OptionalArrayRef<c10::SymInt> qseqlen,
const c10::optional<at::Tensor> & qk_descale, const c10::optional<at::Tensor> & pv_descale, c10::optional<c10::string_view> mask_type_opt, c10::optional<c10::string_view> calc_type_opt,
c10::optional<c10::string_view> cache_mode_opt, const c10::optional<at::Tensor> &workspace)
{
const c10::OptionalDeviceGuard device_guard(device_of(q_nope));
at::Tensor output = at::empty(q_nope.sizes(), q_rope.options());
at::Tensor lse;
float qkScale_float = static_cast<float>(qk_scale);
auto mode = get_mla_mode(mask_type_opt, calc_type_opt, cache_mode_opt);
int mask_type = std::get<0>(mode);
int calc_type = std::get<1>(mode);
if (calc_type == calc_type_map["calc_type_ring"]) {
lse = at::empty({q_nope.sizes()[0], q_nope.sizes()[1], 1}, q_rope.options());
}
int cache_mode = std::get<2>(mode);
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
at::Tensor qseqlen_tensor = qseqlen.has_value()? at::tensor(c10::asIntArrayRefUnchecked(qseqlen.value()), at::kInt): at::Tensor();
if (workspace.has_value() && workspace.value().defined()) {
auto workspace_tensor = workspace.value();
EXEC_ATB_CMD_WITH_WORKSPACE(AtbMLA, workspace_tensor, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
} else {
EXEC_ATB_CMD(AtbMLA, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
}
if (return_lse) {
return std::make_tuple(output, lse);
} else {
return std::make_tuple(output, at::Tensor());
}
}
at::Tensor _npu_multi_head_latent_attention_lse_get_workspace(const at::Tensor & q_nope, const at::Tensor & q_rope, const at::Tensor & ctkv, const at::Tensor & k_rope, const at::Tensor & block_tables,
c10::SymIntArrayRef context_lens, int64_t head_num, double qk_scale, int64_t kv_headnum, bool return_lse, const c10::optional<at::Tensor> & mask, c10::OptionalArrayRef<c10::SymInt> qseqlen,
const c10::optional<at::Tensor> & qk_descale, const c10::optional<at::Tensor> & pv_descale, c10::optional<c10::string_view> mask_type_opt, c10::optional<c10::string_view> calc_type_opt,
c10::optional<c10::string_view> cache_mode_opt)
{
const c10::OptionalDeviceGuard device_guard(device_of(q_nope));
at::Tensor output = at::empty(q_nope.sizes(), q_rope.options());
at::Tensor lse;
float qkScale_float = static_cast<float>(qk_scale);
auto mode = get_mla_mode(mask_type_opt, calc_type_opt, cache_mode_opt);
int mask_type = std::get<0>(mode);
int calc_type = std::get<1>(mode);
if (calc_type == calc_type_map["calc_type_ring"]) {
lse = at::empty({q_nope.sizes()[0], q_nope.sizes()[1], 1}, q_rope.options());
}
int cache_mode = std::get<2>(mode);
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
at::Tensor qseqlen_tensor = qseqlen.has_value()? at::tensor(c10::asIntArrayRefUnchecked(qseqlen.value()), at::kInt): at::Tensor();
uint64_t workspace_size = EXEC_GET_ATB_MAX_WORKSPACE_CMD(AtbMLA, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
at::Tensor workspace_tensor = at::empty({workspace_size}, q_nope.options().dtype(at::kByte));
return workspace_tensor;
}
std::tuple<at::Tensor&, at::Tensor&> npu_multi_head_latent_attention_lse_out(const at::Tensor & q_nope, const at::Tensor & q_rope, const at::Tensor & ctkv, const at::Tensor & k_rope, const at::Tensor & block_tables,
c10::SymIntArrayRef context_lens, int64_t head_num, double qk_scale, int64_t kv_headnum, bool return_lse, const c10::optional<at::Tensor> & mask, c10::OptionalArrayRef<c10::SymInt> qseqlen,
const c10::optional<at::Tensor> & qk_descale, const c10::optional<at::Tensor> & pv_descale, c10::optional<c10::string_view> mask_type_opt, c10::optional<c10::string_view> calc_type_opt,
c10::optional<c10::string_view> cache_mode_opt, const c10::optional<at::Tensor> &workspace, at::Tensor & output, at::Tensor & lse)
{
const c10::OptionalDeviceGuard device_guard(device_of(q_nope));
float qkScale_float = static_cast<float>(qk_scale);
auto mode = get_mla_mode(mask_type_opt, calc_type_opt, cache_mode_opt);
int mask_type = std::get<0>(mode);
int calc_type = std::get<1>(mode);
int cache_mode = std::get<2>(mode);
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
at::Tensor qseqlen_tensor = qseqlen.has_value()? at::tensor(c10::asIntArrayRefUnchecked(qseqlen.value()), at::kInt): at::Tensor();
if (workspace.has_value() && workspace.value().defined()) {
auto workspace_tensor = workspace.value();
EXEC_ATB_CMD_WITH_WORKSPACE(AtbMLA, workspace_tensor, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
} else {
EXEC_ATB_CMD(AtbMLA, q_nope, q_rope, ctkv, k_rope, block_tables, context_lens_tensor, mask, qseqlen_tensor, qk_descale, pv_descale, head_num, qkScale_float, kv_headnum, mask_type, calc_type, cache_mode, output, lse);
}
if (return_lse) {
return std::forward_as_tuple(output, lse);
} else {
return std::forward_as_tuple(output, at::Tensor());
}
}
namespace {
TORCH_LIBRARY_FRAGMENT(atb, m)
{
m.def("npu_multi_head_latent_attention(Tensor q_nope, Tensor q_rope, Tensor ctkv, Tensor k_rope, Tensor block_tables, SymInt[] context_lens, int q_headnum, float qk_scale, int kv_headnum,"
" *, Tensor? mask=None, SymInt[]? qseqlen=None, Tensor? qk_descale=None, Tensor? pv_descale=None, str? mask_type=None, str? calc_type=None, str? cache_mode=None) -> Tensor");
m.def("npu_multi_head_latent_attention.out(Tensor q_nope, Tensor q_rope, Tensor ctkv, Tensor k_rope, Tensor block_tables, SymInt[] context_lens, int q_headnum, float qk_scale, int kv_headnum,"
" *, Tensor? mask=None, SymInt[]? qseqlen=None, Tensor? qk_descale=None, Tensor? pv_descale=None, str? mask_type=None, str? calc_type=None, str? cache_mode=None, Tensor(a!) output) -> Tensor(a!)");
m.def("npu_multi_head_latent_attention.lse(Tensor q_nope, Tensor q_rope, Tensor ctkv, Tensor k_rope, Tensor block_tables, SymInt[] context_lens, int q_headnum, float qk_scale, int kv_headnum, bool return_lse,"
" *, Tensor? mask=None, SymInt[]? qseqlen=None, Tensor? qk_descale=None, Tensor? pv_descale=None, str? mask_type=None, str? calc_type=None, str? cache_mode=None, Tensor? workspace=None) -> (Tensor, Tensor)");
m.def("_npu_multi_head_latent_attention_get_workspace.lse(Tensor q_nope, Tensor q_rope, Tensor ctkv, Tensor k_rope, Tensor block_tables, SymInt[] context_lens, int q_headnum, float qk_scale, int kv_headnum, bool return_lse,"
" *, Tensor? mask=None, SymInt[]? qseqlen=None, Tensor? qk_descale=None, Tensor? pv_descale=None, str? mask_type=None, str? calc_type=None, str? cache_mode=None) -> Tensor");
m.def("npu_multi_head_latent_attention.lse_out(Tensor q_nope, Tensor q_rope, Tensor ctkv, Tensor k_rope, Tensor block_tables, SymInt[] context_lens, int q_headnum, float qk_scale, int kv_headnum, bool return_lse,"
" *, Tensor? mask=None, SymInt[]? qseqlen=None, Tensor? qk_descale=None, Tensor? pv_descale=None, str? mask_type=None, str? calc_type=None, str? cache_mode=None, Tensor? workspace=None, Tensor(a!) output, Tensor(b!) lse) -> (Tensor(a!), Tensor(b!))");
}
}
namespace {
TORCH_LIBRARY_IMPL(atb, PrivateUse1, m)
{
m.impl("npu_multi_head_latent_attention", TORCH_FN(atb::npu_multi_head_latent_attention));
m.impl("npu_multi_head_latent_attention.out", TORCH_FN(atb::npu_multi_head_latent_attention_out));
m.impl("npu_multi_head_latent_attention.lse", TORCH_FN(atb::npu_multi_head_latent_attention_lse));
m.impl("_npu_multi_head_latent_attention_get_workspace.lse", TORCH_FN(atb::_npu_multi_head_latent_attention_lse_get_workspace));
m.impl("npu_multi_head_latent_attention.lse_out", TORCH_FN(atb::npu_multi_head_latent_attention_lse_out));
}
}
}