* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_RUNTIME_DEPLOY_EXECUTOR_SCHED_TASK_INFO_H
#define AIR_RUNTIME_DEPLOY_EXECUTOR_SCHED_TASK_INFO_H
#include <cstdint>
#include "ge/ge_api_types.h"
#include "framework/common/ge_types.h"
#include "framework/common/runtime_tensor_desc.h"
#include "acl/acl.h"
namespace ge {
struct AicpuNotifyKernelArgs {
volatile uint32_t notify_id;
volatile uint16_t ret_code;
};
struct StreamRepeatKernelArgs {
uint32_t model_id;
uint32_t stream_id;
};
struct BatchQueueInfoKernelArgs {
uint32_t queue_num;
uint32_t align_interval;
uint64_t align_offsets_addr;
uint64_t queue_ids_addr;
uint64_t mbuf_addrs_addr;
};
struct QueueInfoKernelArgs {
uint32_t queue_id;
uint64_t mbuf_addr;
};
struct AddrMapInfoKernelArgs {
uint32_t addr_num{0U};
uint64_t src_addr_list{0UL};
uint64_t dst_addr_list{0UL};
};
struct MarkStepKernelArgs {
uint32_t group_total_count{1U};
uint32_t group_index{0U};
uint32_t group_policy{0U};
uint64_t step_id_addr{0UL};
uint64_t reserved[32]{0UL};
char_t dump_step[1024]{'\0'};
};
#pragma pack(push, 1)
struct PrepareDynamicInputOutputKernelArgs {
uint32_t inputs_num;
uint32_t outputs_num;
uint64_t input_dynamic_flags_addr;
uint64_t output_tensor_sizes_addr;
uint64_t input_mbuf_addrs_addr;
uint64_t output_mbuf_addrs_addr;
uint64_t input_fusion_offsets_addr;
uint64_t req_msg_mbuf_addr;
};
struct GatherDequeueParam {
uint32_t input_nums;
int32_t inputs_align_timeout;
uint32_t inputs_align_max_cache_num;
uint32_t inputs_align_drop_out;
uint64_t queue_ids_addr;
uint64_t mbuf_addrs_addr;
uint64_t queue_device_ids_addr;
uint64_t queue_device_type_addr;
};
struct PostprocessDynamicOutputKernelArgs {
uint32_t inputs_num;
uint32_t outputs_num;
uint64_t resp_msg_mbuf_addr;
uint64_t input_mbuf_addrs_addr;
uint64_t output_mbuf_addrs_addr;
uint64_t output_dynamic_flags_addr;
uint64_t output_static_tensor_desc_addr;
};
#pragma pack(pop)
class SchedTaskInfo {
public:
explicit SchedTaskInfo(aclrtStream const stream);
virtual ~SchedTaskInfo() noexcept = default;
virtual Status Distribute() = 0;
virtual Status Release();
protected:
Status LaunchCpuKernel(const char *kernel_name) const;
void *args_ = nullptr;
uint64_t args_size_ = 0UL;
aclrtStream stream_ = nullptr;
private:
SchedTaskInfo &operator=(const SchedTaskInfo &) & = delete;
SchedTaskInfo(const SchedTaskInfo &) = delete;
};
class SchedTaskModelDequeue : public SchedTaskInfo {
public:
explicit SchedTaskModelDequeue(aclrtStream const stream) : SchedTaskInfo(stream) {}
~SchedTaskModelDequeue() override = default;
Status Init(const uint32_t queue_id, uint64_t &mbuf_addr);
Status Distribute() override;
};
class SchedTaskModelEnqueue : public SchedTaskInfo {
public:
explicit SchedTaskModelEnqueue(aclrtStream const stream) : SchedTaskInfo(stream) {}
~SchedTaskModelEnqueue() override = default;
Status Init(const uint32_t queue_id, const uint64_t mbuf_addr);
Status Distribute() override;
};
class SchedTaskModelBatchDequeue : public SchedTaskInfo {
public:
explicit SchedTaskModelBatchDequeue(aclrtStream const stream) : SchedTaskInfo(stream) {}
~SchedTaskModelBatchDequeue() override = default;
Status Init(const std::vector<uint32_t> &queue_ids, const uint32_t align_interval,
const std::vector<uint32_t> &align_offsets, std::vector<uint64_t> &mbuf_addrs);
Status Distribute() override;
};
class SchedTaskModelGatherDequeue : public SchedTaskInfo {
public:
explicit SchedTaskModelGatherDequeue(aclrtStream const stream) : SchedTaskInfo(stream) {}
~SchedTaskModelGatherDequeue() override = default;
Status Init(const std::vector<QueueAttrs> &queues, const InputAlignAttrs &input_align_attrs,
std::vector<uint64_t> &mbuf_addrs);
Status Distribute() override;
};
class SchedTaskPrepareDynamicInputOutput : public SchedTaskInfo {
public:
explicit SchedTaskPrepareDynamicInputOutput(aclrtStream const stream) : SchedTaskInfo(stream) {}
~SchedTaskPrepareDynamicInputOutput() override = default;
Status Init(const std::vector<uint32_t> &input_dynamic_flags, const std::vector<uint64_t> &input_mbuf_addrs,
const std::vector<int32_t> &input_fusion_offsets, const std::vector<int64_t> &output_tensor_sizes,
std::vector<uint64_t> &output_mbuf_addrs, uint64_t &req_msg_mbuf_addr, const bool enable_v2 = false);
Status Distribute() override;
private:
bool enable_v2_ = false;
};
class SchedTaskModelBatchEnqueue : public SchedTaskInfo {
public:
explicit SchedTaskModelBatchEnqueue(aclrtStream const stream) : SchedTaskInfo(stream) {};
~SchedTaskModelBatchEnqueue() override = default;
Status Init(const std::vector<uint32_t> &queue_ids, const std::vector<uint64_t> &mbuf_addrs);
Status Distribute() override;
};
class SchedTaskPostprocessDynamicOutput : public SchedTaskInfo {
public:
explicit SchedTaskPostprocessDynamicOutput(aclrtStream const stream) : SchedTaskInfo(stream) {};
~SchedTaskPostprocessDynamicOutput() override = default;
Status Init(const uint64_t resp_msg_mbuf_addr, const std::vector<uint64_t> &input_mbuf_addrs,
const std::vector<uint64_t> &output_mbuf_addrs, const std::vector<uint32_t> &output_dynamic_flags,
const RuntimeTensorDesc *output_static_tensor_descs, const size_t output_static_tensor_num,
const bool enable_v2 = false);
Status Distribute() override;
private:
bool enable_v2_ = false;
};
class SchedTaskNotifyWait : public SchedTaskInfo {
public:
explicit SchedTaskNotifyWait(aclrtStream const stream) : SchedTaskInfo(stream) {};
~SchedTaskNotifyWait() override = default;
Status Init(const uint32_t notify_id);
Status Distribute() override;
};
class SchedTaskNotifyRecord : public SchedTaskNotifyWait {
public:
explicit SchedTaskNotifyRecord(aclrtStream const stream) : SchedTaskNotifyWait(stream) {};
~SchedTaskNotifyRecord() override = default;
Status Distribute() override;
};
class SchedTaskZeroCopy : public SchedTaskInfo {
public:
explicit SchedTaskZeroCopy(aclrtStream const stream) : SchedTaskInfo(stream) {};
~SchedTaskZeroCopy() override = default;
Status Init(const std::vector<uint64_t> &src_addrs, std::vector<uint64_t> &dst_addrs);
Status Distribute() override;
};
class SchedTaskStreamRepeat : public SchedTaskInfo {
public:
explicit SchedTaskStreamRepeat(aclrtStream const stream) : SchedTaskInfo(stream) {};
~SchedTaskStreamRepeat() override = default;
Status Init(const uint32_t model_id);
Status Distribute() override;
};
class SchedTaskMarkStep : public SchedTaskInfo {
public:
explicit SchedTaskMarkStep(aclrtStream const stream) : SchedTaskInfo(stream) {};
~SchedTaskMarkStep() override = default;
Status Init(const uint32_t group_total_count, const uint32_t group_index, const uint32_t group_policy,
const std::string &dump_step, const uint64_t step_id_addr);
Status Distribute() override;
};
}
#endif