* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef EXECUTOR_GRAPH_LOAD_MODEL_MANAGER_DEPLOY_DATAFLOW_DATA_ALIGNER_H_
#define EXECUTOR_GRAPH_LOAD_MODEL_MANAGER_DEPLOY_DATAFLOW_DATA_ALIGNER_H_
#include <chrono>
#include <queue>
#include <vector>
#include <map>
#include <mutex>
#include "graph/ge_tensor.h"
#include "ge/ge_data_flow_api.h"
#include "common/ge_common/ge_types.h"
#include "dflow/base/deploy/exchange_service.h"
namespace ge {
struct TensorWithHeader {
GeTensor tensor;
ExchangeService::MsgInfo msg_info{};
int8_t user_data[kMaxUserDataSize]{};
};
class AlignCacheData {
public:
explicit AlignCacheData(std::vector<size_t> &cache_nums) : cache_nums_(cache_nums), queue_data_(cache_nums.size()) {}
~AlignCacheData();
AlignCacheData(const AlignCacheData &context) = delete;
AlignCacheData &operator=(const AlignCacheData &context) & = delete;
AlignCacheData(AlignCacheData &&context) noexcept = default;
AlignCacheData &operator=(AlignCacheData &&context) &noexcept = delete;
bool IsComplete() const;
bool IsExpire(const std::chrono::steady_clock::time_point ¤t_time, int32_t timeout) const;
bool IsEmpty() const;
Status Push(size_t idx, TensorWithHeader data);
Status Take(std::vector<GeTensor> &data, DataFlowInfo &info);
private:
std::chrono::steady_clock::time_point start_time_ = std::chrono::steady_clock::now();
std::vector<size_t> &cache_nums_;
std::vector<std::queue<TensorWithHeader>> queue_data_;
};
class DataFlowDataAligner {
public:
using CheckIgnoreTransIdFunc = std::function<bool(uint64_t trans_id)>;
DataFlowDataAligner(const std::vector<uint32_t> &queue_idxes, InputAlignAttrs input_align_attrs,
const CheckIgnoreTransIdFunc &check_ignore_trans_id_func);
~DataFlowDataAligner();
Status PushAndAlignData(uint32_t queue_idx, TensorWithHeader tensor_with_header, std::vector<GeTensor> &output,
DataFlowInfo &info, bool &is_aligned);
uint32_t SelectNextQueueIdx();
Status TryTakeExpiredOrOverLimitData(std::vector<GeTensor> &data, DataFlowInfo &info, bool &has_output);
void ClearCacheByTransId(uint64_t trans_id);
void ClearCache();
const std::vector<uint32_t> &GetQueueIdxes() const {
return queue_idxes_;
}
private:
Status TryTakeExpired(std::vector<GeTensor> &data, DataFlowInfo &info, bool &has_output);
Status TryTakeOverLimit(std::vector<GeTensor> &data, DataFlowInfo &info, bool &has_output);
std::mutex mt_;
const std::vector<uint32_t> queue_idxes_;
std::map<uint32_t, size_t> queue_idx_order_;
const InputAlignAttrs align_attrs_;
CheckIgnoreTransIdFunc check_ignore_trans_id_func_;
std::vector<size_t> cache_nums_;
std::map<std::pair<uint64_t, uint32_t>, AlignCacheData> wait_align_data_;
std::set<uint64_t> exception_trans_id_set_;
};
}
#endif