#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
std::tuple<at::Tensor, at::Tensor> npu_scatter_pa_kv_cache_functional(
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& key_cache,
const at::Tensor& value_cache,
const at::Tensor& slot_mapping,
const c10::optional<at::Tensor>& compress_lens,
const c10::optional<at::Tensor>& compress_seq_offsets,
const c10::optional<at::Tensor>& seq_lens,
c10::optional<c10::string_view> cache_mode)
{
char* cache_mode_ptr = cache_mode.has_value() ? const_cast<char *>(cache_mode.value().data()) : nullptr;
char* scatter_mode = "None";
c10::SmallVector<int64_t, op_infer::SIZE> strides_size = {1, 1};
at::IntArrayRef strides = at::IntArrayRef(strides_size);
c10::SmallVector<int64_t, op_infer::SIZE> offsets_size = {0, 0};
at::IntArrayRef offsets = at::IntArrayRef(offsets_size);
auto keyCacheClone = key_cache.clone(at::MemoryFormat::Contiguous);
auto valueCacheClone = value_cache.clone(at::MemoryFormat::Contiguous);
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnScatterPaKvCache, key, keyCacheClone, slot_mapping, value,
valueCacheClone, compress_lens, compress_seq_offsets, seq_lens, cache_mode_ptr, scatter_mode, strides, offsets);
return std::make_tuple(keyCacheClone, valueCacheClone);
}
void npu_scatter_pa_kv_cache(
const at::Tensor& key,
const at::Tensor& value,
at::Tensor& key_cache,
at::Tensor& value_cache,
const at::Tensor& slot_mapping,
const c10::optional<at::Tensor>& compress_lens,
const c10::optional<at::Tensor>& compress_seq_offsets,
const c10::optional<at::Tensor>& seq_lens,
c10::optional<c10::string_view> cache_mode)
{
char* cache_mode_ptr = cache_mode.has_value() ? const_cast<char *>(cache_mode.value().data()) : nullptr;
char* scatter_mode = "None";
c10::SmallVector<int64_t, op_infer::SIZE> strides_size = {1, 1};
at::IntArrayRef strides = at::IntArrayRef(strides_size);
c10::SmallVector<int64_t, op_infer::SIZE> offsets_size = {0, 0};
at::IntArrayRef offsets = at::IntArrayRef(offsets_size);
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnScatterPaKvCache, key, key_cache, slot_mapping, value, value_cache,
compress_lens, compress_seq_offsets, seq_lens, cache_mode_ptr, scatter_mode, strides, offsets);
}
}