* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "flow_func_run_context.h"
#include <memory>
#include "common/udf_log.h"
#include "common/common_define.h"
#include "securec.h"
#include "flow_func_helper.h"
#include "flow_func_config_manager.h"
#include "ascend_hal.h"
namespace FlowFunc {
namespace {
constexpr size_t kMaxUserDataSize = 64U;
bool CheckAlignSizeValid(const uint32_t align) {
constexpr size_t kMinAlignSize = 32U;
constexpr size_t kMaxAlignSize = sizeof(RuntimeTensorDesc);
if ((align < kMinAlignSize) || (align > kMaxAlignSize)) {
UDF_LOG_ERROR("alloc failed, as align=%u is out of range[%zu, %zu]", align, kMinAlignSize, kMaxAlignSize);
return false;
}
if (kMaxAlignSize % align != 0) {
UDF_LOG_ERROR("alloc failed, as align=%u must can be divisible by %zu.", align, kMaxAlignSize);
return false;
}
return true;
}
}
std::shared_ptr<FlowMsg> MetaRunContext::AllocTensorMsgWithAlign(const std::vector<int64_t> &shape,
TensorDataType dataType, uint32_t align) {
(void)shape;
(void)dataType;
(void)align;
UDF_LOG_ERROR("Not supported, please implement the AllocTensorMsgWithAlign method.");
return nullptr;
}
std::shared_ptr<FlowMsg> MetaRunContext::AllocTensorListMsg(const std::vector<std::vector<int64_t>> &shapes,
const std::vector<TensorDataType> &dataTypes, uint32_t align) {
(void)shapes;
(void)dataTypes;
(void)align;
UDF_LOG_ERROR("Not supported, please implement the AllocTensorMsgWithAlign method.");
return nullptr;
}
void MetaRunContext::RaiseException(int32_t expCode, uint64_t userContextId) {
(void)expCode;
(void)userContextId;
UDF_LOG_ERROR("Not supported, please implement the RaiseException method.");
}
bool MetaRunContext::GetException(int32_t &expCode, uint64_t &userContextId) {
(void)expCode;
(void)userContextId;
UDF_LOG_ERROR("Not supported, please implement the GetException method.");
return false;
}
std::shared_ptr<FlowMsg> MetaRunContext::AllocRawDataMsg(int64_t size, uint32_t align) {
(void)size;
(void)align;
UDF_LOG_ERROR("Not supported, please implement the AllocRawDataMsg method.");
return std::shared_ptr<FlowMsg>(nullptr);
}
std::shared_ptr<FlowMsg> MetaRunContext::ToFlowMsg(std::shared_ptr<Tensor> tensor) {
(void)tensor;
UDF_LOG_ERROR("Not supported, please implement the ToFlowMsg method.");
return std::shared_ptr<FlowMsg>(nullptr);
}
FlowFuncRunContext::FlowFuncRunContext(
uint32_t device_id, std::shared_ptr<FlowFuncParams> params, WriterCallback writer_call_back, uint32_t processor_id)
: MetaRunContext(), params_(std::move(params)), writer_call_back_(std::move(writer_call_back)), device_id_(device_id),
processor_idx_(processor_id) {
}
int32_t FlowFuncRunContext::SetOutput(uint32_t out_idx, std::shared_ptr<FlowMsg> out_msg) {
if (writer_call_back_ != nullptr) {
return writer_call_back_(out_idx, out_msg);
}
UDF_LOG_ERROR("SetOutput failed for writer_call_back_ is null, instance name[%s]", params_->GetName());
return FLOW_FUNC_FAILED;
}
std::shared_ptr<FlowMsg> FlowFuncRunContext::AllocTensorMsg(const std::vector<int64_t> &shape, TensorDataType data_type) {
return MbufFlowMsg::AllocTensorMsg(shape, data_type, device_id_, input_mbuf_head_);
}
std::shared_ptr<FlowMsg> FlowFuncRunContext::AllocRawDataMsg(int64_t size, uint32_t align) {
return MbufFlowMsg::AllocRawDataMsg(size, device_id_, input_mbuf_head_, align);
}
std::shared_ptr<FlowMsg> FlowFuncRunContext::ToFlowMsg(std::shared_ptr<Tensor> tensor) {
std::shared_ptr<FlowMsg> flow_msg = nullptr;
auto mbuf_tensor = std::dynamic_pointer_cast<MbufTensor>(tensor);
if (mbuf_tensor != nullptr) {
flow_msg = MbufFlowMsg::ToTensorMsg(mbuf_tensor->GetShape(), mbuf_tensor->GetDataType(),
mbuf_tensor->SharedMbuf(), input_mbuf_head_);
}
return flow_msg;
}
std::shared_ptr<FlowMsg> FlowFuncRunContext::AllocTensorListMsg(const std::vector<std::vector<int64_t>> &shapes,
const std::vector<TensorDataType> &data_types, uint32_t align) {
if (!CheckAlignSizeValid(align)) {
UDF_LOG_ERROR("Align size[%u] is invalid.", align);
return nullptr;
}
return MbufFlowMsg::AllocTensorListMsg(shapes, data_types, device_id_, input_mbuf_head_, align);
}
std::shared_ptr<FlowMsg> FlowFuncRunContext::AllocTensorMsgWithAlign(const std::vector<int64_t> &shape,
TensorDataType data_type, uint32_t align) {
if (!CheckAlignSizeValid(align)) {
UDF_LOG_ERROR("Align size[%u] is invalid.", align);
return nullptr;
}
return MbufFlowMsg::AllocTensorMsg(shape, data_type, device_id_, input_mbuf_head_, align);
}
std::shared_ptr<FlowMsg> FlowFuncRunContext::AllocEmptyDataMsg(MsgType msg_type) {
return AllocMbufEmptyDataMsg(msg_type);
}
std::shared_ptr<MbufFlowMsg> FlowFuncRunContext::AllocMbufEmptyDataMsg(MsgType msg_type) const {
if (msg_type == MsgType::MSG_TYPE_TENSOR_DATA) {
return MbufFlowMsg::AllocEmptyTensorMsg(device_id_, input_mbuf_head_);
}
UDF_LOG_ERROR("Unsupported msg type [%u].", static_cast<uint32_t>(msg_type));
return nullptr;
}
int32_t FlowFuncRunContext::RunFlowModel(const char *model_key, const std::vector<std::shared_ptr<FlowMsg>> &input_msgs,
std::vector<std::shared_ptr<FlowMsg>> &output_msgs, int32_t timeout) {
if (model_key == nullptr) {
UDF_LOG_ERROR("param model_key is null, instance name[%s].", params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
auto flow_model = params_->GetFlowModels(model_key);
if (flow_model == nullptr) {
UDF_LOG_ERROR("no flow model found, instance name[%s], model_key=%s.", params_->GetName(), model_key);
return FLOW_FUNC_ERR_PARAM_INVALID;
}
int32_t ret = flow_model->Run(input_msgs, output_msgs, timeout);
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("flow model run failed, instance name[%s], model_key=%s.", params_->GetName(), model_key);
return ret;
}
UDF_LOG_DEBUG("flow model run end, instance name[%s], model_key=%s.", params_->GetName(), model_key);
return FLOW_FUNC_SUCCESS;
}
int32_t FlowFuncRunContext::SetInputMbufHead(const MbufInfo &mbuf_info) {
if (mbuf_info.head_buf == nullptr) {
UDF_LOG_ERROR("The mbuf_info.head_buf is nullptr, instance name[%s].", params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
if (mbuf_info.head_buf_len > kMaxMbufHeadLen) {
UDF_LOG_ERROR("The mbuf_info.head_buf_len[%u] > %u, instance name[%s].",
mbuf_info.head_buf_len, kMaxMbufHeadLen, params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
if (mbuf_info.head_buf_len < sizeof(MbufHeadMsg)) {
UDF_LOG_ERROR("The mbuf_info.head_buf_len[%u] < headMsgLen[%zu], instance name[%s].",
mbuf_info.head_buf_len, sizeof(MbufHeadMsg), params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
const auto ret = memcpy_s(input_mbuf_head_.head_buf, static_cast<size_t>(kMaxMbufHeadLen), mbuf_info.head_buf,
static_cast<size_t>(mbuf_info.head_buf_len));
if (ret != EOK) {
UDF_LOG_ERROR("Failed to memcpy mbuf priv info, ret[%d].", ret);
return FLOW_FUNC_FAILED;
}
input_mbuf_head_.head_buf_len = mbuf_info.head_buf_len;
return FLOW_FUNC_SUCCESS;
}
int32_t FlowFuncRunContext::CheckParamsForUserData(const void *data, size_t size, size_t offset) const {
if (data == nullptr) {
UDF_LOG_ERROR("The data is nullptr.");
return FLOW_FUNC_ERR_PARAM_INVALID;
}
if (size == 0U) {
UDF_LOG_ERROR("The size is 0, should in (0, 64].");
return FLOW_FUNC_ERR_PARAM_INVALID;
}
if ((offset >= kMaxUserDataSize) || ((kMaxUserDataSize - offset) < size)) {
UDF_LOG_ERROR("The size + offset need <= %zu, but size = %zu, offset = %zu.",
kMaxUserDataSize, size, offset);
return FLOW_FUNC_ERR_PARAM_INVALID;
}
if (input_mbuf_head_.head_buf_len == 0U) {
UDF_LOG_ERROR("The context mbuf head is invalid.");
return FLOW_FUNC_FAILED;
}
if (input_mbuf_head_.head_buf_len < kMaxUserDataSize) {
UDF_LOG_ERROR("The mbuf head size[%u] < max user data size[%zu].",
input_mbuf_head_.head_buf_len, kMaxUserDataSize);
return FLOW_FUNC_ERR_MEM_BUF_ERROR;
}
return FLOW_FUNC_SUCCESS;
}
int32_t FlowFuncRunContext::GetUserData(void *data, size_t size, size_t offset) const {
const auto ret = CheckParamsForUserData(data, size, offset);
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("Failed to get user data, ret[%d].", ret);
return ret;
}
const auto cpy_ret = memcpy_s(data, size, input_mbuf_head_.head_buf + offset, size);
if (cpy_ret != EOK) {
UDF_LOG_ERROR("Failed to get user data, memcpy_s error, size[%zu], offset[%zu], memcpy_s ret[%d].",
size, offset, cpy_ret);
return FLOW_FUNC_FAILED;
}
UDF_LOG_DEBUG("Success to get user data, size[%zu], offset[%zu].", size, offset);
return FLOW_FUNC_SUCCESS;
}
int32_t FlowFuncRunContext::SetOutput(uint32_t out_idx, std::shared_ptr<FlowMsg> out_msg, const OutOptions &options) {
return SetMultiOutputs(out_idx, {out_msg}, options);
}
int32_t FlowFuncRunContext::SetMultiOutputs(uint32_t out_idx, const std::vector<std::shared_ptr<FlowMsg>> &out_msgs,
const OutOptions &options) {
std::vector<std::shared_ptr<FlowMsg>> after_filter_out_msgs;
int32_t ret = BalanceOptionFilter(options, out_msgs, after_filter_out_msgs);
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("call balance option filter failed, instance name[%s], out_idx=%u.", params_->GetName(), out_idx);
return ret;
}
UDF_LOG_INFO("after balance option filter out msg size=%zu, instance name[%s]", after_filter_out_msgs.size(),
params_->GetName());
for (size_t i = 0; i < after_filter_out_msgs.size(); ++i) {
int32_t sendRet = SetOutput(out_idx, after_filter_out_msgs[i]);
if (sendRet != FLOW_FUNC_SUCCESS) {
ret = sendRet;
UDF_LOG_ERROR("set output, ret=%d, out_idx=%u, index=%zu", sendRet, out_idx, i);
}
}
return ret;
}
int32_t FlowFuncRunContext::CheckAffinityPolicy(AffinityPolicy policy) const {
if (params_->IsBalanceScatter()) {
if (policy != AffinityPolicy::NO_AFFINITY) {
UDF_LOG_ERROR("balance scatter node only support NO_AFFINITY(%d) policy but (%d), instance name[%s].",
static_cast<int32_t>(AffinityPolicy::NO_AFFINITY),
static_cast<int32_t>(policy), params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
} else if (params_->IsBalanceGather()) {
if (policy == AffinityPolicy::NO_AFFINITY) {
UDF_LOG_ERROR("balance gather node not support NO_AFFINITY(%d) policy, instance name[%s].",
static_cast<int32_t>(AffinityPolicy::NO_AFFINITY), params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
} else {
UDF_LOG_ERROR("Can't use balance config, as node is not balance node, instance name[%s].", params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
return FLOW_FUNC_SUCCESS;
}
int32_t FlowFuncRunContext::BalanceOptionFilter(const OutOptions &options,
const std::vector<std::shared_ptr<FlowMsg>> &out_msgs,
std::vector<std::shared_ptr<FlowMsg>> &after_filter_out_msgs) const {
const auto *balance_config = options.GetBalanceConfig();
if (balance_config == nullptr) {
UDF_LOG_INFO("balance config is not exits, instance name[%s].", params_->GetName());
after_filter_out_msgs = out_msgs;
return FLOW_FUNC_SUCCESS;
}
if (!FlowFuncHelper::IsBalanceConfigValid(*balance_config)) {
UDF_LOG_ERROR("balance config is invalid, instance name[%s].", params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
int32_t check_ret = CheckAffinityPolicy(balance_config->GetAffinityPolicy());
if (check_ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("Check AffinityPolicy[%d] failed, instance name[%s].",
static_cast<int32_t>(AffinityPolicy::NO_AFFINITY), params_->GetName());
return check_ret;
}
const auto &data_pos = balance_config->GetDataPos();
if (data_pos.size() != out_msgs.size()) {
UDF_LOG_ERROR("send flow msg but data pos size(%zu) != out_msgs size(%zu), instance name[%s].",
data_pos.size(), out_msgs.size(), params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
const auto &balance_weight = balance_config->GetBalanceWeight();
for (size_t i = 0; i < out_msgs.size(); ++i) {
const auto &pos = data_pos[i];
const auto &out_msg = out_msgs[i];
auto mbuf_flow_msg = std::dynamic_pointer_cast<MbufFlowMsg>(out_msg);
if (mbuf_flow_msg == nullptr) {
UDF_LOG_ERROR("not support custom define flow msg now, instance name[%s].", params_->GetName());
return FLOW_FUNC_ERR_PARAM_INVALID;
}
uint32_t data_label = 0;
uint32_t route_label = 0;
FlowFuncHelper::CalcRouteLabelAndDataLabel(*balance_config, pos, data_label, route_label);
UDF_LOG_INFO("change out_msg[%zu] on pos[%d, %d] route_label=%u, data_label=%u, "
"matrix row_num=%d, col_num=%d, policy=%d",
i, pos.first, pos.second, route_label, data_label, balance_weight.rowNum, balance_weight.colNum,
static_cast<int32_t>(balance_config->GetAffinityPolicy()));
mbuf_flow_msg->SetRouteLabel(route_label);
mbuf_flow_msg->SetDataLabel(data_label);
}
after_filter_out_msgs = out_msgs;
return FLOW_FUNC_SUCCESS;
}
void FlowFuncRunContext::RaiseException(int32_t exp_code, uint64_t user_context_id) {
UDF_LOG_INFO("Raise and report exception event begin, exp_code=%d, user_context_id=%lu.",
exp_code, user_context_id);
const MbufHeadMsg *head_msg = reinterpret_cast<MbufHeadMsg *>(input_mbuf_head_.head_buf +
(input_mbuf_head_.head_buf_len - sizeof(MbufHeadMsg)));
uint64_t current_trans_id = head_msg->trans_id;
UDF_LOG_DEBUG("Get trans_id=%lu from input mbuf head.", current_trans_id);
std::unique_lock<std::mutex> guard(raise_mutex_);
if (trans_id_to_exception_raised_.find(current_trans_id) != trans_id_to_exception_raised_.end()) {
UDF_LOG_WARN("TransId[%lu] is already raised. Skip to raise again.", current_trans_id);
return;
}
event_summary event_info_summary = {};
event_info_summary.pid = getpid();
event_info_summary.event_id = static_cast<EVENT_ID>(UdfEvent::kEventIdRaiseException);
event_info_summary.subevent_id = processor_idx_;
UdfExceptionInfo report_exception_info = {};
report_exception_info.trans_id = current_trans_id;
report_exception_info.exp_code = exp_code;
report_exception_info.user_context_id = user_context_id;
report_exception_info.exp_context_size = input_mbuf_head_.head_buf_len;
const auto mem_ret = memcpy_s(report_exception_info.exp_context, kMaxMbufHeadLen,
input_mbuf_head_.head_buf, input_mbuf_head_.head_buf_len);
if (mem_ret != EOK) {
UDF_LOG_ERROR("[RaiseException] memcpy_s failed . ret=%d, trans_id[%lu], exceptionCode[%d], user_context_id[%lu]",
static_cast<int32_t>(mem_ret), current_trans_id, exp_code, user_context_id);
return;
}
trans_id_to_exception_raised_[current_trans_id] = report_exception_info;
event_info_summary.msg = reinterpret_cast<char *>(¤t_trans_id);
event_info_summary.msg_len = static_cast<uint32_t>(sizeof(current_trans_id));
event_info_summary.dst_engine = FlowFuncConfigManager::GetConfig()->IsRunOnAiCpu() ? ACPU_LOCAL : CCPU_LOCAL;
event_info_summary.grp_id = FlowFuncConfigManager::GetConfig()->GetMainSchedGroupId();
drvError_t ret = halEschedSubmitEvent(params_->GetDeviceId(), &event_info_summary);
if (ret != static_cast<int32_t>(DRV_ERROR_NONE)) {
UDF_LOG_WARN("Send raise exception event failed, ret=%d, trans_id[%lu], exceptionCode[%d], user_context_id[%lu]",
static_cast<int32_t>(ret), current_trans_id, exp_code, user_context_id);
} else {
UDF_LOG_INFO("Send raise exception success. trans_id[%lu], exceptionCode[%d], user_context_id[%lu]",
current_trans_id, exp_code, user_context_id);
}
}
bool FlowFuncRunContext::GetException(int32_t &exp_code, uint64_t &user_context_id) {
if (!exception_existed_) {
UDF_LOG_DEBUG("There is no valid exception during running procedure.");
return false;
}
exp_code = exception_cache_.exp_code;
user_context_id = exception_cache_.user_context_id;
exception_existed_ = false;
return true;
}
int32_t FlowFuncRunContext::GetExceptionByTransId(uint64_t trans_id, UdfExceptionInfo &exception_info) {
std::unique_lock<std::mutex> guard(raise_mutex_);
const auto iter = trans_id_to_exception_raised_.find(trans_id);
if (iter == trans_id_to_exception_raised_.cend()) {
UDF_LOG_ERROR("There is no exception for trans_id[%lu] waiting to report.", trans_id);
return FLOW_FUNC_FAILED;
}
exception_info = iter->second;
return FLOW_FUNC_SUCCESS;
}
bool FlowFuncRunContext::SelfRaiseChkAndDel(uint64_t trans_id) {
std::unique_lock<std::mutex> guard(raise_mutex_);
const auto iter = trans_id_to_exception_raised_.find(trans_id);
if (iter == trans_id_to_exception_raised_.cend()) {
return false;
}
trans_id_to_exception_raised_.erase(iter);
return true;
}
void FlowFuncRunContext::RecordException(const UdfExceptionInfo &exp_info) {
exception_cache_.exp_code = exp_info.exp_code;
exception_cache_.user_context_id = exp_info.user_context_id;
exception_existed_ = true;
}
}