* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "ArgsRawContext.h"
#include "acl_rt_impl/AscendclImplOrigin.h"
#include "utils/Numeric.h"
#include "utils/InjectLogger.h"
using namespace std;
bool ArgsRawContext::GetTilingData(std::vector<uint8_t> &data) const
{
if (placeHolderArray_.empty() || data.empty()) {
DEBUG_LOG("placeHolderArray_ size=%lu, tiling_data_size=%lu", placeHolderArray_.size(), data.size());
return false;
}
aclrtPlaceHolderInfo tilingInfo{};
for (auto const &placeholder : placeHolderArray_) {
if (placeholder.addrOffset > tilingInfo.addrOffset) {
tilingInfo.addrOffset = placeholder.addrOffset;
tilingInfo.dataOffset = placeholder.dataOffset;
}
}
const uint8_t *buff = static_cast<const uint8_t *>(args_);
const uint8_t *tilingData = buff + tilingInfo.dataOffset;
data.assign(tilingData, tilingData + data.size());
return true;
}
bool ArgsRawContext::ExpandArgs(void *param, size_t paramSize, uint32_t ¶mOffset)
{
argsWithMemInfo_.clear();
constexpr uint32_t ALIGN_SIZE = 8;
if (argsSize_ > MAX_ALL_PARAM_SIZE) {
ERROR_LOG("argsSize is over max size: %zu", MAX_ALL_PARAM_SIZE);
return false;
}
uint32_t alignArgsSize = CeilByAlignSize<ALIGN_SIZE>(argsSize_);
argsWithMemInfo_.resize(alignArgsSize, 0);
auto copyMode = aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_HOST;
if (isDeviceArgs_) {
copyMode = aclrtMemcpyKind::ACL_MEMCPY_DEVICE_TO_HOST;
}
aclrtMemcpyImplOrigin(argsWithMemInfo_.data(), argsSize_, args_, argsSize_, copyMode);
uint8_t *paramPtr = reinterpret_cast<uint8_t*>(param);
if (placeHolderArray_.empty()) {
paramOffset = argsWithMemInfo_.size();
} else {
uint32_t lastAddrOffset {};
for (auto const &placeholder : placeHolderArray_) {
lastAddrOffset = std::max(lastAddrOffset, placeholder.addrOffset);
}
paramOffset = lastAddrOffset + sizeof(uintptr_t);
if (DeviceContext::Local().NeedOverflowStatus()) { paramOffset += sizeof(uintptr_t); }
for (size_t i = 0; i < placeHolderArray_.size(); i++) {
placeHolderArray_[i].dataOffset += paramSize;
}
}
argsWithMemInfo_.insert(argsWithMemInfo_.cbegin() + paramOffset, paramPtr, paramPtr + paramSize);
argsSize_ = argsWithMemInfo_.size();
if (!isDeviceArgs_) {
args_ = argsWithMemInfo_.data();
return true;
}
FreeArgs();
aclError ret = aclrtMallocImplOrigin(&args_, argsSize_, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
return false;
}
needFree_ = true;
ret = aclrtMemcpyImplOrigin(args_, argsSize_, argsWithMemInfo_.data(), argsSize_,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
return false;
}
return true;
}
uint32_t ArgsRawContext::GetLastParamOffset()
{
constexpr uint32_t ALIGN_SIZE = 8;
if (argsSize_ > MAX_ALL_PARAM_SIZE) {
ERROR_LOG("argsSize is over max size: %zu", MAX_ALL_PARAM_SIZE);
return 0;
}
if (placeHolderArray_.empty()) {
return CeilByAlignSize<ALIGN_SIZE>(argsSize_);
}
uint32_t lastAddrOffset {};
for (auto const &placeHolder : placeHolderArray_) {
lastAddrOffset = std::max(lastAddrOffset, placeHolder.addrOffset);
}
uint32_t paramOffset = lastAddrOffset + sizeof(uintptr_t);
if (DeviceContext::Local().NeedOverflowStatus()) { paramOffset += sizeof(uintptr_t); }
return paramOffset;
}
bool ArgsRawContext::Save(const std::string &outputPath, DumperContext &config, OpMemInfo &memInfo, bool isSink)
{
(void)isSink;
uint64_t *args;
if (isDeviceArgs_) {
void *hostArgs_;
aclError ret = aclrtMallocHostImplOrigin(&hostArgs_, argsSize_);
if (ret != ACL_ERROR_NONE) {
return false;
}
ret = aclrtMemcpyImplOrigin(hostArgs_, argsSize_, args_, argsSize_, aclrtMemcpyKind::ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
aclrtFreeHostImplOrigin(hostArgs_);
return false;
}
args = static_cast<uint64_t *>(hostArgs_);
} else {
args = static_cast<uint64_t *>(args_);
}
for (size_t i = 0; i < memInfo.inputParamsAddrInfos.size(); i++) {
auto &addrInfo = memInfo.inputParamsAddrInfos[i];
addrInfo.addr = *(args + memInfo.skipNum + i);
}
if (isDeviceArgs_) {
aclrtFreeHostImplOrigin(args);
}
std::vector<uint8_t> tilingData;
tilingData.resize(memInfo.tilingDataSize);
return DumpInputData(outputPath, memInfo.inputParamsAddrInfos, config) &&
(!GetTilingData(tilingData) || DumpTilingData(outputPath, tilingData, config));
}
ArgsContextSP ArgsRawContext::Clone(void) const
{
return std::make_shared<ArgsRawContext>(*this);
}
void ArgsRawContext::FreeArgs()
{
if (needFree_ && args_ != nullptr) {
aclrtFreeImplOrigin(args_);
args_ = nullptr;
needFree_ = false;
}
}
ArgsRawContext::~ArgsRawContext()
{
FreeArgs();
}