* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cmatinv_batched.h"
#include "cgetri_batched.h"
#include "blas_common.h"
#include "blasplan/include/blasplan/BlasCmatinvBatchedPlan.h"
#include "blasplan/include/blasplan/BlasCgetriBatchedPlan.h"
using namespace AsdSip;
using namespace Mki;
static constexpr uint32_t DIM_0 = 0;
static constexpr uint32_t DIM_1 = 1;
static constexpr uint32_t DIM_2 = 2;
static constexpr uint32_t DIM_3 = 3;
static constexpr int32_t MATRIX_SHAPE_LIMIT = 32;
static constexpr int32_t MAX_MATRIX_SHAPE = 256;
static constexpr int32_t MAX_MATRIX_BATCH = 3000;
static constexpr int64_t DTYPE_COMPLEX32 = 0;
static constexpr int64_t DTYPE_COMPLEX64 = 1;
namespace AsdSip {
struct CmatinvBatchedImplParam {
int64_t dtype;
int64_t n;
int64_t batchSize;
aclTensor *A;
aclTensor *Ainv;
};
AspbStatus asdBlasCmatinvBatchedShapeCheck(CmatinvBatchedImplParam implParam)
{
int64_t *AstorageDims = nullptr;
uint64_t AstorageDimsNum = 0;
CHECK_STATUS_WITH_ACL_RETURN(aclGetStorageShape(implParam.A, &AstorageDims, &AstorageDimsNum),
"asdBlasCmatinvBatched: aclGetStorageShape");
if (AstorageDimsNum != DIM_3 || AstorageDims[DIM_0] != implParam.batchSize ||
AstorageDims[DIM_1] != implParam.n || AstorageDims[DIM_2] != implParam.n) {
delete[] AstorageDims;
AstorageDims = nullptr;
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) <<
"blas asdBlasCmatinvBatched get wrong A tensor dim num.";
return ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH;
}
if (AstorageDims != nullptr) {
delete[] AstorageDims;
AstorageDims = nullptr;
}
int64_t *AinvStorageDims = nullptr;
uint64_t AinvStorageDimsNum = 0;
CHECK_STATUS_WITH_ACL_RETURN(aclGetStorageShape(implParam.Ainv, &AinvStorageDims, &AinvStorageDimsNum),
"asdBlasCmatinvBatched: aclGetStorageShape");
if (AinvStorageDimsNum != DIM_3 || AinvStorageDims[DIM_0] != implParam.batchSize ||
AinvStorageDims[DIM_1] != implParam.n || AinvStorageDims[DIM_2] != implParam.n) {
delete[] AinvStorageDims;
AinvStorageDims = nullptr;
ASDSIP_ELOG(ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH) <<
"blas asdBlasCmatinvBatched get wrong Ainv tensor dim num.";
return ErrorType::ACL_ERROR_OP_INPUT_NOT_MATCH;
}
if (AinvStorageDims != nullptr) {
delete[] AinvStorageDims;
AinvStorageDims = nullptr;
}
return ErrorType::ACL_SUCCESS;
}
AspbStatus runCmatinvBatchedOps(asdBlasHandle handle, CmatinvBatchedImplParam implParam)
{
try {
AsdSip::BlasCmatinvBatchedPlan &plan =
dynamic_cast<AsdSip::BlasCmatinvBatchedPlan &>(BlasPlanCache::getPlan(handle));
ASDSIP_ECHECK(plan.IsInitialized(), "CmatinvBatched plan init Error!.", ErrorType::ACL_ERROR_INTERNAL_ERROR);
Status status;
OpDesc opDesc;
opDesc.opName = "CmatinvBatchedOperation";
AsdSip::OpParam::CmatinvBatched param;
param = {implParam.dtype, implParam.n, implParam.batchSize};
opDesc.specificParam = param;
ASDSIP_LOG(DEBUG) << "OpDesc: " << opDesc.opName << "; OpDesc info: " << param.ToString();
SVector<aclTensor *> inTensors{implParam.A, plan.uniMatAclReal, plan.uniMatAclImag, plan.offsetAclTensor};
SVector<aclTensor *> outTensors{implParam.Ainv};
status = RunAsdOpsV2(plan.GetStream(), opDesc, inTensors, outTensors, plan.GetWorkspace());
ASDSIP_ECHECK(status.Ok(), status.Message(), ErrorType::ACL_ERROR_INTERNAL_ERROR);
ASDSIP_LOG(INFO) << "Execute asdBlasCmatinvBatched success when n < 32.";
return ErrorType::ACL_SUCCESS;
} catch (std::bad_cast &e) {
ASDSIP_ELOG(ErrorType::ACL_ERROR_INTERNAL_ERROR) << "CmatinvBatched Error (n < 32): " << e.what();
return ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
}
AspbStatus runCgetriBatchedOps(asdBlasHandle handle, CmatinvBatchedImplParam implParam)
{
try {
AsdSip::BlasCgetriBatchedPlan &plan =
dynamic_cast<AsdSip::BlasCgetriBatchedPlan &>(BlasPlanCache::getPlan(handle));
ASDSIP_ECHECK(plan.IsInitialized(), "CmatinvBatched (n >= 32) plan init Error!.", ErrorType::ACL_ERROR_INTERNAL_ERROR);
Status status;
OpDesc opDesc;
opDesc.opName = "CgetriBatchedOperation";
AsdSip::OpParam::CgetriBatched param;
param = {implParam.dtype, implParam.n, implParam.batchSize};
opDesc.specificParam = param;
ASDSIP_LOG(DEBUG) << "OpDesc: " << opDesc.opName << "; OpDesc info: " << param.ToString();
SVector<aclTensor *> inTensors{implParam.A, plan.wBatchAcl, plan.gather1OffsetAcl,
plan.gather2OffsetAcl, plan.gather3OffsetAcl, plan.eyeBatchMatAcl, plan.aWorkAcl, plan.workGmAcl};
SVector<aclTensor *> outTensors{implParam.Ainv};
status = RunAsdOpsV2(plan.GetStream(), opDesc, inTensors, outTensors, plan.GetWorkspace());
ASDSIP_ECHECK(status.Ok(), status.Message(), ErrorType::ACL_ERROR_INTERNAL_ERROR);
ASDSIP_LOG(INFO) << "Execute asdBlasCmatinvBatched success when n >= 32.";
return ErrorType::ACL_SUCCESS;
} catch (std::bad_cast &e) {
ASDSIP_ELOG(ErrorType::ACL_ERROR_INTERNAL_ERROR) << "CmatinvBatched Error (n >= 32): " << e.what();
return ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
}
AspbStatus asdBlasCmatinvBatchedImpl(asdBlasHandle handle, CmatinvBatchedImplParam implParam)
{
ASDSIP_ECHECK(BlasPlanCache::doesPlanExist(handle),
"blas CmatinvBatched get cached plan failed.",
ErrorType::ACL_ERROR_INTERNAL_ERROR);
AsdSip::AspbStatus shapeCheckStatus = asdBlasCmatinvBatchedShapeCheck(implParam);
if (shapeCheckStatus != AsdSip::ErrorType::ACL_SUCCESS) {
return shapeCheckStatus;
}
AsdSip::AspbStatus runOpsStatus;
if (implParam.n < MATRIX_SHAPE_LIMIT) {
runOpsStatus = runCmatinvBatchedOps(handle, implParam);
} else {
runOpsStatus = runCgetriBatchedOps(handle, implParam);
}
return runOpsStatus;
}
AspbStatus asdBlasCmatinvBatched(asdBlasHandle handle, const int64_t n, aclTensor *A,
const int64_t lda, aclTensor *Ainv, const int64_t lda_inv, aclTensor *info, int64_t batchSize)
{
(void)info;
std::lock_guard<std::mutex> lock(blas_mtx);
aclDataType dataType = aclDataType::ACL_DT_UNDEFINED;
CHECK_STATUS_WITH_ACL_RETURN(aclGetDataType(A, &dataType), "asdBlasCmatinvBatched: aclGetDataType");
ASDSIP_ECHECK(dataType == aclDataType::ACL_COMPLEX64,
"blas asdBlasCmatinvBatched get wrong x tensor dtype.",
ErrorType::ACL_ERROR_UNSUPPORTED_DATA_TYPE);
CHECK_STATUS_WITH_ACL_RETURN(aclGetDataType(Ainv, &dataType), "asdBlasCmatinvBatched: aclGetDataType");
ASDSIP_ECHECK(dataType == aclDataType::ACL_COMPLEX64,
"blas asdBlasCmatinvBatched get wrong y tensor dtype.",
ErrorType::ACL_ERROR_UNSUPPORTED_DATA_TYPE);
if (lda <= 0 || lda_inv <= 0) {
ASDSIP_LOG(INFO) << "blas asdBlasCmatinvBatched get lda <= 0 or lda_inv <= 0.";
}
ASDSIP_ECHECK(n > 0 && batchSize > 0, "blas asdBlasCmatinvBatched get n <= 0 || batchSize <= 0.",
ErrorType::ACL_ERROR_INVALID_PARAM);
ASDSIP_ECHECK(n <= MAX_MATRIX_SHAPE && batchSize <= MAX_MATRIX_BATCH,
"blas asdBlasCmatinvBatched get n <= 256 || batchSize <= 3000.",
ErrorType::ACL_ERROR_INVALID_PARAM);
return asdBlasCmatinvBatchedImpl(handle, {DTYPE_COMPLEX64, n, batchSize, A, Ainv});
}
AspbStatus asdBlasHCmatinvBatched(asdBlasHandle handle, const int64_t n, aclTensor *A,
const int64_t lda, aclTensor *Ainv, const int64_t lda_inv, aclTensor *info, int64_t batchSize)
{
(void)info;
std::lock_guard<std::mutex> lock(blas_mtx);
ASDSIP_ECHECK(n > 0 && batchSize > 0, "blas asdBlasHCmatinvBatched get n <= 0 || batchSize <= 0.",
ErrorType::ACL_ERROR_INVALID_PARAM);
ASDSIP_ECHECK(n <= MAX_MATRIX_SHAPE && batchSize <= MAX_MATRIX_BATCH,
"blas asdBlasHCmatinvBatched get n <= 256 || batchSize <= 3000.",
ErrorType::ACL_ERROR_INVALID_PARAM);
if (lda <= 0 || lda_inv <= 0) {
ASDSIP_LOG(INFO) << "blas asdBlasHCmatinvBatched get lda <= 0 or lda_inv <= 0.";
}
return asdBlasCmatinvBatchedImpl(handle, {DTYPE_COMPLEX32, n, batchSize, A, Ainv});
}
AspbStatus makeCmatinvBatchedPlan(asdBlasHandle handle, const int64_t n)
{
AsdSip::BlasCmatinvBatchedPlan *plan = nullptr;
try {
plan = new AsdSip::BlasCmatinvBatchedPlan(n);
BlasPlanCache::MakePlan(handle, plan);
} catch (const std::exception &e) {
if (plan != nullptr) {
delete plan;
}
delete static_cast<int *>(handle);
ASDSIP_ELOG(ErrorType::ACL_ERROR_INTERNAL_ERROR) << "Make CmatinvBatched Plan failed: " << e.what();
throw std::runtime_error("Make CmatinvBatched Plan failed.");
}
if (plan->CreateTensor() != ErrorType::ACL_SUCCESS) {
plan->MarkFailed();
ASDSIP_ELOG(ErrorType::ACL_ERROR_INTERNAL_ERROR) << "Fail to create blas CmatinvBatched mask tensor.";
return ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
plan->MarkInitialized();
return ErrorType::ACL_SUCCESS;
}
AspbStatus makeCgetriBatchedPlan(asdBlasHandle handle, const int64_t n, const int64_t batchSize, int64_t dtype)
{
AsdSip::BlasCgetriBatchedPlan *plan = nullptr;
try {
plan = new AsdSip::BlasCgetriBatchedPlan(n, batchSize, dtype);
BlasPlanCache::MakePlan(handle, plan);
} catch (const std::exception &e) {
if (plan != nullptr) {
delete plan;
}
delete static_cast<int *>(handle);
ASDSIP_ELOG(ErrorType::ACL_ERROR_INTERNAL_ERROR) <<
"Make CmatinvBatched (n >= 32) Plan failed: " << e.what();
throw std::runtime_error("Make CmatinvBatched (n >= 32) Plan failed.");
}
if (plan->CreateTensor() != ErrorType::ACL_SUCCESS) {
plan->MarkFailed();
ASDSIP_ELOG(ErrorType::ACL_ERROR_INTERNAL_ERROR) <<
"Fail to create blas CmatinvBatched (n >= 32) mask tensor.";
return ErrorType::ACL_ERROR_INTERNAL_ERROR;
}
plan->MarkInitialized();
return ErrorType::ACL_SUCCESS;
}
AspbStatus asdBlasMakeCmatinvBatchedPlan(asdBlasHandle handle, const int64_t n, const int64_t batchSize)
{
std::lock_guard<std::mutex> lock(blas_mtx);
ASDSIP_ECHECK(handle != nullptr, "blas CmatinvBatched Make Plan failed.", ErrorType::ACL_ERROR_INTERNAL_ERROR);
ASDSIP_ECHECK(n > 0, "blas asdBlasMakeCmatinvBatchedPlan get n <= 0.", ErrorType::ACL_ERROR_INVALID_PARAM);
ASDSIP_ECHECK(batchSize > 0, "blas asdBlasMakeCmatinvBatchedPlan get batchSize <= 0.",
ErrorType::ACL_ERROR_INVALID_PARAM);
if (n < MATRIX_SHAPE_LIMIT) {
return makeCmatinvBatchedPlan(handle, n);
} else {
return makeCgetriBatchedPlan(handle, n, batchSize, DTYPE_COMPLEX64);
}
}
AspbStatus asdBlasMakeHCmatinvBatchedPlan(asdBlasHandle handle, const int64_t n, const int64_t batchSize)
{
std::lock_guard<std::mutex> lock(blas_mtx);
ASDSIP_ECHECK(handle != nullptr, "blas HCmatinvBatched Make Plan failed.", ErrorType::ACL_ERROR_INTERNAL_ERROR);
ASDSIP_ECHECK(n > 0, "blas asdBlasMakeHCmatinvBatchedPlan get n <= 0.", ErrorType::ACL_ERROR_INVALID_PARAM);
ASDSIP_ECHECK(batchSize > 0, "blas asdBlasMakeHCmatinvBatchedPlan get batchSize <= 0.",
ErrorType::ACL_ERROR_INVALID_PARAM);
if (n < MATRIX_SHAPE_LIMIT) {
return makeCmatinvBatchedPlan(handle, n);
} else {
return makeCgetriBatchedPlan(handle, n, batchSize, DTYPE_COMPLEX32);
}
}
}