/**

 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

 * MindIE is licensed under Mulan PSL v2.

 * You can use this software according to the terms and conditions of the Mulan PSL v2.

 * You may obtain a copy of Mulan PSL v2 at:

 *          http://license.coscl.org.cn/MulanPSL2

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,

 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,

 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.

 * See the Mulan PSL v2 for more details.

 */



#include <torch/library.h>

#include <iostream>



#include "torch_npu/csrc/framework/utils/OpAdapter.h"

#include "torch_npu/csrc/core/npu/NPUFormat.h"

#include "pytorch_npu_helper.h"

#include "la_preprocess.h"



using namespace at;



namespace {

constexpr int EXPECTED_TENSOR_DIMENSION = 4;

constexpr std::string_view LAPREPROCESS_NAME = "aclnnLaPreprocess";



}



std::tuple<at::Tensor, at::Tensor, at::Tensor>la_preprocess_mindie_sd_impl_npu(

    const at::Tensor &query,

    const at::Tensor &key,

    const at::Tensor &value,

    int64_t align_len)

{

    TORCH_CHECK(align_len > 0, "align_len must be positive, but got ", align_len);

    TORCH_CHECK(query.dim() == EXPECTED_TENSOR_DIMENSION, "Query must be 4D tensor");

    TORCH_CHECK(key.dim() == EXPECTED_TENSOR_DIMENSION, "Key must be 4D tensor");

    TORCH_CHECK(value.dim() == EXPECTED_TENSOR_DIMENSION, "Value must be 4D tensor");



    auto batch_size = query.sizes()[0];

    auto q_seq_len = query.sizes()[1];

    auto k_seq_len = key.sizes()[1];

    auto v_seq_len = value.sizes()[1];

    auto head_num = query.sizes()[2];

    auto head_dim = query.sizes()[3];



    auto q_padded_seq_len = (q_seq_len + align_len - 1) / align_len * align_len;

    auto k_padded_seq_len = (k_seq_len + align_len - 1) / align_len * align_len;

    auto v_padded_seq_len = (v_seq_len + align_len - 1) / align_len * align_len;

    auto options = query.options().dtype(at::kHalf);

    auto format = at_npu::native::get_npu_format(query);



    at::Tensor out_query = at_npu::native::empty_with_format(

        {batch_size, head_num, q_padded_seq_len, head_dim}, options, format);

    at::Tensor out_key = at_npu::native::empty_with_format(

        {batch_size, head_num, k_padded_seq_len, head_dim}, options, format);

    at::Tensor out_value = at_npu::native::empty_with_format(

        {batch_size, head_num, v_padded_seq_len, head_dim}, options, format);



    EXEC_NPU_CMD<LAPREPROCESS_NAME>(query, key, value, align_len,

        out_query, out_key, out_value);

    return std::make_tuple(out_query, out_key, out_value);

}