* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "atb/runner/mki_node_implement.h"
#include <securec.h>
#include <mki/utils/time/timer.h>
#include "atb/utils/log.h"
#include "atb/utils/config.h"
#include "atb/utils/tensor_util.h"
#include "atb/utils/statistic.h"
#include "atb/utils/store_util.h"
#include "atb/utils/probe.h"
namespace atb {
MkiNodeImplement::MkiNodeImplement(Mki::Operation *op, MkiInferShapePreFunc func)
: operation_(op), mkiInferShapePreFunc_(func)
{
if (operation_ == nullptr) {
ATB_LOG(ERROR) << "Operation implement init incorrectly";
} else {
ATB_LOG(INFO) << operation_->GetName() << " MkiNodeImplement Init";
}
}
MkiNodeImplement::~MkiNodeImplement() {}
std::string MkiNodeImplement::GetName() const
{
return kernel_ == nullptr ? operation_->GetName() : kernel_->GetName();
}
void MkiNodeImplement::Reset()
{
runInfo_.Reset();
launchParam_.Reset();
kernelCacheValid_ = false;
tilingBufferFilled_ = false;
kernel_.reset();
}
bool MkiNodeImplement::BuildLaunchParam(const SVector<Mki::Tensor *> &inTensors, SVector<ViewFunc> &inTensorViewFuncs,
const Mki::OpDesc &opDesc, size_t outTensorNum)
{
launchParam_.SetParam(opDesc.specificParam);
launchParam_.GetInTensors().clear();
launchParam_.GetOutTensors().clear();
kernelCacheValid_ = false;
for (size_t i = 0; i < inTensors.size(); ++i) {
Mki::Tensor *tensor = inTensors.at(i);
if (i < inTensorViewFuncs.size() && inTensorViewFuncs.at(i)) {
Mki::Tensor viewTensor = *tensor;
viewTensor.desc.dims.clear();
ATB_LOG(INFO) << GetLogPrefix() << " inTensorViewFuncs[" << i
<< "], tensor->desc.dims:" << TensorUtil::AsdOpsDimsToString(tensor->desc.dims)
<< ", viewTensor.desc.dims:" << TensorUtil::AsdOpsDimsToString(viewTensor.desc.dims);
inTensorViewFuncs.at(i)(tensor->desc.dims, viewTensor.desc.dims);
if (tensor->desc.format != Mki::TENSOR_FORMAT_FRACTAL_NZ && viewTensor.Numel() != tensor->Numel()) {
ATB_LOG(ERROR) << GetLogPrefix() << " mki node: inTensorViewFuncs[" << i
<< "], viewTensor.Numel:" << viewTensor.Numel() << ", tensor.Numel:" << tensor->Numel();
return false;
}
ATB_LOG(INFO) << GetLogPrefix() << " mki node: view inTensor[" << i
<< "], old:" << TensorUtil::AsdOpsDimsToString(tensor->desc.dims)
<< ", new:" << TensorUtil::AsdOpsDimsToString(viewTensor.desc.dims);
launchParam_.AddInTensor(viewTensor);
} else {
launchParam_.AddInTensor(*tensor);
}
}
for (size_t i = 0; i < outTensorNum; ++i) {
Mki::Tensor tensor;
launchParam_.AddOutTensor(tensor);
}
ATB_LOG(DEBUG) << GetLogPrefix() << " launchParam:\n" << launchParam_.ToString();
return true;
}
bool MkiNodeImplement::OperationGetBestKernel()
{
Mki::Kernel *kernel = operation_->GetBestKernel(launchParam_);
if (kernel == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << " " << operation_->GetName()
<< " get best kernel fail, kernel count:" << operation_->GetKernelList().size();
return false;
}
kernel_.reset(kernel);
ATB_LOG(DEBUG) << GetLogPrefix() << " best kernel:" << kernel_->GetName() << ", addr: " << kernel_;
return true;
}
bool MkiNodeImplement::PlanKernelInferShape()
{
ATB_LOG(INFO) << GetLogPrefix() << " " << operation_->GetName() << " infer shape start, launchParam:\n"
<< launchParam_.ToString();
if (mkiInferShapePreFunc_) {
ATB_LOG(DEBUG) << GetLogPrefix() << " " << operation_->GetName()
<< " call inferShapePreFunc, old launchParam:\n"
<< launchParam_.ToString();
mkiInferShapePreFunc_(launchParam_);
ATB_LOG(DEBUG) << GetLogPrefix() << " " << operation_->GetName()
<< " call inferShapePreFunc, new launchParam:\n"
<< launchParam_.ToString();
}
Mki::Status st = operation_->InferShape(launchParam_);
if (!st.Ok()) {
ATB_LOG(ERROR) << GetLogPrefix() << " " << operation_->GetName()
<< " mki node infer shape fail, error:" << st.Message();
return false;
}
ATB_LOG(INFO) << GetLogPrefix() << " " << operation_->GetName() << " mki node infer shape success, launchParam:\n"
<< launchParam_.ToString();
return true;
}
bool MkiNodeImplement::UpdateBestKernel()
{
if (kernelCacheValid_ && kernel_) {
ATB_LOG(DEBUG) << GetLogPrefix() << " " << operation_->GetName()
<< " has got best kernel: " << kernel_->GetName();
return true;
}
return OperationGetBestKernel();
}
Mki::SVector<Mki::Tensor> &MkiNodeImplement::GetInTensors()
{
return launchParam_.GetInTensors();
}
Mki::SVector<Mki::Tensor> &MkiNodeImplement::GetOutTensors()
{
return launchParam_.GetOutTensors();
}
size_t MkiNodeImplement::GetTilingSize() const
{
if (kernel_ == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << " kernel is null";
return 0;
}
return kernel_->GetTilingSize(launchParam_);
}
int64_t MkiNodeImplement::GetWorkspaceSize() const
{
if (kernel_ == nullptr) {
ATB_LOG(WARN) << GetLogPrefix() << " kernel is null";
return -1;
}
return kernel_->GetKernelInfo().GetTotalScratchSize();
}
Status MkiNodeImplement::InitKernelInfo(uint8_t *hostTilingBuffer, uint64_t tilingSize, bool launchWithTiling)
{
ATB_LOG(DEBUG) << GetLogPrefix() << " init kernel info";
if (kernel_ == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << " kernel is null";
return ERROR_INVALID_PARAM;
}
if (launchWithTiling) {
ATB_LOG(INFO) << GetLogPrefix() << " use tiling optimize";
kernel_->SetLaunchWithTiling(true);
} else {
ATB_LOG(DEBUG) << GetLogPrefix() << " set tiling info, tilingSize: " << tilingSize;
kernel_->SetLaunchWithTiling(false);
kernel_->SetTilingHostAddr(hostTilingBuffer, tilingSize);
}
Mki::Status status = kernel_->Init(launchParam_);
if (!status.Ok()) {
ATB_LOG(ERROR) << GetLogPrefix() << " InitRunInfo failed!";
return atb::ATB_MKI_ERROR_HASH.find(Mki::ErrorType(status.Code()))->second;
}
return NO_ERROR;
}
void MkiNodeImplement::SetWorkspaceDeviceAddr(uint8_t *deviceWorkspaceBuffer)
{
runInfo_.SetScratchDeviceAddr(deviceWorkspaceBuffer);
}
void MkiNodeImplement::SetTilingDeviceAddr(uint8_t *deviceTilingBuffer)
{
runInfo_.SetTilingDeviceAddr(deviceTilingBuffer);
}
Status MkiNodeImplement::BuildArgs()
{
Mki::Status st = kernel_->BuildArgs(launchParam_, runInfo_, argsHostBuffer_);
if (st.Ok()) {
ATB_LOG(DEBUG) << GetLogPrefix() << " " << kernel_->GetName()
<< " BuildArgs success, launchParam\n:" << launchParam_.ToString();
} else {
ATB_LOG(ERROR) << GetLogPrefix() << " " << kernel_->GetName()
<< " BuildArgs fail, launchParam\n:" << launchParam_.ToString() << "\nst: " << st.ToString();
return atb::ATB_MKI_ERROR_HASH.find(Mki::ErrorType(st.Code()))->second;
}
return NO_ERROR;
}
Status MkiNodeImplement::Run(aclrtStream stream)
{
Mki::Status st;
if (kernel_ == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << " kernel is null";
return ERROR_INVALID_PARAM;
}
runInfo_.SetStream(stream);
ATB_LOG(INFO) << GetLogPrefix() << " " << kernel_->GetName() << " run start, launchParam:\n"
<< launchParam_.ToString() << " runInfo:\n"
<< runInfo_.ToString();
bool isDeviceAddr = true;
if (argsDeviceBuffer_ != nullptr) {
st = kernel_->RunWithArgs(argsDeviceBuffer_, stream, isDeviceAddr);
} else {
st = kernel_->Run(launchParam_, runInfo_);
}
if (st.Ok()) {
ATB_LOG(INFO) << GetLogPrefix() << " " << kernel_->GetName()
<< " run end, launchParam\n:" << launchParam_.ToString();
} else {
ATB_LOG(ERROR) << GetLogPrefix() << " " << kernel_->GetName()
<< " run fail, launchParam\n:" << launchParam_.ToString() << "\nst: " << st.ToString();
return atb::ATB_MKI_ERROR_HASH.find(Mki::ErrorType(st.Code()))->second;
}
return NO_ERROR;
}
bool MkiNodeImplement::GetCachedTiling(KernelCache &kernelCache, size_t kernelIndex, uint8_t *kernelHostTilingBuffer,
uint64_t maxTilingSize, uint64_t &tilingSizeFetched, bool launchWithTiling)
{
tilingBufferFilled_ = false;
Mki::Timer kernelCacheGetTilingTimer;
const Mki::Kernel* kernelCached = nullptr;
TilingBufferPtr cachedTilingBuffer = kernelCache.GetTiling(kernelIndex, launchParam_, kernelCached);
if (kernelCached != nullptr) {
bool cachedTilingLaunchStatus = kernelCached->GetKernelInfo().GetLaunchWithTiling();
if (cachedTilingLaunchStatus != launchWithTiling) {
ATB_LOG(INFO) << "Cache miss because of status of tilingLaunch mismatch.";
return false;
}
if (!kernel_) {
kernel_.reset(kernelCached->Clone());
} else {
Mki::Status st = kernel_->Copy(*kernelCached);
if (!st.Ok()) {
ATB_LOG(ERROR) << GetLogPrefix() << " mki node copy kernelCached to kernel fail, error: " << st.Message();
return false;
}
}
kernelCacheValid_ = true;
}
GetOpSetupStatistic().kernelCacheGetTilingTime += kernelCacheGetTilingTimer.ElapsedMicroSecond();
if (cachedTilingBuffer == nullptr) {
tilingSizeFetched = 0;
return false;
}
tilingSizeFetched = cachedTilingBuffer->size();
if (tilingSizeFetched > maxTilingSize) {
ATB_LOG(ERROR) << GetLogPrefix() << " MkiNodeImplement do not have enough tiling buffer for cached tilnig";
return false;
}
if (!launchWithTiling || Probe::IsSaveTiling()) {
int ret = memcpy_s(kernelHostTilingBuffer, maxTilingSize, cachedTilingBuffer->data(), tilingSizeFetched);
if (ret != EOK) {
ATB_LOG(ERROR) << GetLogPrefix() << " MkiNodeImplement memcpy_s cached tiling fail, error:" << ret;
return false;
}
}
tilingBufferFilled_ = true;
return true;
}
void MkiNodeImplement::AddTiling(KernelCache &kernelCache, size_t kernelIndex, uint8_t *hostTilingBuffer,
size_t tilingSize) const
{
kernelCache.AddTiling(kernelIndex, hostTilingBuffer, tilingSize, launchParam_, kernel_.get());
ATB_LOG(DEBUG) << GetLogPrefix() << " AddTiling end, runinfo\n:" << runInfo_.ToString();
}
void MkiNodeImplement::ResetLogPrefix(const std::string &prefix, size_t kernelId)
{
std::stringstream ss;
ss << prefix << "[" << kernelId << "]";
logPrefix_ = ss.str();
}
bool MkiNodeImplement::GetTilingFilledFlag() const
{
return tilingBufferFilled_;
}
void MkiNodeImplement::SetTilingFilledFlag(bool flag)
{
tilingBufferFilled_ = flag;
}
std::string MkiNodeImplement::GetLogPrefix() const
{
return logPrefix_;
}
void MkiNodeImplement::SaveLaunchParam(aclrtStream stream, const std::string &dirPath) const
{
StoreUtil::SaveLaunchParam(stream, launchParam_, dirPath);
}
void *MkiNodeImplement::GetMsprofInfoKey() const
{
if (kernel_ == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << " kernel is nullptr, get msprof key failed";
return nullptr;
}
return *(reinterpret_cast<void **>(kernel_.get()));
}
void MkiNodeImplement::GetReportTensors(Mki::SVector<std::pair<bool, Mki::Tensor>> &allTensors) const
{
for (std::size_t i = 0; i < launchParam_.GetInTensorCount(); i++) {
allTensors.push_back(std::make_pair(true, launchParam_.GetInTensors().at(i)));
}
for (std::size_t i = 0; i < launchParam_.GetOutTensorCount(); i++) {
allTensors.push_back(std::make_pair(false, launchParam_.GetOutTensors().at(i)));
}
}
uint32_t MkiNodeImplement::GetOpType() const
{
return kernel_->GetType();
}
uint32_t MkiNodeImplement::GetBlockDim() const
{
const Mki::KernelInfo &kernelInfo = kernel_->GetKernelInfo();
return kernelInfo.GetBlockDim();
}
void MkiNodeImplement::SetArgsDeviceBuffer(void *deviceBuffer)
{
argsDeviceBuffer_ = deviceBuffer;
}
void MkiNodeImplement::SetArgsHostBuffer(void *hostBuffer)
{
argsHostBuffer_ = hostBuffer;
}
void *MkiNodeImplement::GetArgsDeviceBuffer()
{
return argsDeviceBuffer_;
}
void *MkiNodeImplement::GetArgsHostBuffer()
{
return argsHostBuffer_;
}
uint64_t MkiNodeImplement::GetArgsSize()
{
if (kernel_ == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << " kernel is null";
return 0;
}
return kernel_->GetKernelInfo().GetArgsSize();
}
Status MkiNodeImplement::BuildLaunchParam(const SVector<Mki::Tensor *> &inTensors,
const SVector<Mki::Tensor *> &outTensors, const Mki::OpDesc &opDesc)
{
launchParam_.SetParam(opDesc.specificParam);
launchParam_.GetInTensors().clear();
launchParam_.GetOutTensors().clear();
kernelCacheValid_ = false;
for (size_t i = 0; i < inTensors.size(); ++i) {
Mki::Tensor *tensor = inTensors.at(i);
launchParam_.AddInTensor(*tensor);
}
for (size_t i = 0; i < outTensors.size(); ++i) {
Mki::Tensor *tensor = outTensors.at(i);
launchParam_.AddOutTensor(*tensor);
}
return NO_ERROR;
}
const std::unordered_map<const Mki::ErrorType, atb::ErrorType> InitAtbMkiErrorHash() noexcept
{
return {{Mki::ErrorType::NO_ERROR, atb::ErrorType::NO_ERROR},
{Mki::ErrorType::ERROR_INVALID_VALUE, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_OPERATION_NOT_EXIST, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_TACTIC_NOT_EXIST, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_KERNEL_NOT_EXIST, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_ATTR_NOT_EXIST, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_ATTR_INVALID_TYPE, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_LAUNCH_KERNEL_ERROR, atb::ErrorType::ERROR_RT_FAIL},
{Mki::ErrorType::ERROR_SYNC_STREAM_ERROR, atb::ErrorType::ERROR_RT_FAIL},
{Mki::ErrorType::ERROR_INFERSHAPE_ERROR, atb::ErrorType::ERROR_RT_FAIL},
{Mki::ErrorType::ERROR_NOT_CONSISTANT, atb::ErrorType::ERROR_INVALID_PARAM},
{Mki::ErrorType::ERROR_ALLOC_HOST, atb::ErrorType::ERROR_OUT_OF_HOST_MEMORY},
{Mki::ErrorType::ERROR_MEMERY_COPY_ERROR, atb::ErrorType::ERROR_COPY_HOST_MEMORY_FAIL},
{Mki::ErrorType::ERROR_RUN_TIME_ERROR, atb::ErrorType::ERROR_RT_FAIL}};
}
const std::unordered_map<const Mki::ErrorType, atb::ErrorType> ATB_MKI_ERROR_HASH = InitAtbMkiErrorHash();
}