* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstring>
#include "graph/types.h"
#include "aclnn_mla_prolog.h"
#include "opdev/platform.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/op_def.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
namespace {
extern aclnnStatus aclnnInnerMlaPrologGetWorkspaceSize(
const aclTensor *tokenX, const aclTensor *weightDq, const aclTensor *weightUqQr, const aclTensor *weightUk,
const aclTensor *weightDkvKr, const aclTensor *rmsnormGammaCq, const aclTensor *rmsnormGammaCkv,
const aclTensor *ropeSin, const aclTensor *ropeCos, const aclTensor *cacheIndex, aclTensor *kvCacheRef,
aclTensor *krCacheRef, const aclTensor *dequantScaleXOptional, const aclTensor *dequantScaleWDqOptional,
const aclTensor *dequantScaleWUqQrOptional, const aclTensor *dequantScaleWDkvKrOptional,
const aclTensor *quantScaleCkvOptional, const aclTensor *quantScaleCkrOptional,
const aclTensor *smoothScalesCqOptional, double rmsnormEpsilonCq, double rmsnormEpsilonCkv, char *cacheModeOptional,
const aclTensor *queryOut, const aclTensor *queryRopeOut, uint64_t *workspaceSize, aclOpExecutor **executor);
extern aclnnStatus aclnnInnerMlaProlog(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream);
aclnnStatus aclnnMlaPrologGetWorkspaceSize(
const aclTensor *tokenX, const aclTensor *weightDq, const aclTensor *weightUqQr, const aclTensor *weightUk,
const aclTensor *weightDkvKr, const aclTensor *rmsnormGammaCq, const aclTensor *rmsnormGammaCkv,
const aclTensor *ropeSin, const aclTensor *ropeCos, const aclTensor *cacheIndex, aclTensor *kvCacheRef,
aclTensor *krCacheRef, const aclTensor *dequantScaleXOptional, const aclTensor *dequantScaleWDqOptional,
const aclTensor *dequantScaleWUqQrOptional, const aclTensor *dequantScaleWDkvKrOptional,
const aclTensor *quantScaleCkvOptional, const aclTensor *quantScaleCkrOptional,
const aclTensor *smoothScalesCqOptional, double rmsnormEpsilonCq, double rmsnormEpsilonCkv, char *cacheModeOptional,
const aclTensor *queryOut, const aclTensor *queryRopeOut, uint64_t *workspaceSize, aclOpExecutor **executor)
{
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "Interface aclnnMlaPrologGetWorkspaceSize "
"are no longer supported on Ascend950.");
return ACLNN_ERR_RUNTIME_ERROR;
}
static bool isFirstCall = true;
if (isFirstCall) {
OP_LOGW("aclnnMlaPrologGetWorkspaceSize is scheduled to be deprecated in December 2026, "
"and will be replaced by the aclnnMlaPrologV3GetWorkspaceSize. "
"We apologize for any inconvenience caused and appreciate your timely migration to the new interface.");
isFirstCall = false;
}
return aclnnInnerMlaPrologGetWorkspaceSize(
tokenX, weightDq, weightUqQr, weightUk, weightDkvKr, rmsnormGammaCq, rmsnormGammaCkv, ropeSin, ropeCos,
cacheIndex, kvCacheRef, krCacheRef, dequantScaleXOptional, dequantScaleWDqOptional, dequantScaleWUqQrOptional,
dequantScaleWDkvKrOptional, quantScaleCkvOptional, quantScaleCkrOptional, smoothScalesCqOptional,
rmsnormEpsilonCq, rmsnormEpsilonCkv, cacheModeOptional, queryOut, queryRopeOut, workspaceSize, executor);
}
aclnnStatus aclnnMlaProlog(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
{
if (op::GetCurrentPlatformInfo().GetCurNpuArch() == NpuArch::DAV_3510) {
OP_LOGE(ACLNN_ERR_RUNTIME_ERROR, "Interface aclnnMlaProlog are no longer supported on Ascend950.");
return ACLNN_ERR_RUNTIME_ERROR;
}
static bool isFirstCall = true;
if (isFirstCall) {
OP_LOGW("aclnnMlaProlog is scheduled to be deprecated in December 2026, "
"and will be replaced by the aclnnMlaPrologV3WeightNz. "
"We apologize for any inconvenience caused and appreciate your timely migration to the new interface.");
isFirstCall = false;
}
return aclnnInnerMlaProlog(workspace, workspaceSize, executor, stream);
}
}
#ifdef __cplusplus
}
#endif