* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "interp_api.h"
#include "utils/ops_base.h"
#include "utils/op_desc.h"
#include "utils/aspb_status.h"
#include "utils/assert.h"
#include "log/log.h"
#include "mki/tensor.h"
#include "mki/utils/rt/rt.h"
#include "mki/utils/assert/assert.h"
#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "interpbycoeff.h"
using namespace AsdSip;
using namespace Mki;
constexpr uint32_t INTERP_WORKSPACE_SIZE = 65536 * 20 * 2;
constexpr uint64_t INTERP_DIMS_THREE = 3;
constexpr int64_t MAX_BATCH = 1024;
constexpr int64_t INTERP_RS_TWO = 2;
constexpr int64_t INTERP_RS_FOUR = 4;
constexpr int64_t MAX_TOTAL_SUBCARRIER = 32760;
constexpr int64_t MAX_SINGAL_NUM = 14;
constexpr int DIMS_TWO = 2;
namespace AsdSip {
int64_t *getInputShape(const aclTensor *x)
{
int64_t *storageDims = nullptr;
uint64_t storageDimsNum = 0;
auto ret = aclGetStorageShape(x, &storageDims, &storageDimsNum);
if (ret != ACL_SUCCESS || *storageDims <= 0 || storageDimsNum != INTERP_DIMS_THREE) {
if (storageDims != nullptr) {
delete[] storageDims;
storageDims = nullptr;
}
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) << "interpByCoeff get wrong input tensor.";
return nullptr;
}
if (storageDims[0] > MAX_BATCH || (storageDims[1] != INTERP_RS_TWO && storageDims[1] != INTERP_RS_FOUR)
|| storageDims[DIMS_TWO] > MAX_TOTAL_SUBCARRIER) {
delete[] storageDims;
storageDims = nullptr;
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) << "interpByCoeff do not support input tensor shape.";
return nullptr;
}
return storageDims;
}
int64_t *getCoeffShape(const aclTensor *coefficient)
{
int64_t *coeffDims = nullptr;
uint64_t coeffDimsNum = 0;
auto ret = aclGetStorageShape(coefficient, &coeffDims, &coeffDimsNum);
if (ret != ACL_SUCCESS || *coeffDims <= 0 || coeffDimsNum != INTERP_DIMS_THREE || coeffDims[1] < 0 ||
coeffDims[1] > MAX_SINGAL_NUM) {
delete[] coeffDims;
coeffDims = nullptr;
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) << "interpByCoeff get wrong coefficient tensor.";
}
return coeffDims;
}
void cleanAcl(int64_t *storageDims, int64_t *coeffDims)
{
if (storageDims != nullptr) {
delete[] storageDims;
storageDims = nullptr;
}
if (coeffDims != nullptr) {
delete[] coeffDims;
coeffDims = nullptr;
}
}
AspbStatus asdInterpWithCoeff(const aclTensor *x, const aclTensor *coefficient, aclTensor *output,
void *stream, void *workSpace)
{
int64_t *storageDims = getInputShape(x);
ASDSIP_ECHECK(storageDims != nullptr, "InterpWithCoeff failed.", ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH);
int64_t *coeffDims = getCoeffShape(coefficient);
ASDSIP_ECHECK(coeffDims != nullptr, "InterpWithCoeff failed.", ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH);
if (storageDims[0] != coeffDims[0] || storageDims[1] != coeffDims[DIMS_TWO]) {
cleanAcl(storageDims, coeffDims);
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) << "input and coefficient do not match.";
return ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH;
}
OpDesc opDesc;
opDesc.opName = "InterpByCoeffOperation";
AsdSip::OpParam::InterpByCoeff param;
param.batch = storageDims[0];
param.rsNum = storageDims[1];
param.totalSubcarrier = storageDims[DIMS_TWO];
param.interpLength = coeffDims[1];
opDesc.specificParam = param;
cleanAcl(storageDims, coeffDims);
SVector<aclTensor *> inTensors{const_cast<aclTensor*>(x), const_cast<aclTensor*>(coefficient)};
SVector<aclTensor *> outTensors{output};
Mki::Status status = RunAsdOpsV2(stream, opDesc, inTensors, outTensors, (uint8_t *)workSpace);
ASDSIP_ECHECK(status.Ok(), status.Message(), ErrorType::ACL_ERROR_INTERNAL_ERROR);
output = outTensors.at(0);
return ErrorType::ACL_SUCCESS;
}
AspbStatus asdInterpWithCoeffGetWorkspaceSize(size_t &workspaceSize)
{
workspaceSize = INTERP_WORKSPACE_SIZE;
return ErrorType::ACL_SUCCESS;
}
}