* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file dev_start_args.h
* \brief
*/
#pragma once
#include <thread>
#include "machine/utils/dynamic/dev_encode_program.h"
#include "machine/utils/dynamic/device_task.h"
namespace npu::tile_fwk::dynamic {
const uint32_t DUMP_INDEX_SIZE_2 = 2;
const uint32_t DUMP_INDEX_SIZE_4 = 4;
struct DevInputSymbol {
int64_t value;
};
struct DeviceRuntimeDataDesc {
DeviceTaskCtrl* taskCtrlPool{nullptr};
DeviceTaskCtrlQueue* taskQueueList{nullptr};
uint64_t generalAddr;
uint64_t dynamicCellMatchAddr;
uint64_t stitchPoolAddr;
};
struct DevCtrlState {
uint32_t schAicpuNum{MAX_SCHEDULE_AICPU_NUM};
uint32_t taskCtrlIndex{0};
};
#define CTRL_THREAD_INDEX 0
struct DevScheState {
std::atomic<int> threadIdx{0};
std::atomic<int> finished{0};
};
struct DevStartArgs : DevStartArgsBase {
uint64_t contextWorkspaceAddr;
uint64_t contextWorkspaceSize;
DevAscendProgram* devProg;
DevInputSymbol* inputSymbolList;
uint64_t inputSymbolSize;
const void* controlFlowEntry{nullptr};
DeviceRuntimeDataDesc deviceRuntimeDataDesc;
DevCtrlState devCtrlState;
DevScheState devScheState;
void InitProgram(DevAscendProgram* prog, uint64_t base)
{
devProg = prog;
deviceRuntimeDataDesc.taskCtrlPool =
reinterpret_cast<DeviceTaskCtrl*>(base + devProg->GetDeviceRuntimeOffset().taskCtrlPoolOffset);
deviceRuntimeDataDesc.taskQueueList =
reinterpret_cast<DeviceTaskCtrlQueue*>(base + devProg->GetDeviceRuntimeOffset().taskQueueOffset);
deviceRuntimeDataDesc.generalAddr = base + devProg->GetDeviceRuntimeOffset().generalOffset;
deviceRuntimeDataDesc.dynamicCellMatchAddr = devProg->devArgs.dynamicCellMatchAddr;
deviceRuntimeDataDesc.stitchPoolAddr = base + devProg->GetDeviceRuntimeOffset().stitchPoolOffset;
}
public:
void InitWorkspace(DevAscendProgram* tDevProg, void* workspace)
{
contextWorkspaceAddr = reinterpret_cast<uint64_t>(workspace);
devProg = tDevProg;
inputSymbolList = nullptr;
inputSymbolSize = 0;
}
public:
template <typename T>
const T& At(const DevLocalVector<T>& localvec, int index) const
{
return *reinterpret_cast<const T*>(reinterpret_cast<const uint8_t*>(this) + localvec.Offset(index));
}
template <typename T>
T& At(const DevLocalVector<T>& localvec, int index)
{
return *reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(this) + localvec.Offset(index));
}
int GetInputTensorSize() const { return inputTensorSize; }
const DevTensorData& GetInputTensor(int index) const { return devTensorList[index]; }
DevTensorData& GetInputTensor(int index) { return devTensorList[index]; }
int GetOutputTensorSize() const { return outputTensorSize; }
const DevTensorData& GetOutputTensor(int index) const { return devTensorList[index + inputTensorSize]; }
DevTensorData& GetOutputTensor(int index) { return devTensorList[index + inputTensorSize]; }
int GetInputSymbolSize() const { return inputSymbolSize; }
const DevInputSymbol& GetInputSymbol(int index) const { return inputSymbolList[index]; }
DevInputSymbol& GetInputSymbol(int index) { return inputSymbolList[index]; }
std::string Dump(int indent = 0) const
{
std::string INDENTINNER(indent + DUMP_INDEX_SIZE_2, ' ');
std::string INDENTINNERINNER(indent + DUMP_INDEX_SIZE_4, ' ');
std::ostringstream oss;
oss << "DevStartArgs {"
<< "\n";
for (int i = 0; i < GetInputTensorSize(); i++) {
const DevTensorData& input = GetInputTensor(i);
oss << INDENTINNER << "#input-" << i << ": #address:" << AddressDescriptor::DumpAddress(input.address);
oss << " #shape:[";
for (int j = 0; j < input.shape.dimSize; j++) {
oss << Delim(j != 0, ",");
oss << input.shape.dim[j];
}
oss << "]\n";
}
for (int i = 0; i < GetOutputTensorSize(); i++) {
const DevTensorData& output = GetOutputTensor(i);
oss << INDENTINNER << "#output-" << i << ": #address:" << AddressDescriptor::DumpAddress(output.address);
oss << " #shape:[";
for (int j = 0; j < output.shape.dimSize; j++) {
oss << Delim(j != 0, ",");
oss << output.shape.dim[j];
}
oss << "]\n";
}
oss << INDENTINNER << "#workspaceAddr:" << AddressDescriptor::DumpAddress(contextWorkspaceAddr) << "\n";
oss << INDENTINNER << "#tensorMemBudget:" << devProg->memBudget.tensor.Total() << "\n";
oss << INDENTINNER << "#metadataMemBudget:" << devProg->memBudget.metadata.Total() << "\n";
oss << INDENTINNER << "#devProg:" << AddressDescriptor::DumpAddress(reinterpret_cast<uintdevptr_t>(devProg))
<< "\n";
oss << "}";
return oss.str();
}
static std::unordered_map<std::string, SymbolHandlerId> symbolIndexDict;
};
static_assert(sizeof(DevStartArgs) < DEV_ARGS_SIZE, "dev start args is too large");
static inline void RuntimeYield(uint64_t microseconds = 0)
{
std::this_thread::sleep_for(std::chrono::microseconds(microseconds));
}
#define DEFAULT_RUNTIME_DATA_RING_BUFFER_COUNT 4
struct RuntimeDataRingBufferHead {
public:
void Initialize(uint64_t runtimeDataSize, uint64_t runtimeDataCount, ArchInfo arch = ArchInfo::DAV_2201)
{
runtimeDataSize_ = GetAlignedSize(runtimeDataSize);
runtimeDataCount_ = runtimeDataCount;
archInfo_ = arch;
indexFinished_ = 0;
indexPending_ = 0;
}
bool Full() const { return indexFinished_ + runtimeDataCount_ <= indexPending_; }
bool Empty() const { return indexFinished_ == indexPending_; }
void AllocateWait()
{
TIMEOUT_CHECK_INIT_WARN_ONLY(archInfo_);
while (Full()) {
RuntimeYield();
__PYPTO_TIMEOUT_CHECK_WARN_ONLY(
"#ringbuffer.alloc: AllocateWait, ring buffer full.");
}
}
uint8_t* Allocate()
{
AllocateWait();
uint64_t index = ++indexPending_;
return GetRuntimeData(index);
}
uint8_t* AllocatePrepare()
{
AllocateWait();
return GetRuntimeData(indexPending_ + 1);
}
void AllocateSubmit() { ++indexPending_; }
void Deallocate(uint8_t* ptr)
{
uint8_t* nextFree = GetRuntimeData(indexFinished_ + 1);
ASSERT(DevCommonErr::PARAM_CHECK_FAILED, nextFree == ptr);
indexFinished_ += 1;
}
uint64_t GetRuntimeDataSize() { return runtimeDataSize_; }
uint64_t GetRuntimeDataCount() { return runtimeDataCount_; }
uint64_t GetIndexFinished() { return indexFinished_; }
uint64_t GetIndexPending() { return indexPending_; }
uint64_t GetIndexCurrent() { return indexFinished_ + 1; }
uint64_t GetIndexPendingIndex() { return GetIndexPending() % GetRuntimeDataCount(); }
uint8_t* GetRuntimeData(uint64_t index) { return &data_[runtimeDataSize_ * (index % runtimeDataCount_)]; }
uint8_t* GetRuntimeData() { return &data_[0]; }
uint8_t* GetRuntimeDataCurrent() { return GetRuntimeData(GetIndexCurrent()); }
uint8_t* GetRuntimeDataPending() { return GetRuntimeData(GetIndexPending()); }
static constexpr int AlignSize = 0x10;
static constexpr uint64_t GetAlignedSize(uint64_t size) { return (size + AlignSize - 1) & ~(AlignSize - 1); }
static constexpr uint64_t GetRingBufferSize(uint64_t runtimeDataSize, uint64_t runtimeDataCount)
{
return sizeof(RuntimeDataRingBufferHead) + GetAlignedSize(runtimeDataSize) * runtimeDataCount;
}
private:
uint64_t runtimeDataSize_;
uint64_t runtimeDataCount_;
ArchInfo archInfo_{ArchInfo::DAV_2201};
std::atomic<uint64_t> indexFinished_;
std::atomic<uint64_t> indexPending_;
unsigned char data_[0];
};
}