#include <acl/acl.h>
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/custom_functions/atb/AtbCommon.h"
using namespace std;
namespace atb {
using PagedAttentionParam = atb::infer::PagedAttentionParam;
at::Tensor& _npu_paged_attention_v2(
const at::Tensor &query,
const at::Tensor &key_cache,
const at::Tensor &block_table,
c10::SymIntArrayRef context_lens,
const c10::optional<at::Tensor> &value_cache,
const c10::optional<at::Tensor> &mask,
int64_t num_kv_heads,
int64_t num_heads,
double scale_value,
int64_t mask_type,
const c10::optional<at::Tensor> &workspace,
at::Tensor &out)
{
const c10::OptionalDeviceGuard device_guard(device_of(query));
OpParamCache<PagedAttentionParam>& pagedAttentionParamCache = OpParamCache<PagedAttentionParam>::getInstance();
PagedAttentionParam pagedparam;
pagedparam.headNum = num_heads;
pagedparam.qkScale = scale_value;
pagedparam.kvHeadNum = num_kv_heads;
auto masktype = static_cast<PagedAttentionParam::MaskType>(mask_type);
pagedparam.maskType = masktype;
pagedparam.batchRunStatusEnable = false;
pagedparam.quantType = PagedAttentionParam::TYPE_QUANT_UNDEFINED;
pagedparam.outDataType = ACL_DT_UNDEFINED;
pagedparam.hasQuantOffset = false;
pagedparam.compressType = PagedAttentionParam::COMPRESS_TYPE_UNDEFINED;
pagedparam.calcType = PagedAttentionParam::CALC_TYPE_UNDEFINED;
pagedparam.scaleType = PagedAttentionParam::SCALE_TYPE_TOR;
pagedparam.inputLayout = atb::infer::TYPE_BSND;
pagedparam.mlaVHeadSize = 0;
ParamSetter paramsetter;
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
if (pagedparam.maskType == PagedAttentionParam::UNDEFINED) {
paramsetter.Input(query)
.Input(key_cache)
.Input(value_cache)
.Input(block_table)
.Input(context_lens_tensor)
.Output(out);
} else if (pagedparam.maskType == PagedAttentionParam::MASK_TYPE_ALIBI) {
paramsetter.Input(query)
.Input(key_cache)
.Input(value_cache)
.Input(block_table)
.Input(context_lens_tensor)
.Input(mask)
.Output(out);
}
auto opPaged = pagedAttentionParamCache.getOperation(pagedparam, "PagedAttentionOperation");
if (workspace.has_value() && workspace.value().defined()) {
RunAtbCmdWithWorkspace(opPaged, paramsetter, "PagedAttentionOperation", workspace.value());
} else {
RunAtbCmd(opPaged, paramsetter, "PagedAttentionOperation");
}
return out;
}
at::Tensor _npu_paged_attention_v2_get_workspace(
const at::Tensor &query,
const at::Tensor &key_cache,
const at::Tensor &block_table,
c10::SymIntArrayRef context_lens,
const c10::optional<at::Tensor> &value_cache,
const c10::optional<at::Tensor> &mask,
int64_t num_kv_heads,
int64_t num_heads,
double scale_value,
int64_t mask_type,
at::Tensor &out)
{
const c10::OptionalDeviceGuard device_guard(device_of(query));
OpParamCache<PagedAttentionParam>& pagedAttentionParamCache = OpParamCache<PagedAttentionParam>::getInstance();
PagedAttentionParam pagedparam;
pagedparam.headNum = num_heads;
pagedparam.qkScale = scale_value;
pagedparam.kvHeadNum = num_kv_heads;
auto masktype = static_cast<PagedAttentionParam::MaskType>(mask_type);
pagedparam.maskType = masktype;
pagedparam.batchRunStatusEnable = false;
pagedparam.quantType = PagedAttentionParam::TYPE_QUANT_UNDEFINED;
pagedparam.outDataType = ACL_DT_UNDEFINED;
pagedparam.hasQuantOffset = false;
pagedparam.compressType = PagedAttentionParam::COMPRESS_TYPE_UNDEFINED;
pagedparam.calcType = PagedAttentionParam::CALC_TYPE_UNDEFINED;
pagedparam.scaleType = PagedAttentionParam::SCALE_TYPE_TOR;
pagedparam.inputLayout = atb::infer::TYPE_BSND;
pagedparam.mlaVHeadSize = 0;
ParamSetter paramsetter;
at::Tensor context_lens_tensor = at::tensor(c10::asIntArrayRefUnchecked(context_lens), at::kInt);
if (pagedparam.maskType == PagedAttentionParam::UNDEFINED) {
paramsetter.Input(query)
.Input(key_cache)
.Input(value_cache)
.Input(block_table)
.Input(context_lens_tensor)
.Output(out);
} else if (pagedparam.maskType == PagedAttentionParam::MASK_TYPE_ALIBI) {
paramsetter.Input(query)
.Input(key_cache)
.Input(value_cache)
.Input(block_table)
.Input(context_lens_tensor)
.Input(mask)
.Output(out);
}
auto opPaged = pagedAttentionParamCache.getOperation(pagedparam, "PagedAttentionOperation");
uint64_t workspace_size = GetAtbWorkspaceSizeCmd(opPaged, paramsetter, "PagedAttentionOperation");
at::Tensor workspace_tensor = at::empty({workspace_size}, query.options().dtype(at::kByte));
return workspace_tensor;
}
namespace {
TORCH_LIBRARY_FRAGMENT(atb, m)
{
m.def("_npu_paged_attention_v2.out(Tensor query, Tensor key_cache, Tensor block_table, SymInt[] context_lens, *, Tensor? value_cache=None, Tensor? mask=None, int num_kv_heads=0, int num_heads=0, float scale_value=1.0, int mask_type=0, Tensor? workspace=None, Tensor(a!) out) -> Tensor(a!)");
m.def("_npu_paged_attention_v2_get_workspace(Tensor query, Tensor key_cache, Tensor block_table, SymInt[] context_lens, *, Tensor? value_cache=None, Tensor? mask=None, int num_kv_heads=0, int num_heads=0, float scale_value=1.0, int mask_type=0, Tensor(a!) out) -> Tensor");
}
}
namespace {
TORCH_LIBRARY_IMPL(atb, PrivateUse1, m)
{
m.impl("_npu_paged_attention_v2.out", TORCH_FN(atb::_npu_paged_attention_v2));
m.impl("_npu_paged_attention_v2_get_workspace", TORCH_FN(atb::_npu_paged_attention_v2_get_workspace));
}
}
}