* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_SINGLE_OP_SINGLE_OP_IMPL_H_
#define GE_SINGLE_OP_SINGLE_OP_IMPL_H_
#include "graph/utils/object_pool.h"
#include "single_op/task/op_task.h"
#include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/executor/hybrid_model_rt_v1_executor.h"
#include "common/profiling_definitions.h"
namespace ge {
class SingleOpImpl {
public:
SingleOpImpl(StreamResource *const stream_res, std::mutex *const stream_mutex, aclrtStream const stream);
~SingleOpImpl() = default;
Status ExecuteAsync(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
int64_t GetProfilingNodeIndex() const noexcept;
Status MallocOnExecute();
const uint8_t *GetMemoryBase() const;
void FreeAllocatedMem();
private:
Status ValidateArgs(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
Status UpdateArgs(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
Status GetArgs(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
bool CheckHostMemInputOptimization(const std::vector<DataBuffer> &input_buffers) const;
friend class SingleOpModel;
StreamResource *stream_resource_;
std::mutex *stream_mutex_;
aclrtStream stream_;
std::vector<const void *> input_addr_list_;
std::vector<size_t> input_sizes_;
std::vector<const void *> output_addr_list_;
std::vector<size_t> output_sizes_;
std::vector<uintptr_t> args_;
std::vector<std::unique_ptr<OpTask>> tasks_;
std::vector<std::vector<uintptr_t *>> arg_table_;
std::unique_ptr<SingleOpModelParam> model_param_;
std::vector<GeTensorDesc> inputs_desc_;
int64_t profiling_node_type_index_{gert::profiling::kUnknownName};
ComputeGraphPtr root_graph_ = nullptr;
ge::MemBlock *allocated_mem_{nullptr};
};
class DynamicSingleOpImpl {
public:
DynamicSingleOpImpl(ObjectPool<GeTensor> *const tensor_pool, const uintptr_t resource_id,
std::mutex *const stream_mutex, aclrtStream const stream);
~DynamicSingleOpImpl() = default;
Status ExecuteAsync(const std::vector <GeTensorDesc> &input_desc,
const std::vector <DataBuffer> &input_buffers,
std::vector <GeTensorDesc> &output_desc,
std::vector <DataBuffer> &output_buffers);
int64_t GetProfilingNodeIndex() const noexcept;
private:
friend class SingleOpModel;
Status ValidateParams(const std::vector <GeTensorDesc> &input_desc,
const std::vector <DataBuffer> &inputs,
const std::vector <GeTensorDesc> &output_desc,
const std::vector <DataBuffer> &outputs) const;
Status SetHostTensorValue(const std::vector <std::pair<size_t, uint64_t>> &inputs_size,
const std::vector <GeTensorDesc> &input_desc,
const std::vector <DataBuffer> &input_buffers);
Status SetHostTensorValue(const std::vector <GeTensorDesc> &input_desc,
const std::vector <DataBuffer> &input_buffers);
bool CheckHostMemInputOptimization(const std::vector <DataBuffer> &input_buffers);
void InjectRuntimeContext();
std::unique_ptr <OpTask> op_task_;
std::unique_ptr <hybrid::HybridModel> hybrid_model_;
std::unique_ptr <hybrid::HybridModelRtV1Executor> hybrid_model_executor_;
std::map <int32_t, int64_t> hostmem_node_id_map_;
std::map <int32_t, std::pair<int32_t, int32_t>> input_node_anchor_map_;
std::vector <NodePtr> node_with_hostmem_;
ObjectPool <GeTensor> *tensor_pool_;
int64_t profiling_node_type_index_ = -1;
uintptr_t resource_id_;
std::mutex *stream_mutex_;
aclrtStream stream_;
size_t num_inputs_ = 0U;
size_t num_outputs_ = 0U;
ComputeGraphPtr compute_graph_;
std::queue <std::unique_ptr<GeTensor>> shared_tensors_;
RuntimeInferenceContext runtime_context_;
};
}
#endif