* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <torch/library.h>
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "torch_npu/csrc/core/npu/NPUFormat.h"
#include "pytorch_npu_helper.h"
#include "la.h"
using namespace at;
using npu_preparation = at_npu::native::OpPreparation;
namespace {
constexpr std::string_view LASERATTENTIONOP_NAME = "aclnnLaserAttention";
}
std::tuple<at::Tensor, at::Tensor> la_mindie_sd_impl_npu(
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
const c10::optional<at::Tensor> &atten_mask_opt,
const c10::optional<at::Tensor> &alibi_mask_opt,
const c10::optional<at::Tensor> &drop_mask_opt,
double scale_value, int64_t head_num, std::string input_layout,
double keep_prob, int64_t pre_tokens, int64_t next_tokens, bool is_highPrecision)
{
size_t query_dim = query.sizes().size();
if (query_dim != 4) {
throw std::invalid_argument("The first input dimension of la must be 4 but got " + str(query_dim));
}
const at::Tensor& atten_mask = c10::value_or_else(atten_mask_opt, [] {return at::Tensor();});
const at::Tensor& alibi_mask = c10::value_or_else(alibi_mask_opt, [] {return at::Tensor();});
const at::Tensor& drop_mask = c10::value_or_else(drop_mask_opt, [] {return at::Tensor();});
at::Tensor softmax_log_max_sum =
at_npu::native::empty_with_format({query.sizes()[0], query.sizes()[1], query.sizes()[2]},
query.options().dtype(at::kFloat),
at_npu::native::get_npu_format(query)
);
at::Tensor attention_out = at_npu::native::empty_with_format(query.sizes(),
query.options().dtype(at::kFloat),
at_npu::native::get_npu_format(query)
);
EXEC_NPU_CMD<LASERATTENTIONOP_NAME>(query, key, value, atten_mask, alibi_mask, drop_mask,
scale_value, head_num, input_layout, keep_prob, pre_tokens, next_tokens,
is_highPrecision, softmax_log_max_sum, attention_out);
return std::tuple<at::Tensor, at::Tensor>(softmax_log_max_sum, attention_out);
}