* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "base_alignment_strategy.h"
#include "common_utils.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "platform/platform_factory.h"
namespace optimize {
bool IsTailAxisTranspose(const af::AscTensorAttr &attr) {
const size_t axis_size = attr.vectorized_axis.size();
if (axis_size <= 1UL) {
return false;
}
auto tail_axis_iter = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[axis_size - 1UL]);
GE_ASSERT_TRUE(tail_axis_iter != attr.axis.end(), "Cannot find vectorized axis [%ld], axis attr may be invalid.",
attr.vectorized_axis[axis_size - 1UL]);
const size_t tail_index = std::distance(attr.axis.begin(), tail_axis_iter);
if (af::SymbolicUtils::StaticCheckEq(attr.strides[tail_index], af::sym::kSymbolZero) != af::TriBool::kTrue) {
return false;
}
for (size_t i = static_cast<size_t>(axis_size - 1UL); i > 0UL; --i) {
auto iter = std::find(attr.axis.begin(), attr.axis.end(), attr.vectorized_axis[i]);
GE_ASSERT_TRUE(tail_axis_iter != attr.axis.end(), "Cannot find vectorized axis [%ld], axis attr may be invalid.",
attr.vectorized_axis[i]);
const size_t index = std::distance(attr.axis.begin(), iter);
if (af::SymbolicUtils::StaticCheckEq(attr.strides[index], af::sym::kSymbolZero) != af::TriBool::kTrue) {
return true;
}
}
for (auto id = attr.vectorized_axis.rbegin(); id != attr.vectorized_axis.rend(); ++id) {
auto iter = std::find(attr.axis.begin(), attr.axis.end(), *id);
GE_ASSERT_TRUE(iter != attr.axis.end(), "Cannot find vectorized axis [%ld], axis attr may be invalid.", *id);
const size_t index = std::distance(attr.axis.begin(), iter);
if (af::SymbolicUtils::StaticCheckEq(attr.strides[index], af::sym::kSymbolZero) == af::TriBool::kTrue) {
continue;
}
return af::SymbolicUtils::StaticCheckNe(attr.strides[index], af::sym::kSymbolOne) == af::TriBool::kTrue;
}
return false;
}
bool IsHasReduceNode(const af::AscNodePtr &node) {
std::set<af::Node *> visited_nodes;
std::queue<af::Node *> node_queue;
node_queue.emplace(node.get());
visited_nodes.emplace(node.get());
while (!node_queue.empty()) {
const auto top_node = node_queue.front();
node_queue.pop();
for (auto &out_node : top_node->GetOutNodes()) {
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
if ((asc_node == nullptr) || optimize::ScheduleUtils::IsBuffer(asc_node)) {
continue;
}
if (optimize::ScheduleUtils::IsReduce(asc_node)) {
return true;
}
if (visited_nodes.insert(asc_node.get()).second) {
node_queue.push(asc_node.get());
}
}
}
return false;
}
bool IsLoadNeedAlignForReduce(const af::AscNodePtr &node) {
bool is_tail_axis_transpose = IsTailAxisTranspose(node->outputs[0].attr);
if (!is_tail_axis_transpose) {
return false;
}
return IsHasReduceNode(node);
}
* 功能:判断尾轴是否做了transpose
* 判断逻辑:tensor的最后一根有效轴和向量化的最后一根有效轴是否不同
*/
bool IsTailAxisTransposeV2(const af::AscNodePtr &node_load) {
auto attr = node_load->outputs[0].attr;
GE_ASSERT_TRUE(attr.vectorized_axis.size() > 0, "Node %s with unexpected vectorized axis.", node_load->GetNamePtr());
auto last_valid_axis_index = attr.axis.size();
for (auto id = attr.vectorized_axis.rbegin(); id != attr.vectorized_axis.rend(); ++id) {
size_t reverse_pos = std::distance(attr.vectorized_axis.rbegin(), id);
const size_t index = attr.vectorized_axis.size() - reverse_pos - 1;
if (af::SymbolicUtils::StaticCheckEq(attr.vectorized_strides[index], af::sym::kSymbolZero) == af::TriBool::kTrue) {
continue;
}
auto iter = std::find(attr.axis.begin(), attr.axis.end(), *id);
GE_ASSERT_TRUE(iter != attr.axis.end(), "Cannot find vectorized axis [%ld], axis attr may be invalid.", *id);
last_valid_axis_index = std::distance(attr.axis.begin(), iter);
break;
}
if (last_valid_axis_index >= attr.axis.size()) {
GELOGD("Node %s with unexpected vectorized strides.", node_load->GetNamePtr());
return false;
}
for (auto id = attr.axis.begin() + last_valid_axis_index + 1; id != attr.axis.end(); ++id) {
size_t index = std::distance(attr.axis.begin(), id);
if (af::SymbolicUtils::StaticCheckNe(attr.strides[index], af::sym::kSymbolZero) == af::TriBool::kTrue) {
return true;
}
}
return false;
}
* 功能:判断load节点是否需要做对齐
* 判断逻辑:IsTailAxisTransposeV2为A5判断尾轴转置的逻辑;IsVectorizedAxisContinuousInGM和IsLoadNeedAlignForReduce是继承的A3判断是否需要针对reduce做对齐的逻辑
*/
bool IsLoadNeedAlign(const af::AscNodePtr &node_load) {
if (IsHasReduceNode(node_load) &&
(IsTailAxisTransposeV2(node_load) ||
IsTailAxisTranspose(node_load->outputs[0].attr))
) {
return true;
}
return false;
}
ge::Status BaseAlignmentStrategy::DefaultAlignmentInferFunc(const af::AscNodePtr &node) {
if (af::ops::IsOps<af::ascir_op::RemovePad>(node)) {
tensor_to_align_type_[&node->outputs[0].attr] = {AlignmentType::kFixedNotAligned};
return af::SUCCESS;
}
bool has_aligned_input = false;
bool has_fix_unaligned = false;
auto out_type = AlignmentType::kNotAligned;
for (const auto &input : node->inputs()) {
auto alignment_iter = tensor_to_align_type_.find(&input->attr);
if (alignment_iter == tensor_to_align_type_.end()) {
continue;
}
const AlignmentType input_alignment = alignment_iter->second.align_type;
if (input_alignment == AlignmentType::kAligned || input_alignment == AlignmentType::kDiscontinuous) {
has_aligned_input = true;
out_type = std::max(out_type, input_alignment);
} else if (input_alignment == AlignmentType::kFixedNotAligned) {
has_fix_unaligned = true;
}
}
if (has_aligned_input) {
for (const auto &output : node->outputs()) {
tensor_to_align_type_[&output->attr] = {out_type};
}
return BackPropagateAlignment(node, out_type);
}
if (has_fix_unaligned) {
out_type = AlignmentType::kFixedNotAligned;
GE_ASSERT_SUCCESS(BackPropagateFixUnAlignType(node));
}
for (const auto &output : node->outputs()) {
tensor_to_align_type_[&output->attr] = {out_type};
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::BroadcastAlignmentInferFunc(const af::AscNodePtr &node) {
return DefaultAlignmentInferFunc(node);
}
ge::Status BaseAlignmentStrategy::ConcatAlignmentInferFunc(const af::AscNodePtr &node) {
return DefaultAlignmentInferFunc(node);
}
ge::Status BaseAlignmentStrategy::EleWiseAlignmentInferFunc(const af::AscNodePtr &node) {
return DefaultAlignmentInferFunc(node);
}
ge::Status BaseAlignmentStrategy::LoadAlignmentInferFunc(const af::AscNodePtr &node) {
return DefaultAlignmentInferFunc(node);
}
ge::Status BaseAlignmentStrategy::StoreAlignmentInferFunc(const af::AscNodePtr &node) {
return DefaultAlignmentInferFunc(node);
}
ge::Status BaseAlignmentStrategy::SplitAlignmentInferFunc(const af::AscNodePtr &node) {
return DefaultAlignmentInferFunc(node);
}
void BaseAlignmentStrategy::InitAlignmentInferFunc() {
compute_type_to_infer_func_[af::ComputeType::kComputeElewise] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->EleWiseAlignmentInferFunc(node);
};
compute_type_to_infer_func_[af::ComputeType::kComputeBroadcast] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->BroadcastAlignmentInferFunc(node);
};
compute_type_to_infer_func_[af::ComputeType::kComputeConcat] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->ConcatAlignmentInferFunc(node);
};
compute_type_to_infer_func_[af::ComputeType::kComputeLoad] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->LoadAlignmentInferFunc(node);
};
compute_type_to_infer_func_[af::ComputeType::kComputeStore] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->StoreAlignmentInferFunc(node);
};
compute_type_to_infer_func_[af::ComputeType::kComputeReduce] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->ReduceAlignmentInferFunc(node);
};
compute_type_to_infer_func_[af::ComputeType::kComputeSplit] = [this](const std::shared_ptr<af::AscNode> &node) {
return this->SplitAlignmentInferFunc(node);
};
}
ge::Status BaseAlignmentStrategy::ReduceAlignmentInferFunc(const af::AscNodePtr &node) {
for (const auto &output : node->outputs()) {
auto &output_attr = output->attr;
GE_ASSERT_TRUE(!output_attr.vectorized_axis.empty());
const auto axis = output_attr.vectorized_axis.back();
auto axis_tensor_iter = std::find(output_attr.axis.begin(), output_attr.axis.end(), axis);
GE_ASSERT_TRUE(axis_tensor_iter != output_attr.axis.end(),
"Cannot find vectorized axis [%ld] in [%s]'s output tensor.", axis, node->GetNamePtr());
const int64_t axis_index = std::distance(output_attr.axis.begin(), axis_tensor_iter);
const auto &stride = output_attr.strides[axis_index];
if (af::SymbolicUtils::StaticCheckEq(stride, af::sym::kSymbolZero) == af::TriBool::kTrue) {
tensor_to_align_type_[&output->attr] = {AlignmentType::kNotAligned};
} else {
tensor_to_align_type_[&output->attr] = {AlignmentType::kAligned};
}
}
GE_ASSERT_SUCCESS(BackPropagateAlignment(node));
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::AddRemovePadForTailAxisDiscontinuousLoad(ascir::ImplGraph &impl_graph) {
const auto &config = PlatformFactory::GetInstance().GetPlatform()->GetPlatformConfig();
if (config.is_support_compat_mode) {
return af::SUCCESS;
}
bool inserted = false;
for (const auto &node : impl_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (!ScheduleUtils::IsLoad(node) || !ScheduleUtils::IsNeedDiscontinuousAligned(node->outputs[0].attr)) {
continue;
}
af::AscNodePtr remove_pad_node = nullptr;
if (ScheduleUtils::AddRemovePadAfter(impl_graph, node, remove_pad_node) != af::SUCCESS) {
continue;
}
inserted = true;
}
if (inserted) {
GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(impl_graph));
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::CheckIsNoNeedPad(const af::AscNodePtr &node, af::AscTensorAttr &out_attr,
bool &is_no_need_pad) const {
size_t valid_axis_num = 0UL;
bool tail_axis_aligned = false;
for (auto axis_it = out_attr.vectorized_axis.rbegin(); axis_it != out_attr.vectorized_axis.rend(); ++axis_it) {
auto it = std::find(out_attr.axis.begin(), out_attr.axis.end(), *axis_it);
GE_ASSERT_TRUE(it != out_attr.axis.end());
const size_t distance = std::distance(out_attr.axis.begin(), it);
if (axis_it == out_attr.vectorized_axis.rbegin()) {
const auto dtype_size = af::GetSizeByDataType(out_attr.dtype);
GE_ASSERT_TRUE(dtype_size > 0, "Node [%s]'s data type size:[%d] is invalid.", node->GetNamePtr(), dtype_size);
auto repeat = out_attr.repeats[distance];
auto aligned_repeat = af::sym::Align(repeat, align_width_ / dtype_size);
if (af::SymbolicUtils::StaticCheckEq(repeat, aligned_repeat) == af::TriBool::kTrue) {
tail_axis_aligned = true;
GELOGD("Tail repeat [%s] is aligned, no need to add pad for node:[%s].",
af::SymbolicUtils::ToString(repeat).c_str(), node->GetNamePtr());
break;
}
} else if (af::SymbolicUtils::StaticCheckNe(out_attr.strides[distance], af::sym::kSymbolZero) ==
af::TriBool::kTrue) {
valid_axis_num++;
}
}
is_no_need_pad = tail_axis_aligned || valid_axis_num == 0UL;
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::AddPadForAlignmentConflictNode(ascir::ImplGraph &impl_graph) {
bool inserted = false;
for (const auto &node : impl_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
for (size_t i = 0UL; i < node->GetAllOutDataAnchorsSize(); ++i) {
auto &out_attr = node->outputs[i].attr;
if (!tensor_to_align_type_[&out_attr].conflict_with_output) {
continue;
}
bool is_no_need_pad{false};
GE_ASSERT_SUCCESS(CheckIsNoNeedPad(node, out_attr, is_no_need_pad));
if (is_no_need_pad) {
tensor_to_align_type_[&out_attr] = {AlignmentType::kAligned};
continue;
}
const auto &dtype = node->outputs[0].attr.dtype;
std::vector<af::DataType> exp_dtypes{dtype};
auto ret = ScheduleUtils::CallAscirInferDataType<af::ascir_op::Pad>({dtype}, exp_dtypes);
if (ret != af::SUCCESS) {
GELOGW("Pad is unsupported for dtype [%s] in graph [%s], skip this template.",
af::TypeUtils::DataTypeToSerialString(dtype).c_str(), impl_graph.GetName().c_str());
return af::UNSUPPORTED;
}
inserted = true;
const std::string node_name = node->GetName() + "_" + std::to_string(i) + "_pad";
af::ascir_op::Pad pad_op(node_name.c_str());
auto pad_node = impl_graph.AddNode(pad_op);
GE_ASSERT_NOTNULL(pad_node);
pad_node->attr = node->attr;
pad_node->outputs[0].attr = out_attr;
pad_node->attr.api.compute_type = af::ComputeType::kComputeElewise;
pad_node->attr.api.type = af::ApiType::kAPITypeCompute;
pad_node->attr.api.unit = af::ComputeUnit::kUnitVector;
tensor_to_align_type_[&pad_node->outputs[0].attr] = {AlignmentType::kAligned};
auto out_anchor = node->GetOutDataAnchor(static_cast<int32_t>(i));
GE_ASSERT_NOTNULL(out_anchor);
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_ASSERT_SUCCESS(af::GraphUtils::ReplaceEdgeSrc(out_anchor, in_anchor, pad_node->GetOutDataAnchor(0)));
}
GE_ASSERT_SUCCESS(af::GraphUtils::AddEdge(out_anchor, pad_node->GetInDataAnchor(0)));
}
}
if (inserted) {
GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(impl_graph));
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::SetAlignWidth(const ascir::ImplGraph &impl_graph) {
align_width_ = 32U;
for (const auto &node : impl_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
auto dtype = node->outputs[0].attr.dtype;
if ((dtype == ge::DT_FLOAT) && node->attr.api.compute_type == af::ComputeType::kComputeTranspose) {
align_width_ = 64U;
break;
}
}
GELOGD("[%s]'s align width is [%d].", impl_graph.GetName().c_str(), align_width_);
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::AlignVectorizedStrides(ascir::ImplGraph &impl_graph) {
GE_ASSERT_SUCCESS(SetAlignWidth(impl_graph), "Failed to set align width for [%s].", impl_graph.GetName().c_str());
if (compute_type_to_infer_func_.empty()) {
InitAlignmentInferFunc();
}
GE_ASSERT_SUCCESS(AddRemovePadForTailAxisDiscontinuousLoad(impl_graph), "Failed to add removepad for [%s].",
impl_graph.GetName().c_str());
for (const auto &node : impl_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
GE_ASSERT_SUCCESS(InferAlignmentForOneNode(node));
}
auto ret = AddPadForAlignmentConflictNode(impl_graph);
if (ret != af::SUCCESS) {
return ret;
}
for (const auto &node : impl_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
GE_ASSERT_SUCCESS(SetVectorizedStridesForOneNode(node));
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::InferAlignmentForOneNode(const af::AscNodePtr &node) {
if (ScheduleUtils::IsBuffer(node)) {
return af::SUCCESS;
}
GE_ASSERT_TRUE(!node->inputs().empty(), "The inputs of %s(%s) is empty.", node->GetTypePtr(), node->GetNamePtr());
GE_ASSERT_TRUE(!node->outputs().empty(), "The output of %s(%s) is empty.", node->GetTypePtr(), node->GetNamePtr());
af::ComputeType compute_type = node->attr.api.compute_type;
auto it = compute_type_to_infer_func_.find(compute_type);
if (it != compute_type_to_infer_func_.end()) {
GE_ASSERT_SUCCESS(it->second(node), "Failed to infer alignment for node: [%s], type: %d", node->GetNamePtr(),
static_cast<int32_t>(compute_type));
} else {
GE_ASSERT_SUCCESS(DefaultAlignmentInferFunc(node),
"Failed to infer alignment for node: [%s] by default align func, type: %d", node->GetNamePtr(),
static_cast<int32_t>(compute_type));
}
return af::SUCCESS;
}
void BaseAlignmentStrategy::SetAlignInfoForNodeInputs(AlignmentType aligned_type, af::AscNode *node,
std::set<af::Node *> &visited_nodes,
std::queue<af::Node *> &node_queue) {
for (const auto &input : node->inputs()) {
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(input->anchor.GetOwnerNode());
if ((asc_node == nullptr) || ScheduleUtils::IsBuffer(asc_node)) {
continue;
}
auto &align_info = tensor_to_align_type_[&input->attr];
if (align_info.align_type == aligned_type) {
continue;
}
if (align_info.align_type == AlignmentType::kFixedNotAligned) {
align_info.conflict_with_output = true;
} else if (visited_nodes.insert(asc_node.get()).second) {
node_queue.push(asc_node.get());
}
}
}
bool BaseAlignmentStrategy::SetAlignInfoForNodeOutputs(AlignmentType aligned_type, af::AscNode *node,
std::set<af::Node *> &visited_nodes,
std::queue<af::Node *> &node_queue) {
bool alignment_changed = false;
for (const auto &output : node->outputs()) {
auto &align_info = tensor_to_align_type_[&output->attr];
if (align_info.align_type == aligned_type) {
continue;
}
if (align_info.align_type == AlignmentType::kFixedNotAligned) {
align_info.conflict_with_output = true;
continue;
}
GELOGD("Node [%s]'s align type need to be changed.", node->GetNamePtr());
align_info.align_type = aligned_type;
alignment_changed = true;
for (const auto &peer_in : output->anchor.GetPeerInDataAnchorsPtr()) {
if (peer_in == nullptr) {
continue;
}
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(peer_in->GetOwnerNode());
GE_ASSERT_NOTNULL(asc_node);
if (!ScheduleUtils::IsBuffer(asc_node) && visited_nodes.insert(asc_node.get()).second) {
node_queue.push(asc_node.get());
}
}
}
return alignment_changed;
}
ge::Status BaseAlignmentStrategy::BackPropagateAlignment(const af::AscNodePtr &node, AlignmentType aligned_type) {
std::set<af::Node *> visited_nodes;
std::queue<af::Node *> node_queue;
visited_nodes.emplace(node.get());
SetAlignInfoForNodeInputs(aligned_type, node.get(), visited_nodes, node_queue);
while (!node_queue.empty()) {
const auto curr_node = dynamic_cast<af::AscNode *>(node_queue.front());
node_queue.pop();
GE_ASSERT_NOTNULL(curr_node);
bool alignment_changed = SetAlignInfoForNodeOutputs(aligned_type, curr_node, visited_nodes, node_queue);
if (alignment_changed) {
SetAlignInfoForNodeInputs(aligned_type, curr_node, visited_nodes, node_queue);
}
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::BackPropagateFixUnAlignType(const af::AscNodePtr &node) {
std::queue<af::Node *> node_queue;
node_queue.emplace(node.get());
while (!node_queue.empty()) {
const auto curr_node = dynamic_cast<af::AscNode *>(node_queue.front());
node_queue.pop();
GE_ASSERT_NOTNULL(curr_node);
for (const auto &input : node->inputs()) {
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(input->anchor.GetOwnerNode());
if ((asc_node == nullptr) || ScheduleUtils::IsBuffer(asc_node)) {
continue;
}
auto &align_info = tensor_to_align_type_[&input->attr];
if (align_info.align_type == AlignmentType::kFixedNotAligned) {
continue;
}
GE_ASSERT_TRUE(align_info.align_type == AlignmentType::kNotAligned);
align_info.align_type = AlignmentType::kFixedNotAligned;
node_queue.push(asc_node.get());
}
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::SetVectorizedStridesForOneNode(const af::AscNodePtr &node) {
if (ScheduleUtils::IsStore(node)) {
node->outputs[0].attr.vectorized_strides = node->inputs[0].attr.vectorized_strides;
return af::SUCCESS;
}
for (const auto &output : node->outputs()) {
auto &output_attr = output->attr;
auto type_iter = tensor_to_align_type_.find(&output_attr);
GE_ASSERT_TRUE(type_iter != tensor_to_align_type_.end(), "Node [%s] has not visited tensor.", node->GetNamePtr());
GE_ASSERT_SUCCESS(SetVectorizedStridesForTensor(node, output_attr, type_iter->second.align_type));
}
return af::SUCCESS;
}
ge::Status BaseAlignmentStrategy::SetVectorizedStridesForTensor(const af::NodePtr &node, af::AscTensorAttr &output_attr,
const AlignmentType align_type) {
const auto &output_vec_axis = output_attr.vectorized_axis;
GE_ASSERT_TRUE(!output_vec_axis.empty(), "Node [%s] output tensor has empty vectorized_axis.", node->GetNamePtr());
GE_ASSERT_TRUE(output_attr.axis.size() == output_attr.repeats.size(),
"Node [%s] output tensor axis and repeats size mismatch.", node->GetNamePtr());
GE_ASSERT_TRUE(output_attr.axis.size() == output_attr.strides.size(),
"Node [%s] output tensor axis and strides size mismatch.", node->GetNamePtr());
const auto dtype_size = af::GetSizeByDataType(output_attr.dtype);
GE_ASSERT_TRUE(dtype_size > 0, "Node [%s] output tensor dtype is invalid.", node->GetNamePtr());
const uint32_t align_factor = align_width_ / static_cast<uint32_t>(dtype_size);
ascir::SizeExpr size_product = af::sym::kSymbolOne;
std::vector<ascir::SizeExpr> vectorized_strides;
vectorized_strides.reserve(output_vec_axis.size());
for (auto axis_it = output_vec_axis.rbegin(); axis_it != output_vec_axis.rend(); ++axis_it) {
const auto axis = *axis_it;
auto axis_tensor_iter = std::find(output_attr.axis.begin(), output_attr.axis.end(), axis);
GE_ASSERT_TRUE(axis_tensor_iter != output_attr.axis.end(),
"Cannot find vectorized axis [%ld] in [%s]'s output tensor.", axis, node->GetNamePtr());
const int64_t axis_index = std::distance(output_attr.axis.begin(), axis_tensor_iter);
const auto &stride = output_attr.strides[axis_index];
const auto &repeat = output_attr.repeats[axis_index];
ascir::SizeExpr current_stride = af::sym::kSymbolZero;
const bool is_last_axis = (axis_it == output_vec_axis.rbegin());
const bool is_zero_stride =
ascgen_utils::ExpressEq(stride, af::sym::kSymbolZero) || ascgen_utils::ExpressEq(repeat, af::sym::kSymbolOne);
if (is_last_axis && align_type == AlignmentType::kDiscontinuous) {
auto aligned_stride = af::sym::Align(af::sym::kSymbolOne, align_factor);
if (!is_zero_stride) {
current_stride = aligned_stride;
}
size_product = aligned_stride * repeat;
} else {
if (!is_zero_stride) {
current_stride = size_product;
}
if (is_last_axis && align_type == AlignmentType::kAligned) {
size_product = size_product * af::sym::Align(repeat, align_factor);
} else if (!is_zero_stride) {
size_product = size_product * repeat;
}
}
vectorized_strides.push_back(current_stride);
}
std::reverse(vectorized_strides.begin(), vectorized_strides.end());
output_attr.vectorized_strides = vectorized_strides;
return af::SUCCESS;
}
}