* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*
* The code snippet comes from Ascend project.
*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef METADEF_CXX_INC_EXE_GRAPH_TILING_CONTEXT_H_
#define METADEF_CXX_INC_EXE_GRAPH_TILING_CONTEXT_H_
#include "storage_shape.h"
#include "tensor.h"
#include "continuous_vector.h"
#include "extended_kernel_context.h"
#include "tiling_data.h"
#include "external/ge_common/ge_api_error_codes.h"
namespace fe {
class PlatFormInfos;
}
namespace gert {
* tiling kernel的context
*/
class TilingContext : public ExtendedKernelContext {
public:
const void *GetCompileInfo() const {
const auto compute_node_info = GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return nullptr;
}
const size_t index = compute_node_info->GetInputsNum() + compute_node_info->GetOutputsNum();
const auto av = GetInput(index);
if (av == nullptr) {
return nullptr;
}
return av->GetValue<void *>();
}
* 获取CompileInfo
* @tparam T CompileInfo的类型
* @return CompileInfo的指针
*/
template<typename T>
const T *GetCompileInfo() const {
return reinterpret_cast<const T *>(GetCompileInfo());
}
* 获取输入shape,输入shape中包含了原始shape与运行时shape
* @param index 输入index
* @return 输入shape指针,index非法时返回空指针
*/
const StorageShape *GetInputShape(const size_t index) const {
const auto compute_node_info = GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return nullptr;
}
if (index >= compute_node_info->GetInputsNum()) {
return nullptr;
}
return GetInputPointer<StorageShape>(index);
}
* 获取输入tensor
* 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。
* @param index 输入index
* @return 输入tensor指针,index非法时返回空指针
*/
const Tensor *GetInputTensor(const size_t index) const {
return GetInputPointer<Tensor>(index);
}
* 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入tensor指针
* 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。
* @param ir_index IR原型定义中的index
* @return tensor指针,index非法,或该INPUT没有实例化时,返回空指针
*/
const Tensor *GetOptionalInputTensor(const size_t ir_index) const {
return GetDynamicInputPointer<Tensor>(ir_index, 0);
}
* 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入Tensor指针
* 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。
* @param ir_index IR原型定义中的index
* @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2]
* @return Tensor指针,index或relative_index非法时,返回空指针
*/
const Tensor *GetDynamicInputTensor(const size_t ir_index, const size_t relative_index) const {
return GetDynamicInputPointer<Tensor>(ir_index, relative_index);
}
* 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape
* @param ir_index IR原型定义中的index
* @return shape指针,index非法,或该INPUT没有实例化时,返回空指针
*/
const StorageShape *GetOptionalInputShape(const size_t ir_index) const {
return GetDynamicInputPointer<StorageShape>(ir_index, 0);
}
* 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入Tensor指针
* 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。
* @param ir_index IR原型定义中的index
* @return Tensor指针,index非法时,返回空指针
*/
const Tensor *GetRequiredInputTensor(const size_t ir_index) const {
return GetDynamicInputPointer<Tensor>(ir_index, 0);
}
* 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape
* @param ir_index IR原型定义中的index
* @return shape指针,index非法,或该INPUT没有实例化时,返回空指针
*/
const StorageShape *GetRequiredInputShape(const size_t ir_index) const {
return GetDynamicInputPointer<StorageShape>(ir_index, 0);
}
* 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape
* @param ir_index IR原型定义中的index
* @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2]
* @return shape指针,index或relative_index非法时,返回空指针
*/
const StorageShape *GetDynamicInputShape(const size_t ir_index, const size_t relative_index) const {
return GetDynamicInputPointer<StorageShape>(ir_index, relative_index);
}
* 根据输出index,获取输出shape指针,shape中包含了原始shape与运行时shape
* @param index 输出index
* @return 输出shape指针,index非法时,返回空指针
*/
const StorageShape *GetOutputShape(const size_t index) const {
const auto compute_node_info = GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return nullptr;
}
if (index >= compute_node_info->GetOutputsNum()) {
return nullptr;
}
const size_t offset = compute_node_info->GetInputsNum();
return GetInputPointer<StorageShape>(offset + index);
}
* outputs, tiling的outputs以如下顺序排列:
* outputs[0]: tiling-key
* outputs[1]: block-dim
* outputs[2]: atomic-clean-flag
* outputs[3]: tiling-data
* outputs[4]: workspace sizes
* outputs[5]: tiling condition
* outputs[6]: schedule mode
* outputs[7]: local memory size
* outputs[8]: aicpu block-dim
*/
enum TilingOutputIndex : uint32_t {
kOutputTilingKey,
kOutputNumBlocks,
kOutputAtomicCleanFlag,
kOutputTilingData,
kOutputWorkspace,
kOutputTilingCond,
kOutputScheduleMode,
kOutputLocalMemorySize,
kOutputAicpuNumBlocks,
kOutputNum
};
* outputs[0]: fallible tiling condition
*/
enum FallibleTilingOutputIndex : uint32_t {
kTilingStatus = TilingOutputIndex::kOutputNum,
kFallibleOutputNum
};
* 设置tiling key
* @param tiling_key tiling key
* @return 成功时返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetTilingKey(const uint64_t tiling_key) {
const auto p = GetOutputPointer<uint64_t>(kOutputTilingKey);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = tiling_key;
return ge::GRAPH_SUCCESS;
}
* 获取tiling key
* @return tiling key,获取失败时
*/
uint64_t GetTilingKey() const {
const auto p = GetOutputPointer<uint64_t>(kOutputTilingKey);
if (p == nullptr) {
return std::numeric_limits<uint64_t>::max();
}
return *p;
}
* 设置schedule_mode
* @param schedule_mode schedule_mode
* @return 成功时返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetScheduleMode(const uint32_t schedule_mode) {
const auto p = GetOutputPointer<uint32_t>(kOutputScheduleMode);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = schedule_mode;
return ge::GRAPH_SUCCESS;
}
* 获取设置schedule_mode
* @return 设置schedule_mode,获取失败时
*/
uint32_t GetScheduleMode() const {
const auto p = GetOutputPointer<uint32_t>(kOutputScheduleMode);
if (p == nullptr) {
return 0U;
}
return *p;
}
* 设置block dim
* @param num_blocks block dim
* @return 成功时返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetNumBlocks(const uint32_t num_blocks) {
const auto p = GetOutputPointer<uint32_t>(kOutputNumBlocks);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = num_blocks;
return ge::GRAPH_SUCCESS;
}
* 获取block dim
* @return block dim
*/
uint32_t GetNumBlocks() const {
const auto p = GetOutputPointer<uint32_t>(kOutputNumBlocks);
if (p == nullptr) {
return std::numeric_limits<uint32_t>::max();
}
return *p;
}
* 设置aicpu block dim(融合算子使用)
* @param num_blocks block dim
* @return 成功时返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetAicpuNumBlocks(uint32_t num_blocks) {
const auto p = GetOutputPointer<uint32_t>(kOutputAicpuNumBlocks);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = num_blocks;
return ge::GRAPH_SUCCESS;
}
* 获取aicpu block dim(融合算子使用)
* @return block dim
*/
uint32_t GetAicpuNumBlocks() const {
const auto p = GetOutputPointer<uint32_t>(kOutputAicpuNumBlocks);
if (p == nullptr) {
return std::numeric_limits<uint32_t>::max();
}
return *p;
}
* 设置tiling cond
* @param tiling_cond tiling condition
* @return 成功时返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetTilingCond(int32_t tiling_cond) {
const auto p = GetOutputPointer<int32_t>(kOutputTilingCond);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = tiling_cond;
return ge::GRAPH_SUCCESS;
}
* 获取tiling cond
* @return tiling cond:有效的tiling_cond大于等于0,若该值无效返回-1
*/
int32_t GetTilingCond() const {
const auto p = GetOutputPointer<int32_t>(kOutputTilingCond);
if (p == nullptr) {
return -1;
}
return *p;
}
* 设置是否需要atomic clean
* @param atomic true/false代表是否需要做atomic clean
* @return 成功时返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetNeedAtomic(const bool atomic) {
const auto p = GetOutputPointer<bool>(kOutputAtomicCleanFlag);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = atomic;
return ge::GRAPH_SUCCESS;
}
* 获取是否需要atomic clean
* @return true/false
*/
bool NeedAtomic() const {
const auto p = GetOutputPointer<bool>(kOutputAtomicCleanFlag);
if (p == nullptr) {
return false;
}
return *p;
}
* 获取有类型的tiling data指针
* @tparam T tiling data类型,sizeof(T)不可以大于编译结果中指定的最大tiling data长度
* @return tiling data指针,失败时返回空指针
*/
template<typename T>
auto GetTilingData() -> T* {
auto tiling_data = GetRawTilingData();
if (tiling_data == nullptr) {
return nullptr;
}
if (tiling_data->GetCapacity() < sizeof(T)) {
return nullptr;
}
tiling_data->SetDataSize(sizeof(T));
return static_cast<T *>(tiling_data->GetData());
}
* 获取无类型的tiling data码流
* @return tiling data指针,失败时返回空指针
*/
TilingData *GetRawTilingData() {
return *GetOutputPointer<TilingData *>(kOutputTilingData);
}
* 获取workspace sizes指针
* @param workspace_count workspace的个数,传入的workspace个数不可以超过编译时指定的最大workspace个数
* @return workspace sizes指针
*/
size_t *GetWorkspaceSizes(const size_t workspace_count) {
const auto workspace = GetOutputPointer<TypedContinuousVector<size_t>>(kOutputWorkspace);
if (workspace == nullptr) {
return nullptr;
}
if (workspace->SetSize(workspace_count) != ge::SUCCESS) {
return nullptr;
}
return workspace->MutableData();
}
* 获取 workspace 个数
* @return workspace 个数
*/
size_t GetWorkspaceNum() const {
const auto workspace = GetOutputPointer<TypedContinuousVector<size_t>>(kOutputWorkspace);
if (workspace == nullptr) {
return 0U;
}
return workspace->GetSize();
}
* 获取 fe::PlatFormInfos 指针
* @return fe::PlatFormInfos 指针
*/
fe::PlatFormInfos *GetPlatformInfo() const {
const auto compute_node_info = GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return nullptr;
}
const size_t index = compute_node_info->GetInputsNum() + compute_node_info->GetOutputsNum();
const auto av = GetInput(index + 1U);
if (av == nullptr) {
return nullptr;
}
return av->GetValue<fe::PlatFormInfos *>();
}
* 获取 确定性计算变量
* @return int32 变量
*/
int32_t GetDeterministic() const {
const auto compute_node_info = GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return std::numeric_limits<int32_t>::max();
}
const size_t index = compute_node_info->GetInputsNum() + compute_node_info->GetOutputsNum();
const auto av = GetInput(index + 3U);
if (av == nullptr) {
return std::numeric_limits<int32_t>::max();
}
return av->GetValue<int32_t>();
}
* 设置 local memory size, 默认值为0
* @param local_memory_size
* @return 成功返回ge::GRAPH_SUCCESS
*/
ge::graphStatus SetLocalMemorySize(const uint32_t local_memory_size) {
const auto p = GetOutputPointer<uint32_t>(kOutputLocalMemorySize);
if (p == nullptr) {
return ge::GRAPH_FAILED;
}
*p = local_memory_size;
return ge::GRAPH_SUCCESS;
}
* 获取 local memory size,默认值为0
* @return local memory size, 失败返回 std::numeric_limits<uint32_t>::max()
*/
uint32_t GetLocalMemorySize() const {
const auto p = GetOutputPointer<uint32_t>(kOutputLocalMemorySize);
if (p == nullptr) {
return std::numeric_limits<uint32_t>::max();
}
return *p;
}
};
static_assert(std::is_standard_layout<TilingContext>::value, "The class TilingContext must be a POD");
}
#endif