* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef TESTS_UT_COMMON_ASCIR_REDUCE_TEST_HELPERS_H_
#define TESTS_UT_COMMON_ASCIR_REDUCE_TEST_HELPERS_H_
#include <memory>
#include <string>
#include <vector>
#include "ascir_ops.h"
#include "ascir_node_param/ascir_node_param.h"
#include "ascir_node_param/ascir_param_builder.h"
#include "gen_model_info/parser/reduce_specific_params_builder.h"
#include "gen_model_info/parser/tuning_space.h"
namespace ascir_reduce_test_helpers {
struct ReduceTestEnv {
af::AscGraph graph;
ge::Expression s0;
ge::Expression s1;
af::Axis z0;
af::Axis z1;
af::AscNodePtr node;
att::SubAxis axis0;
att::SubAxis axis1;
explicit ReduceTestEnv(const char *reduce_name)
: graph("reduce_test_graph")
{
s0 = graph.CreateSizeVar("s0");
s1 = graph.CreateSizeVar("s1");
z0 = graph.CreateAxis("z0", s0);
z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data x("x", graph);
af::ascir_op::Load load("load");
af::ascir_op::Max reduce_op(reduce_name);
graph.AddNode(load);
graph.AddNode(reduce_op);
load.x = x.y;
reduce_op.x = load.y;
node = graph.FindNode(reduce_name);
node->attr.sched.axis = {z0.id, z1.id};
node->attr.sched.loop_axis = z1.id;
node->inputs[0].attr.dtype = ge::DT_FLOAT16;
node->inputs[0].attr.axis = {z0.id, z1.id};
node->inputs[0].attr.repeats = {s0, s1};
node->outputs[0].attr.dtype = ge::DT_FLOAT16;
node->outputs[0].attr.axis = {z0.id, z1.id};
axis0.id = z0.id;
axis0.name = "z0";
axis0.repeat = s0;
axis1.id = z1.id;
axis1.name = "z1";
axis1.repeat = s1;
}
void SetIoAttrs(const std::vector<ge::Expression> &in_strides,
const std::vector<ge::Expression> &out_repeats,
const std::vector<ge::Expression> &out_strides)
{
node->inputs[0].attr.strides = in_strides;
node->inputs[0].attr.vectorized_axis = {z0.id, z1.id};
node->inputs[0].attr.vectorized_strides = in_strides;
node->outputs[0].attr.repeats = out_repeats;
node->outputs[0].attr.strides = out_strides;
node->outputs[0].attr.vectorized_axis = {z0.id, z1.id};
node->outputs[0].attr.vectorized_strides = out_strides;
}
};
inline att::TensorPtr BuildParserTensor(const std::string &name, const std::vector<att::SubAxis *> &axes,
const std::vector<ge::Expression> &repeats,
const std::vector<ge::Expression> &strides)
{
auto tensor = std::make_shared<att::Tensor>();
tensor->name = name;
tensor->data_type_size = 2U;
tensor->dim_info = axes;
tensor->repeat = repeats;
tensor->stride = strides;
return tensor;
}
inline att::NodeInfo BuildReduceNodeInfo(ReduceTestEnv &env, const std::string &node_name)
{
att::NodeInfo node_info;
node_info.name = node_name;
node_info.node_type = "Max";
node_info.node_ptr = env.node;
node_info.inputs = {BuildParserTensor("x", {&env.axis0, &env.axis1}, {ge::Symbol(8), ge::Symbol(16)},
{ge::Symbol(16), ge::Symbol(1)})};
node_info.outputs = {BuildParserTensor("y", {&env.axis0, &env.axis1}, {ge::Symbol(8), ge::Symbol(1)},
{ge::Symbol(1), ge::Symbol(0)})};
node_info.loop_axes = {&env.axis1};
(void)ascir_param::EnrichAscirGraphNodeParams(env.graph);
(void)att::FillReduceSpecificParams(env.node, node_info);
return node_info;
}
}
#endif