* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "data_aligner.h"
#include <set>
#include "common/inner_error_codes.h"
#include "common/udf_log.h"
namespace FlowFunc {
namespace {
constexpr int64_t kDataNeverExpireTime = -1;
#pragma pack(push, 1)
struct MbufHeadMsg {
uint64_t trans_id;
uint16_t version;
uint16_t msg_type;
int32_t ret_code;
uint64_t start_time;
uint64_t end_time;
uint32_t flags;
uint8_t data_flag;
uint8_t rsv0[3];
int32_t worker_id;
uint32_t step_id;
uint8_t rsv[8];
uint32_t data_label;
uint32_t route_label;
};
#pragma pack(pop)
}
CachedData::~CachedData() {
for (size_t idx = 0; idx < queue_data_.size(); ++idx) {
auto &queue_mbuf = queue_data_[idx];
while (!queue_mbuf.empty()) {
auto mbuf = queue_mbuf.front();
if (mbuf != nullptr) {
halMbufFree(mbuf);
}
queue_mbuf.pop();
--queue_cache_num_[idx];
}
}
queue_data_.clear();
}
bool CachedData::IsExpire(const std::chrono::steady_clock::time_point ¤t_time, int32_t timeout) const {
std::chrono::duration<double> elapsed = current_time - start_time_;
int64_t elapsed_second = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count();
if (elapsed_second >= timeout) {
HICAID_LOG_WARN("data elapsed %ld ms, reach expire time %d ms", elapsed_second, timeout);
return true;
}
return false;
}
bool CachedData::IsComplete() const {
if (queue_data_.empty()) {
return false;
}
for (const auto &queue_buf : queue_data_) {
if (queue_buf.empty()) {
return false;
}
}
return true;
}
bool CachedData::IsEmpty() const {
for (const auto &queue_buf : queue_data_) {
if (!queue_buf.empty()) {
return false;
}
}
return true;
}
void CachedData::Push(size_t idx, Mbuf *buf) {
queue_data_[idx].push(buf);
++queue_cache_num_[idx];
}
void CachedData::Take(std::vector<Mbuf *> &data) {
data.resize(queue_data_.size());
for (size_t i = 0; i < queue_data_.size(); ++i) {
if (queue_data_[i].empty()) {
continue;
}
data[i] = queue_data_[i].front();
queue_data_[i].pop();
--queue_cache_num_[i];
}
}
int32_t DataAligner::GetTransIdAndDataLabel(Mbuf *mbuf, uint64_t &trans_id, uint32_t &data_label) {
if (mbuf == nullptr) {
HICAID_LOG_ERROR("get trans_id failed as mbuf is null.");
return HICAID_ERR_PARAM_INVALID;
}
void *head_buf = nullptr;
uint32_t head_buf_len = 0U;
int32_t drv_ret = halMbufGetPrivInfo(mbuf, &head_buf, &head_buf_len);
if (drv_ret != DRV_ERROR_NONE || head_buf == nullptr) {
HICAID_LOG_ERROR("Failed to get mbuf priv info, ret[%d].", drv_ret);
return HICAID_FAILED;
}
if (head_buf_len < sizeof(MbufHeadMsg)) {
HICAID_LOG_ERROR("mbuf priv info len=%u can't be less than to sizeof(MbufHeadMsg)=%zu.",
head_buf_len, sizeof(MbufHeadMsg));
return HICAID_FAILED;
}
const auto *head_msg = reinterpret_cast<MbufHeadMsg *>(static_cast<uint8_t *>(head_buf) +
(head_buf_len - sizeof(MbufHeadMsg)));
trans_id = head_msg->trans_id;
data_label = head_msg->data_label;
return HICAID_SUCCESS;
}
size_t DataAligner::SelectNextIndex() const {
size_t next_take_idx = 0;
size_t min_cache_size = std::numeric_limits<size_t>::max();
for (size_t idx = 0; idx < queue_cache_num_.size(); ++idx) {
if (min_cache_size > queue_cache_num_[idx]) {
min_cache_size = queue_cache_num_[idx];
next_take_idx = idx;
}
}
return next_take_idx;
}
int32_t DataAligner::PushAndAlignData(size_t idx, Mbuf *mbuf, std::vector<Mbuf *> &aligned_data) {
if (idx >= queue_num_) {
HICAID_LOG_ERROR("save failed as idx=%zu must be less than %zu.", idx, queue_num_);
return HICAID_ERR_PARAM_INVALID;
}
uint64_t trans_id = 0;
uint32_t data_label = 0;
int32_t get_ret = GetTransIdAndDataLabel(mbuf, trans_id, data_label);
if (get_ret != HICAID_SUCCESS) {
HICAID_LOG_ERROR("get trans_id and data label failed, idx=%zu.", idx);
return get_ret;
}
std::unique_lock<std::mutex> guard(mutex_);
if (exception_trans_id_.find(trans_id) != exception_trans_id_.cend()) {
HICAID_LOG_WARN("trans_id=%lu is in exception list, drop it, data_label=%u.", trans_id, data_label);
halMbufFree(mbuf);
return HICAID_SUCCESS;
}
std::pair<uint64_t, uint32_t> trans_id_and_data_label(trans_id, data_label);
auto find_ret = wait_align_data_.find(trans_id_and_data_label);
if (find_ret == wait_align_data_.end()) {
auto insertRet = wait_align_data_.emplace(trans_id_and_data_label, CachedData(queue_cache_num_));
if (!insertRet.second) {
HICAID_LOG_ERROR("emplace trans_id=%lu data_label[%u] failed.", trans_id_and_data_label.first,
trans_id_and_data_label.second);
return HICAID_FAILED;
}
find_ret = insertRet.first;
}
auto &cache_data = find_ret->second;
cache_data.Push(idx, mbuf);
HICAID_LOG_DEBUG("save trans_id[%lu], data_label[%u] to the [%zu]th queue success.", trans_id, data_label, idx);
if (cache_data.IsComplete()) {
HICAID_LOG_INFO("trans_id[%lu], data_label[%u] is complete.", trans_id, data_label);
cache_data.Take(aligned_data);
if (cache_data.IsEmpty()) {
wait_align_data_.erase({trans_id, data_label});
}
}
return HICAID_SUCCESS;
}
void DataAligner::TryTakeExpiredOrOverLimitData(std::vector<Mbuf *> &data) {
std::unique_lock<std::mutex> guard(mutex_);
TryTakeExpired(data);
if (!data.empty()) {
return;
}
TryTakeOverLimit(data);
}
void DataAligner::TryTakeExpired(std::vector<Mbuf *> &data) {
if (wait_align_data_.empty()) {
return;
}
if (align_timeout_ == kDataNeverExpireTime) {
return;
}
const auto current_time = std::chrono::steady_clock::now();
for (auto iter = wait_align_data_.begin(); iter != wait_align_data_.end();) {
if (iter->second.IsExpire(current_time, align_timeout_)) {
if (drop_when_not_align_) {
HICAID_LOG_WARN("data trans_id=%lu, data_label=%u is expire, need drop it.", iter->first.first,
iter->first.second);
iter = wait_align_data_.erase(iter);
continue;
}
iter->second.Take(data);
if (iter->second.IsEmpty()) {
HICAID_LOG_WARN("data trans_id=%lu, data_label=%u is expire, and handle finish.", iter->first.first,
iter->first.second);
(void)wait_align_data_.erase(iter);
} else {
HICAID_LOG_WARN(
"data trans_id=%lu, data_label=%u is expire, but has data left after take, need take next time.",
iter->first.first, iter->first.second);
}
break;
} else {
++iter;
}
}
}
void DataAligner::TryTakeOverLimit(std::vector<Mbuf *> &data) {
if (wait_align_data_.size() <= align_max_cache_num_) {
return;
}
if (!drop_when_not_align_) {
auto begin = wait_align_data_.begin();
begin->second.Take(data);
if (begin->second.IsEmpty()) {
HICAID_LOG_WARN("cache size=%zu is over limit size %u, take trans_id=%lu, data_label=%u finish.",
wait_align_data_.size(), align_max_cache_num_, begin->first.first, begin->first.second);
(void)wait_align_data_.erase(begin);
} else {
HICAID_LOG_WARN("cache size=%zu is over limit size %u, take trans_id=%lu, data_label=%u and has data left, "
"need take next time.",
wait_align_data_.size(), align_max_cache_num_, begin->first.first, begin->first.second);
}
return;
}
while (wait_align_data_.size() > align_max_cache_num_) {
HICAID_LOG_WARN("cache size=%zu is over limit size %u, drop trans_id=%lu, data_label=%u.",
wait_align_data_.size(), align_max_cache_num_, wait_align_data_.begin()->first.first,
wait_align_data_.begin()->first.second);
wait_align_data_.erase(wait_align_data_.begin());
}
}
void DataAligner::AddExceptionTransId(uint64_t trans_id) {
std::unique_lock<std::mutex> guard(mutex_);
(void)exception_trans_id_.emplace(trans_id);
auto iter = wait_align_data_.begin();
size_t drop_cache_cnt = 0;
while (iter != wait_align_data_.cend()) {
if (iter->first.first == trans_id) {
iter = wait_align_data_.erase(iter);
++drop_cache_cnt;
} else {
++iter;
}
}
HICAID_LOG_INFO("add exception trans_id end, trans_id=%lu, total exception trans_id count=%zu, drop_cache_cnt=%zu.",
trans_id, exception_trans_id_.size(), drop_cache_cnt);
}
void DataAligner::DeleteExceptionTransId(uint64_t trans_id) {
std::unique_lock<std::mutex> guard(mutex_);
(void)exception_trans_id_.erase(trans_id);
}
void DataAligner::Reset() {
std::unique_lock<std::mutex> guard(mutex_);
wait_align_data_.clear();
exception_trans_id_.clear();
queue_cache_num_.assign(queue_cache_num_.size(), 0);
}
}