* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "tensor_layout_utils.h"
#include "ascendc_ir/ascendc_ir_core/ascendc_ir_def.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "ascir_ops_utils.h"
#include "ascgen_log.h"
#include "ascir_ops.h"
#include "common_utils.h"
#include "schedule_utils.h"
namespace optimize {
af::Status TensorLayoutUtils::AnalyzeLoadDiscontinuity(const af::AscTensorAttr &attr, DiscontinuityInfo &info) {
const auto &axis = attr.axis;
const auto &repeats = attr.repeats;
const auto &strides = attr.strides;
const auto &vectorized_axis = attr.vectorized_axis;
int32_t total_discontinuous_cnt = 0;
af::Expression expected_stride = af::sym::kSymbolOne;
bool tail_checked = false;
for (auto axis_it = vectorized_axis.rbegin(); axis_it != vectorized_axis.rend(); ++axis_it) {
auto axis_iter = std::find(axis.begin(), axis.end(), *axis_it);
GE_ASSERT(axis_iter != axis.end(), "Failed to find vectorized_axis id [%ld] from axis.", *axis_it);
const size_t idx = std::distance(axis.begin(), axis_iter);
GE_ASSERT_TRUE(idx < strides.size() && idx < repeats.size());
const auto &stride = strides[idx];
const auto &repeat = repeats[idx];
bool is_stride_zero = ascgen_utils::ExpressEq(stride, af::sym::kSymbolZero);
if (!tail_checked && !is_stride_zero) {
tail_checked = true;
if (!ascgen_utils::ExpressEq(stride, af::sym::kSymbolOne)) {
info.is_tail_axis_discontinuous = true;
}
}
bool is_repeat_one = ascgen_utils::ExpressEq(repeat, af::sym::kSymbolOne);
if (is_repeat_one || is_stride_zero) {
continue;
}
if (!ascgen_utils::ExpressEq(stride, expected_stride)) {
++total_discontinuous_cnt;
expected_stride = stride;
}
expected_stride = expected_stride * repeat;
}
if (total_discontinuous_cnt > 1) {
info.has_multiple_discontinuities = true;
}
return ge::SUCCESS;
}
}