* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ascendc_ir.h"
#include "ascir_ops.h"
#include "ascir_utils.h"
#include "ascir_ops_utils.h"
using namespace std;
using namespace ge;
using namespace af;
using namespace af::ops;
using namespace af::ascir_op;
void LoadConcatStore_BeforeAutofuse(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto zo = graph.CreateAxis("zo", s1 + s2);
auto zo_s_0 = graph.CreateAxis("zo_s_0", Axis::Type::kAxisTypeOriginal, s1, {zo.id}, af::kIdNone);
auto zo_s_1 = graph.CreateAxis("zo_s_1", Axis::Type::kAxisTypeOriginal, s2, {zo.id}, af::kIdNone);
Data x1("x1", graph);
Data x2("x2", graph);
Load load1("load1");
Load load2("load2");
af::ascir_op::Concat concat("concat");
Store store("store");
Output y("y");
x1.attr.sched.axis = {z0.id, zo_s_0.id};
x1.y.dtype = ge::DT_FLOAT;
*x1.y.axis = {z0.id, zo_s_0.id};
*x1.y.repeats = {s0, s1};
*x1.y.strides = {s1, One};
x2.attr.sched.axis = {z0.id, zo_s_1.id};
x2.y.dtype = ge::DT_FLOAT;
*x2.y.axis = {z0.id, zo_s_1.id};
*x2.y.repeats = {s0, s2};
*x2.y.strides = {s2, One};
load1.x = x1.y;
load1.attr.sched.axis = {z0.id, zo_s_0.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, zo_s_0.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, One};
load2.x = x2.y;
load2.attr.sched.axis = {z0.id, zo_s_1.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.axis = {z0.id, zo_s_1.id};
*load2.y.repeats = {s0, s2};
*load2.y.strides = {s2, One};
concat.x = {load1.y, load2.y};
concat.attr.sched.axis = {z0.id, zo.id};
concat.y.dtype = ge::DT_FLOAT;
*concat.y.axis = {z0.id, zo.id};
*concat.y.repeats = {s0, s1 + s2};
*concat.y.strides = {s1+s2, One};
concat.attr.tmp_buffers = {{{af::Symbol(65536), -1}, MemAttr(), 0}};
store.x = concat.y;
store.attr.sched.axis = {z0.id, zo.id};
store.y.dtype = ge::DT_FLOAT;
*store.y.axis = {z0.id, zo.id};
*store.y.repeats = {s0, s1 + s2};
*store.y.strides = {s1+s2, One};
y.x = store.y;
y.attr.sched.axis = {z0.id, zo.id};
y.y.dtype = ge::DT_FLOAT;
*y.y.axis = {z0.id, zo.id};
*y.y.repeats = {s0, s1 + s2};
*y.y.strides = {s1 + s2, One};
}
void LoadConcatStore_BeforeAutofuseConcatInterAxis(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1 + s2);
auto z1_s_0 = graph.CreateAxis("z1_s_0", Axis::Type::kAxisTypeOriginal, s1, {z1.id}, af::kIdNone);
auto z1_s_1 = graph.CreateAxis("z1_s_1", Axis::Type::kAxisTypeOriginal, s2, {z1.id}, af::kIdNone);
auto z2 = graph.CreateAxis("z2", s3);
Data x1("x1", graph);
Data x2("x2", graph);
Load load1("load1");
Load load2("load2");
af::ascir_op::Concat concat("concat");
Store store("store");
Output y("y");
x1.attr.sched.axis = {z0.id, z1_s_0.id, z2.id};
x1.y.dtype = ge::DT_FLOAT;
*x1.y.axis = {z0.id, z1_s_0.id, z2.id};
*x1.y.repeats = {s0, s1, s3};
*x1.y.strides = {s1*s3, s3, One};
x2.attr.sched.axis = {z0.id, z1_s_1.id, z2.id};
x2.y.dtype = ge::DT_FLOAT;
*x2.y.axis = {z0.id, z1_s_1.id, z2.id};
*x2.y.repeats = {s0, s2, s3};
*x2.y.strides = {s2*s3, s3, One};
load1.x = x1.y;
load1.attr.sched.axis = {z0.id, z1_s_0.id, z2.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1_s_0.id, z2.id};
*load1.y.repeats = {s0, s1, s3};
*load1.y.strides = {s1*s3, s3, One};
load2.x = x2.y;
load2.attr.sched.axis = {z0.id, z1_s_1.id, z2.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.axis = {z0.id, z1_s_1.id, z2.id};
*load2.y.repeats = {s0, s2, s3};
*load2.y.strides = {s2*s3, s3, One};
concat.x = {load1.y, load2.y};
concat.attr.sched.axis = {z0.id, z1.id, z2.id};
concat.y.dtype = ge::DT_FLOAT;
*concat.y.axis = {z0.id, z1.id, z2.id};
*concat.y.repeats = {s0, s1 + s2, s3};
*concat.y.strides = {(s1+s2)*s3, s3, One};
concat.attr.tmp_buffers = {{{af::Symbol(16384), -1}, MemAttr(), 0}};
store.x = concat.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id};
store.y.dtype = ge::DT_FLOAT;
*store.y.axis = {z0.id, z1.id, z2.id};
*store.y.repeats = {s0, s1 + s2, s3};
*store.y.strides = {(s1+s2)*s3, s3, One};
y.x = store.y;
y.attr.sched.axis = {z0.id, z1.id, z2.id};
y.y.dtype = ge::DT_FLOAT;
*y.y.axis = {z0.id, z1.id, z2.id};
*y.y.repeats = {s0, s1 + s2, s3};
*y.y.strides = {(s1+s2)*s3, s3, One};
}
void LoadConcatStore_BeforeAutofuse3dLastAxis(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2 + s3);
auto z2_s_0 = graph.CreateAxis("z2_s_0", Axis::Type::kAxisTypeOriginal, s2, {z2.id}, af::kIdNone);
auto z2_s_1 = graph.CreateAxis("z2_s_1", Axis::Type::kAxisTypeOriginal, s3, {z2.id}, af::kIdNone);
Data x1("x1", graph);
Data x2("x2", graph);
Load load1("load1");
Load load2("load2");
af::ascir_op::Concat concat("concat");
Store store("store");
Output y("y");
x1.attr.sched.axis = {z0.id, z1.id, z2_s_0.id};
x1.y.dtype = ge::DT_FLOAT;
*x1.y.axis = {z0.id, z1.id, z2_s_0.id};
*x1.y.repeats = {s0, s1, s2};
*x1.y.strides = {s1*s2, s2, One};
x2.attr.sched.axis = {z0.id, z1.id, z2_s_1.id};
x2.y.dtype = ge::DT_FLOAT;
*x2.y.axis = {z0.id, z1.id, z2_s_1.id};
*x2.y.repeats = {s0, s1, s3};
*x2.y.strides = {s1*s3, s3, One};
load1.x = x1.y;
load1.attr.sched.axis = {z0.id, z1.id, z2_s_0.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2_s_0.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1*s2, s2, One};
load2.x = x2.y;
load2.attr.sched.axis = {z0.id, z1.id, z2_s_1.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.axis = {z0.id, z1.id, z2_s_1.id};
*load2.y.repeats = {s0, s1, s3};
*load2.y.strides = {s1*s3, s3, One};
concat.x = {load1.y, load2.y};
concat.attr.sched.axis = {z0.id, z1.id, z2.id};
concat.y.dtype = ge::DT_FLOAT;
*concat.y.axis = {z0.id, z1.id, z2.id};
*concat.y.repeats = {s0, s1, s2 + s3};
*concat.y.strides = {s1*(s2+s3), s2 + s3, One};
concat.attr.tmp_buffers = {{{af::Symbol(16384), -1}, MemAttr(), 0}};
store.x = concat.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id};
store.y.dtype = ge::DT_FLOAT;
*store.y.axis = {z0.id, z1.id, z2.id};
*store.y.repeats = {s0, s1, s2 + s3};
*store.y.strides = {s1*(s2+s3), s2 + s3, One};
y.x = store.y;
y.attr.sched.axis = {z0.id, z1.id, z2.id};
y.y.dtype = ge::DT_FLOAT;
*y.y.axis = {z0.id, z1.id, z2.id};
*y.y.repeats = {s0, s1, s2 + s3};
*y.y.strides = {s1*(s2+s3), s2 + s3, One};
}
void LoadConcatStore_AfterInferOutput(af::AscGraph &graph) {
auto x1 = graph.FindNode("x1");
x1->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x2 = graph.FindNode("x2");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto load1 = graph.FindNode("load1");
load1->attr.api.compute_type = ComputeType::kComputeLoad;
auto load2 = graph.FindNode("load2");
load2->attr.api.compute_type = ComputeType::kComputeLoad;
auto concat = graph.FindNode("concat");
concat->outputs[0].attr.dtype =(ge::DataType)load1->outputs[0].attr.dtype;
concat->attr.api.compute_type = ComputeType::kComputeConcat;
auto store = graph.FindNode("store");
store->outputs[0].attr.dtype = (ge::DataType)concat->outputs[0].attr.dtype;
store->attr.api.compute_type = ComputeType::kComputeStore;
auto y = graph.FindNode("y");
y->attr.api.compute_type = ComputeType::kComputeInvalid;
}
void LoadConcatStore_AfterGetApiInfo(af::AscGraph &graph) {
auto x1 = graph.FindNode("x1");
x1->attr.api.type = ApiType::kAPITypeBuffer;
x1->attr.api.unit = ComputeUnit::kUnitNone;
auto x2 = graph.FindNode("x2");
x2->attr.api.type = ApiType::kAPITypeBuffer;
x2->attr.api.unit = ComputeUnit::kUnitNone;
auto load1 = graph.FindNode("load1");
load1->attr.api.type = ApiType::kAPITypeCompute;
load1->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load2 = graph.FindNode("load2");
load2->attr.api.type = ApiType::kAPITypeCompute;
load2->attr.api.unit = ComputeUnit::kUnitMTE2;
auto concat = graph.FindNode("concat");
concat->attr.api.type = ApiType::kAPITypeCompute;
concat->attr.api.unit = ComputeUnit::kUnitVector;
auto store = graph.FindNode("store");
store->attr.api.type = ApiType::kAPITypeCompute;
store->attr.api.unit = ComputeUnit::kUnitMTE2;
auto y = graph.FindNode("y");
y->attr.api.type = ApiType::kAPITypeBuffer;
y->attr.api.unit = ComputeUnit::kUnitNone;
}
void LoadConcatStore_AfterScheduler(af::AscGraph &graph, int32_t alignment) {
int32_t input_alignment = 8;
int32_t output_alignment = 8;
if (alignment != -1) {
input_alignment = alignment;
output_alignment = 1;
}
auto all_axis = graph.GetAllAxis();
auto z0 = all_axis[0]->id;
auto zo = all_axis[1]->id;
auto zo_s_0 = all_axis[2]->id;
auto zo_s_1 = all_axis[3]->id;
auto [z0T, z0t] = graph.TileSplit(z0);
auto [z0TB, z0Tb] = graph.BlockSplit(z0T->id);
vector<AxisId> vectorized_axis{z0t->id, zo};
vector<af::Expression> vectorized_strides{One, One};
uint32_t align_size = 32 / sizeof(float);
vectorized_strides[0] = af::sym::Align(graph.FindAxis(vectorized_axis[1])->size, output_alignment);
vector<af::Expression> vectorized_strides_x1{af::sym::Align(graph.FindAxis(zo_s_0)->size, input_alignment), One};
vector<af::Expression> vectorized_strides_x2{af::sym::Align(graph.FindAxis(zo_s_1)->size, input_alignment), One};
auto x1 = graph.FindNode("x1");
graph.ApplySplit(x1, z0T->id, z0t->id);
graph.ApplySplit(x1, z0TB->id, z0Tb->id);
x1->attr.sched.loop_axis = z0Tb->id;
x1->outputs[0].attr.vectorized_axis = {z0t->id, zo_s_0};
x1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto x2 = graph.FindNode("x2");
graph.ApplySplit(x2, z0T->id, z0t->id);
graph.ApplySplit(x2, z0TB->id, z0Tb->id);
x2->attr.sched.loop_axis = z0Tb->id;
x2->outputs[0].attr.vectorized_axis = {z0t->id, zo_s_1};
x2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto load1 = graph.FindNode("load1");
graph.ApplySplit(load1, z0T->id, z0t->id);
graph.ApplySplit(load1, z0TB->id, z0Tb->id);
load1->attr.sched.loop_axis = z0Tb->id;
load1->outputs[0].attr.vectorized_axis = {z0t->id, zo_s_0};
load1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto load2 = graph.FindNode("load2");
graph.ApplySplit(load2, z0T->id, z0t->id);
graph.ApplySplit(load2, z0TB->id, z0Tb->id);
load2->attr.sched.loop_axis = z0Tb->id;
load2->outputs[0].attr.vectorized_axis = {z0t->id, zo_s_1};
load2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto concat = graph.FindNode("concat");
graph.ApplySplit(concat, z0T->id, z0t->id);
graph.ApplySplit(concat, z0TB->id, z0Tb->id);
concat->attr.sched.loop_axis = z0Tb->id;
concat->outputs[0].attr.vectorized_axis = vectorized_axis;
concat->outputs[0].attr.vectorized_strides = vectorized_strides;
auto store = graph.FindNode("store");
graph.ApplySplit(store, z0T->id, z0t->id);
graph.ApplySplit(store, z0TB->id, z0Tb->id);
store->attr.sched.loop_axis = z0Tb->id;
store->outputs[0].attr.vectorized_axis = vectorized_axis;
store->outputs[0].attr.vectorized_strides = vectorized_strides;
}
void LoadConcatStore_AfterSchedulerConcatInterAxis(af::AscGraph &graph) {
auto all_axis = graph.GetAllAxis();
auto z0 = all_axis[0]->id;
auto z1 = all_axis[1]->id;
auto z1_s_0 = all_axis[2]->id;
auto z1_s_1 = all_axis[3]->id;
auto z2 = all_axis[4]->id;
auto [z0T, z0t] = graph.TileSplit(z0);
auto [z0TB, z0Tb] = graph.BlockSplit(z0T->id);
vector<AxisId> vectorized_axis{z0t->id, z1, z2};
vector<af::Expression> vectorized_strides{graph.FindAxis(z1)->size * af::sym::Align(graph.FindAxis(z2)->size, 8),
af::sym::Align(graph.FindAxis(z2)->size, 8), One};
uint32_t align_size = 32 / sizeof(float);
vector<af::Expression> vectorized_strides_x1{graph.FindAxis(z1_s_0)->size * af::sym::Align(graph.FindAxis(z2)->size, 8),
af::sym::Align(graph.FindAxis(z2)->size, 8), One};
vector<af::Expression> vectorized_strides_x2{graph.FindAxis(z1_s_1)->size * af::sym::Align(graph.FindAxis(z2)->size, 8),
af::sym::Align(graph.FindAxis(z2)->size, 8), One};
auto x1 = graph.FindNode("x1");
graph.ApplySplit(x1, z0T->id, z0t->id);
graph.ApplySplit(x1, z0TB->id, z0Tb->id);
x1->attr.sched.loop_axis = z0Tb->id;
x1->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_0, z2};
x1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto x2 = graph.FindNode("x2");
graph.ApplySplit(x2, z0T->id, z0t->id);
graph.ApplySplit(x2, z0TB->id, z0Tb->id);
x2->attr.sched.loop_axis = z0Tb->id;
x2->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_1, z2};
x2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto load1 = graph.FindNode("load1");
graph.ApplySplit(load1, z0T->id, z0t->id);
graph.ApplySplit(load1, z0TB->id, z0Tb->id);
load1->attr.sched.loop_axis = z0Tb->id;
load1->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_0, z2};
load1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto load2 = graph.FindNode("load2");
graph.ApplySplit(load2, z0T->id, z0t->id);
graph.ApplySplit(load2, z0TB->id, z0Tb->id);
load2->attr.sched.loop_axis = z0Tb->id;
load2->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_1, z2};
load2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto concat = graph.FindNode("concat");
graph.ApplySplit(concat, z0T->id, z0t->id);
graph.ApplySplit(concat, z0TB->id, z0Tb->id);
concat->attr.sched.loop_axis = z0Tb->id;
concat->outputs[0].attr.vectorized_axis = vectorized_axis;
concat->outputs[0].attr.vectorized_strides = vectorized_strides;
auto store = graph.FindNode("store");
graph.ApplySplit(store, z0T->id, z0t->id);
graph.ApplySplit(store, z0TB->id, z0Tb->id);
store->attr.sched.loop_axis = z0Tb->id;
store->outputs[0].attr.vectorized_axis = vectorized_axis;
store->outputs[0].attr.vectorized_strides = vectorized_strides;
}
void LoadConcatStore_AfterScheduler3dLastAxis(af::AscGraph &graph) {
auto all_axis = graph.GetAllAxis();
auto z0 = all_axis[0]->id;
auto z1 = all_axis[1]->id;
auto z2 = all_axis[2]->id;
auto z2_s_0 = all_axis[3]->id;
auto z2_s_1 = all_axis[4]->id;
auto [z0T, z0t] = graph.TileSplit(z0);
auto [z0TB, z0Tb] = graph.BlockSplit(z0T->id);
vector<AxisId> vectorized_axis{z0t->id, z1, z2};
vector<af::Expression> vectorized_strides{graph.FindAxis(z1)->size * af::sym::Align(graph.FindAxis(z2)->size, 8),
af::sym::Align(graph.FindAxis(z2)->size, 8), One};
uint32_t align_size = 32 / sizeof(float);
vector<af::Expression> vectorized_strides_x1{graph.FindAxis(z1)->size * af::sym::Align(graph.FindAxis(z2_s_0)->size, 8),
af::sym::Align(graph.FindAxis(z2_s_0)->size, 8), One};
vector<af::Expression> vectorized_strides_x2{graph.FindAxis(z1)->size * af::sym::Align(graph.FindAxis(z2_s_1)->size, 8),
af::sym::Align(graph.FindAxis(z2_s_1)->size, 8), One};
auto x1 = graph.FindNode("x1");
graph.ApplySplit(x1, z0T->id, z0t->id);
graph.ApplySplit(x1, z0TB->id, z0Tb->id);
x1->attr.sched.loop_axis = z0Tb->id;
x1->outputs[0].attr.vectorized_axis = {z0t->id, z1, z2_s_0};
x1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto x2 = graph.FindNode("x2");
graph.ApplySplit(x2, z0T->id, z0t->id);
graph.ApplySplit(x2, z0TB->id, z0Tb->id);
x2->attr.sched.loop_axis = z0Tb->id;
x2->outputs[0].attr.vectorized_axis = {z0t->id, z1, z2_s_1};
x2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto load1 = graph.FindNode("load1");
graph.ApplySplit(load1, z0T->id, z0t->id);
graph.ApplySplit(load1, z0TB->id, z0Tb->id);
load1->attr.sched.loop_axis = z0Tb->id;
load1->outputs[0].attr.vectorized_axis = {z0t->id, z1, z2_s_0};
load1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto load2 = graph.FindNode("load2");
graph.ApplySplit(load2, z0T->id, z0t->id);
graph.ApplySplit(load2, z0TB->id, z0Tb->id);
load2->attr.sched.loop_axis = z0Tb->id;
load2->outputs[0].attr.vectorized_axis = {z0t->id, z1, z2_s_1};
load2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto concat = graph.FindNode("concat");
graph.ApplySplit(concat, z0T->id, z0t->id);
graph.ApplySplit(concat, z0TB->id, z0Tb->id);
concat->attr.sched.loop_axis = z0Tb->id;
concat->outputs[0].attr.vectorized_axis = vectorized_axis;
concat->outputs[0].attr.vectorized_strides = vectorized_strides;
auto store = graph.FindNode("store");
graph.ApplySplit(store, z0T->id, z0t->id);
graph.ApplySplit(store, z0TB->id, z0Tb->id);
store->attr.sched.loop_axis = z0Tb->id;
store->outputs[0].attr.vectorized_axis = vectorized_axis;
store->outputs[0].attr.vectorized_strides = vectorized_strides;
}
void LoadConcatStore_AfterQueBufAlloc(af::AscGraph &graph) {
auto x1 = graph.FindNode("x1");
x1->outputs[0].attr.mem.tensor_id = 0;
x1->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x1->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x1->outputs[0].attr.mem.position = Position::kPositionGM;
x1->outputs[0].attr.buf.id = af::kIdNone;
x1->outputs[0].attr.que.id = af::kIdNone;
x1->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x1->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x2 = graph.FindNode("x2");
x2->outputs[0].attr.mem.tensor_id = 1;
x2->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x2->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x2->outputs[0].attr.mem.position = Position::kPositionGM;
x2->outputs[0].attr.buf.id = af::kIdNone;
x2->outputs[0].attr.que.id = af::kIdNone;
x2->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x2->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load1 = graph.FindNode("load1");
load1->outputs[0].attr.mem.tensor_id = 2;
load1->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load1->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load1->outputs[0].attr.mem.position = Position::kPositionVecIn;
load1->outputs[0].attr.buf.id = af::kIdNone;
load1->outputs[0].attr.que.id = 0;
load1->outputs[0].attr.mem.reuse_id = 0;
load1->outputs[0].attr.que.depth = 2;
load1->outputs[0].attr.que.buf_num = 2;
load1->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load1->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load2 = graph.FindNode("load2");
load2->outputs[0].attr.mem.tensor_id = 3;
load2->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load2->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load2->outputs[0].attr.mem.position = Position::kPositionVecIn;
load2->outputs[0].attr.buf.id = af::kIdNone;
load2->outputs[0].attr.que.id = 1;
load2->outputs[0].attr.mem.reuse_id = 1;
load2->outputs[0].attr.que.depth = 2;
load2->outputs[0].attr.que.buf_num = 2;
load2->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load2->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto concat = graph.FindNode("concat");
concat->outputs[0].attr.mem.tensor_id = 4;
concat->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
concat->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
concat->outputs[0].attr.mem.position = Position::kPositionVecOut;
concat->outputs[0].attr.buf.id = af::kIdNone;
concat->outputs[0].attr.que.id = 2;
concat->outputs[0].attr.mem.reuse_id = 2;
concat->outputs[0].attr.que.depth = 2;
concat->outputs[0].attr.que.buf_num = 2;
concat->outputs[0].attr.opt.ref_tensor = af::kIdNone;
concat->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto store = graph.FindNode("store");
store->outputs[0].attr.mem.tensor_id = 5;
store->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
store->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
store->outputs[0].attr.mem.position = Position::kPositionGM;
store->outputs[0].attr.buf.id = af::kIdNone;
store->outputs[0].attr.que.id = af::kIdNone;
store->outputs[0].attr.opt.ref_tensor = af::kIdNone;
store->outputs[0].attr.opt.merge_scope = af::kIdNone;
}
void LoadConcatStore_BeforeAutofuse7Inputs(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto s4 = graph.CreateSizeVar("s4");
auto s5 = graph.CreateSizeVar("s5");
auto s6 = graph.CreateSizeVar("s6");
auto s7 = graph.CreateSizeVar("s7");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1 + s2 + s3 + s4 + s5 + s6 + s7);
auto z1_s_1 = graph.CreateAxis("z1_s_1", Axis::Type::kAxisTypeOriginal, s1, {z1.id}, af::kIdNone);
auto z1_s_2 = graph.CreateAxis("z1_s_2", Axis::Type::kAxisTypeOriginal, s2, {z1.id}, af::kIdNone);
auto z1_s_3 = graph.CreateAxis("z1_s_3", Axis::Type::kAxisTypeOriginal, s3, {z1.id}, af::kIdNone);
auto z1_s_4 = graph.CreateAxis("z1_s_4", Axis::Type::kAxisTypeOriginal, s4, {z1.id}, af::kIdNone);
auto z1_s_5 = graph.CreateAxis("z1_s_5", Axis::Type::kAxisTypeOriginal, s5, {z1.id}, af::kIdNone);
auto z1_s_6 = graph.CreateAxis("z1_s_6", Axis::Type::kAxisTypeOriginal, s6, {z1.id}, af::kIdNone);
auto z1_s_7 = graph.CreateAxis("z1_s_7", Axis::Type::kAxisTypeOriginal, s7, {z1.id}, af::kIdNone);
Data x1("x1", graph);
Data x2("x2", graph);
Data x3("x3", graph);
Data x4("x4", graph);
Data x5("x5", graph);
Data x6("x6", graph);
Data x7("x7", graph);
Load load1("load1");
Load load2("load2");
Load load3("load3");
Load load4("load4");
Load load5("load5");
Load load6("load6");
Load load7("load7");
af::ascir_op::Concat concat("concat");
Store store("store");
Output y("y");
x1.attr.sched.axis = {z0.id, z1_s_1.id};
x1.y.dtype = ge::DT_INT64;
*x1.y.axis = {z0.id, z1_s_1.id};
*x1.y.repeats = {s0, s1};
*x1.y.strides = {s1, One};
x2.attr.sched.axis = {z0.id, z1_s_2.id};
x2.y.dtype = ge::DT_INT64;
*x2.y.axis = {z0.id, z1_s_2.id};
*x2.y.repeats = {s0, s2};
*x2.y.strides = {s2, One};
x3.attr.sched.axis = {z0.id, z1_s_3.id};
x3.y.dtype = ge::DT_INT64;
*x3.y.axis = {z0.id, z1_s_3.id};
*x3.y.repeats = {s0, s3};
*x3.y.strides = {s3, One};
x4.attr.sched.axis = {z0.id, z1_s_4.id};
x4.y.dtype = ge::DT_INT64;
*x4.y.axis = {z0.id, z1_s_4.id};
*x4.y.repeats = {s0, s4};
*x4.y.strides = {s4, One};
x5.attr.sched.axis = {z0.id, z1_s_5.id};
x5.y.dtype = ge::DT_INT64;
*x5.y.axis = {z0.id, z1_s_5.id};
*x5.y.repeats = {s0, s5};
*x5.y.strides = {s5, One};
x6.attr.sched.axis = {z0.id, z1_s_6.id};
x6.y.dtype = ge::DT_INT64;
*x6.y.axis = {z0.id, z1_s_6.id};
*x6.y.repeats = {s0, s6};
*x6.y.strides = {s6, One};
x7.attr.sched.axis = {z0.id, z1_s_7.id};
x7.y.dtype = ge::DT_INT64;
*x7.y.axis = {z0.id, z1_s_7.id};
*x7.y.repeats = {s0, s7};
*x7.y.strides = {s7, One};
load1.x = x1.y;
load1.attr.sched.axis = {z0.id, z1_s_1.id};
load1.y.dtype = ge::DT_INT64;
*load1.y.axis = {z0.id, z1_s_1.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, One};
load2.x = x2.y;
load2.attr.sched.axis = {z0.id, z1_s_2.id};
load2.y.dtype = ge::DT_INT64;
*load2.y.axis = {z0.id, z1_s_2.id};
*load2.y.repeats = {s0, s2};
*load2.y.strides = {s2, One};
load3.x = x3.y;
load3.attr.sched.axis = {z0.id, z1_s_3.id};
load3.y.dtype = ge::DT_INT64;
*load3.y.axis = {z0.id, z1_s_3.id};
*load3.y.repeats = {s0, s3};
*load3.y.strides = {s3, One};
load4.x = x4.y;
load4.attr.sched.axis = {z0.id, z1_s_4.id};
load4.y.dtype = ge::DT_INT64;
*load4.y.axis = {z0.id, z1_s_4.id};
*load4.y.repeats = {s0, s4};
*load4.y.strides = {s4, One};
load5.x = x5.y;
load5.attr.sched.axis = {z0.id, z1_s_5.id};
load5.y.dtype = ge::DT_INT64;
*load5.y.axis = {z0.id, z1_s_5.id};
*load5.y.repeats = {s0, s5};
*load5.y.strides = {s5, One};
load6.x = x6.y;
load6.attr.sched.axis = {z0.id, z1_s_6.id};
load6.y.dtype = ge::DT_INT64;
*load6.y.axis = {z0.id, z1_s_6.id};
*load6.y.repeats = {s0, s6};
*load6.y.strides = {s6, One};
load7.x = x7.y;
load7.attr.sched.axis = {z0.id, z1_s_7.id};
load7.y.dtype = ge::DT_INT64;
*load7.y.axis = {z0.id, z1_s_7.id};
*load7.y.repeats = {s0, s7};
*load7.y.strides = {s7, One};
concat.x = {load1.y, load2.y, load3.y, load4.y, load5.y, load6.y, load7.y};
concat.attr.sched.axis = {z0.id, z1.id};
concat.y.dtype = ge::DT_INT64;
*concat.y.axis = {z0.id, z1.id};
*concat.y.repeats = {s0, s1 + s2 + s3 + s4 + s5 + s6 + s7};
*concat.y.strides = {s1 + s2 + s3 + s4 + s5 + s6 + s7, One};
concat.attr.tmp_buffers = {{{af::Symbol(16384), -1}, MemAttr(), 0}};
store.x = concat.y;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = ge::DT_INT64;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1 + s2 + s3 + s4 + s5 + s6 + s7};
*store.y.strides = {s1 + s2 + s3 + s4 + s5 + s6 + s7, One};
y.x = store.y;
y.attr.sched.axis = {z0.id, z1.id};
y.y.dtype = ge::DT_INT64;
*y.y.axis = {z0.id, z1.id};
*y.y.repeats = {s0, s1 + s2 + s3 + s4 + s5 + s6 + s7};
*y.y.strides = {s1 + s2 + s3 + s4 + s5 + s6 + s7, One};
}
void LoadConcatStore_AfterInferOutput7Inputs(af::AscGraph &graph) {
auto x1 = graph.FindNode("x1");
x1->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x2 = graph.FindNode("x2");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x3 = graph.FindNode("x3");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x4 = graph.FindNode("x4");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x5 = graph.FindNode("x5");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x6 = graph.FindNode("x6");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto x7 = graph.FindNode("x7");
x2->attr.api.compute_type = ComputeType::kComputeInvalid;
auto load1 = graph.FindNode("load1");
load1->outputs[0].attr.dtype = ge::DT_INT64;
load1->attr.api.compute_type = ComputeType::kComputeLoad;
auto load2 = graph.FindNode("load2");
load2->outputs[0].attr.dtype = ge::DT_INT64;
load2->attr.api.compute_type = ComputeType::kComputeLoad;
auto load3 = graph.FindNode("load3");
load3->outputs[0].attr.dtype = ge::DT_INT64;
load3->attr.api.compute_type = ComputeType::kComputeLoad;
auto load4 = graph.FindNode("load4");
load4->outputs[0].attr.dtype = ge::DT_INT64;
load4->attr.api.compute_type = ComputeType::kComputeLoad;
auto load5 = graph.FindNode("load5");
load5->outputs[0].attr.dtype = ge::DT_INT64;
load5->attr.api.compute_type = ComputeType::kComputeLoad;
auto load6 = graph.FindNode("load6");
load6->outputs[0].attr.dtype = ge::DT_INT64;
load6->attr.api.compute_type = ComputeType::kComputeLoad;
auto load7 = graph.FindNode("load7");
load7->outputs[0].attr.dtype = ge::DT_INT64;
load7->attr.api.compute_type = ComputeType::kComputeLoad;
auto concat = graph.FindNode("concat");
concat->outputs[0].attr.dtype =(ge::DataType)load1->outputs[0].attr.dtype;
concat->attr.api.compute_type = ComputeType::kComputeConcat;
auto store = graph.FindNode("store");
store->outputs[0].attr.dtype = (ge::DataType)concat->outputs[0].attr.dtype;
store->attr.api.compute_type = ComputeType::kComputeStore;
auto y = graph.FindNode("y");
y->attr.api.compute_type = ComputeType::kComputeInvalid;
y->outputs[0].attr.dtype = (ge::DataType)concat->outputs[0].attr.dtype;
}
void LoadConcatStore_AfterGetApiInfo7Inputs(af::AscGraph &graph) {
auto x1 = graph.FindNode("x1");
x1->attr.api.type = ApiType::kAPITypeBuffer;
x1->attr.api.unit = ComputeUnit::kUnitNone;
auto x2 = graph.FindNode("x2");
x2->attr.api.type = ApiType::kAPITypeBuffer;
x2->attr.api.unit = ComputeUnit::kUnitNone;
auto x3 = graph.FindNode("x3");
x3->attr.api.type = ApiType::kAPITypeBuffer;
x3->attr.api.unit = ComputeUnit::kUnitNone;
auto x4 = graph.FindNode("x4");
x4->attr.api.type = ApiType::kAPITypeBuffer;
x4->attr.api.unit = ComputeUnit::kUnitNone;
auto x5 = graph.FindNode("x5");
x5->attr.api.type = ApiType::kAPITypeBuffer;
x5->attr.api.unit = ComputeUnit::kUnitNone;
auto x6 = graph.FindNode("x6");
x6->attr.api.type = ApiType::kAPITypeBuffer;
x6->attr.api.unit = ComputeUnit::kUnitNone;
auto x7 = graph.FindNode("x7");
x7->attr.api.type = ApiType::kAPITypeBuffer;
x7->attr.api.unit = ComputeUnit::kUnitNone;
auto load1 = graph.FindNode("load1");
load1->attr.api.type = ApiType::kAPITypeCompute;
load1->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load2 = graph.FindNode("load2");
load2->attr.api.type = ApiType::kAPITypeCompute;
load2->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load3 = graph.FindNode("load3");
load3->attr.api.type = ApiType::kAPITypeCompute;
load3->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load4 = graph.FindNode("load4");
load4->attr.api.type = ApiType::kAPITypeCompute;
load4->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load5 = graph.FindNode("load5");
load5->attr.api.type = ApiType::kAPITypeCompute;
load5->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load6 = graph.FindNode("load6");
load6->attr.api.type = ApiType::kAPITypeCompute;
load6->attr.api.unit = ComputeUnit::kUnitMTE2;
auto load7 = graph.FindNode("load7");
load7->attr.api.type = ApiType::kAPITypeCompute;
load7->attr.api.unit = ComputeUnit::kUnitMTE2;
auto concat = graph.FindNode("concat");
concat->attr.api.type = ApiType::kAPITypeCompute;
concat->attr.api.unit = ComputeUnit::kUnitVector;
auto store = graph.FindNode("store");
store->attr.api.type = ApiType::kAPITypeCompute;
store->attr.api.unit = ComputeUnit::kUnitMTE2;
auto y = graph.FindNode("y");
y->attr.api.type = ApiType::kAPITypeBuffer;
y->attr.api.unit = ComputeUnit::kUnitNone;
}
void LoadConcatStore_AfterScheduler7Inputs(af::AscGraph &graph) {
auto all_axis = graph.GetAllAxis();
auto z0 = all_axis[0]->id;
auto z1 = all_axis[1]->id;
auto z1_s_1 = all_axis[2]->id;
auto z1_s_2 = all_axis[3]->id;
auto z1_s_3 = all_axis[4]->id;
auto z1_s_4 = all_axis[5]->id;
auto z1_s_5 = all_axis[6]->id;
auto z1_s_6 = all_axis[7]->id;
auto z1_s_7 = all_axis[8]->id;
auto [z0T, z0t] = graph.TileSplit(z0);
auto [z0TB, z0Tb] = graph.BlockSplit(z0T->id);
vector<AxisId> vectorized_axis{z0t->id, z1};
vector<af::Expression> vectorized_strides{One, One};
uint32_t align_size = 32 / sizeof(int64_t);
vectorized_strides[0] = af::sym::Align(graph.FindAxis(vectorized_axis[1])->size, align_size);
vector<af::Expression> vectorized_strides_x1{af::sym::Align(graph.FindAxis(z1_s_1)->size, align_size), One};
vector<af::Expression> vectorized_strides_x2{af::sym::Align(graph.FindAxis(z1_s_2)->size, align_size), One};
vector<af::Expression> vectorized_strides_x3{af::sym::Align(graph.FindAxis(z1_s_3)->size, align_size), One};
vector<af::Expression> vectorized_strides_x4{af::sym::Align(graph.FindAxis(z1_s_4)->size, align_size), One};
vector<af::Expression> vectorized_strides_x5{af::sym::Align(graph.FindAxis(z1_s_5)->size, align_size), One};
vector<af::Expression> vectorized_strides_x6{af::sym::Align(graph.FindAxis(z1_s_6)->size, align_size), One};
vector<af::Expression> vectorized_strides_x7{af::sym::Align(graph.FindAxis(z1_s_7)->size, align_size), One};
auto x1 = graph.FindNode("x1");
graph.ApplySplit(x1, z0T->id, z0t->id);
graph.ApplySplit(x1, z0TB->id, z0Tb->id);
x1->attr.sched.loop_axis = z0Tb->id;
x1->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_1};
x1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto x2 = graph.FindNode("x2");
graph.ApplySplit(x2, z0T->id, z0t->id);
graph.ApplySplit(x2, z0TB->id, z0Tb->id);
x2->attr.sched.loop_axis = z0Tb->id;
x2->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_2};
x2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto x3 = graph.FindNode("x3");
graph.ApplySplit(x3, z0T->id, z0t->id);
graph.ApplySplit(x3, z0TB->id, z0Tb->id);
x3->attr.sched.loop_axis = z0Tb->id;
x3->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_3};
x3->outputs[0].attr.vectorized_strides = vectorized_strides_x3;
auto x4 = graph.FindNode("x4");
graph.ApplySplit(x4, z0T->id, z0t->id);
graph.ApplySplit(x4, z0TB->id, z0Tb->id);
x4->attr.sched.loop_axis = z0Tb->id;
x4->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_4};
x4->outputs[0].attr.vectorized_strides = vectorized_strides_x4;
auto x5 = graph.FindNode("x5");
graph.ApplySplit(x5, z0T->id, z0t->id);
graph.ApplySplit(x5, z0TB->id, z0Tb->id);
x5->attr.sched.loop_axis = z0Tb->id;
x5->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_5};
x5->outputs[0].attr.vectorized_strides = vectorized_strides_x5;
auto x6 = graph.FindNode("x6");
graph.ApplySplit(x6, z0T->id, z0t->id);
graph.ApplySplit(x6, z0TB->id, z0Tb->id);
x6->attr.sched.loop_axis = z0Tb->id;
x6->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_6};
x6->outputs[0].attr.vectorized_strides = vectorized_strides_x6;
auto x7 = graph.FindNode("x7");
graph.ApplySplit(x7, z0T->id, z0t->id);
graph.ApplySplit(x7, z0TB->id, z0Tb->id);
x7->attr.sched.loop_axis = z0Tb->id;
x7->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_7};
x7->outputs[0].attr.vectorized_strides = vectorized_strides_x7;
auto load1 = graph.FindNode("load1");
graph.ApplySplit(load1, z0T->id, z0t->id);
graph.ApplySplit(load1, z0TB->id, z0Tb->id);
load1->attr.sched.loop_axis = z0Tb->id;
load1->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_1};
load1->outputs[0].attr.vectorized_strides = vectorized_strides_x1;
auto load2 = graph.FindNode("load2");
graph.ApplySplit(load2, z0T->id, z0t->id);
graph.ApplySplit(load2, z0TB->id, z0Tb->id);
load2->attr.sched.loop_axis = z0Tb->id;
load2->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_2};
load2->outputs[0].attr.vectorized_strides = vectorized_strides_x2;
auto load3 = graph.FindNode("load3");
graph.ApplySplit(load3, z0T->id, z0t->id);
graph.ApplySplit(load3, z0TB->id, z0Tb->id);
load3->attr.sched.loop_axis = z0Tb->id;
load3->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_3};
load3->outputs[0].attr.vectorized_strides = vectorized_strides_x3;
auto load4 = graph.FindNode("load4");
graph.ApplySplit(load4, z0T->id, z0t->id);
graph.ApplySplit(load4, z0TB->id, z0Tb->id);
load4->attr.sched.loop_axis = z0Tb->id;
load4->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_4};
load4->outputs[0].attr.vectorized_strides = vectorized_strides_x4;
auto load5 = graph.FindNode("load5");
graph.ApplySplit(load5, z0T->id, z0t->id);
graph.ApplySplit(load5, z0TB->id, z0Tb->id);
load5->attr.sched.loop_axis = z0Tb->id;
load5->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_5};
load5->outputs[0].attr.vectorized_strides = vectorized_strides_x5;
auto load6 = graph.FindNode("load6");
graph.ApplySplit(load6, z0T->id, z0t->id);
graph.ApplySplit(load6, z0TB->id, z0Tb->id);
load6->attr.sched.loop_axis = z0Tb->id;
load6->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_6};
load6->outputs[0].attr.vectorized_strides = vectorized_strides_x6;
auto load7 = graph.FindNode("load7");
graph.ApplySplit(load7, z0T->id, z0t->id);
graph.ApplySplit(load7, z0TB->id, z0Tb->id);
load7->attr.sched.loop_axis = z0Tb->id;
load7->outputs[0].attr.vectorized_axis = {z0t->id, z1_s_7};
load7->outputs[0].attr.vectorized_strides = vectorized_strides_x7;
auto concat = graph.FindNode("concat");
graph.ApplySplit(concat, z0T->id, z0t->id);
graph.ApplySplit(concat, z0TB->id, z0Tb->id);
concat->attr.sched.loop_axis = z0Tb->id;
concat->outputs[0].attr.vectorized_axis = vectorized_axis;
concat->outputs[0].attr.vectorized_strides = vectorized_strides;
auto store = graph.FindNode("store");
graph.ApplySplit(store, z0T->id, z0t->id);
graph.ApplySplit(store, z0TB->id, z0Tb->id);
store->attr.sched.loop_axis = z0Tb->id;
store->outputs[0].attr.vectorized_axis = vectorized_axis;
store->outputs[0].attr.vectorized_strides = vectorized_strides;
}
void LoadConcatStore_AfterQueBufAlloc7Inputs(af::AscGraph &graph) {
auto x1 = graph.FindNode("x1");
x1->outputs[0].attr.mem.tensor_id = 0;
x1->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x1->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x1->outputs[0].attr.mem.position = Position::kPositionGM;
x1->outputs[0].attr.buf.id = af::kIdNone;
x1->outputs[0].attr.que.id = af::kIdNone;
x1->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x1->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x2 = graph.FindNode("x2");
x2->outputs[0].attr.mem.tensor_id = 1;
x2->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x2->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x2->outputs[0].attr.mem.position = Position::kPositionGM;
x2->outputs[0].attr.buf.id = af::kIdNone;
x2->outputs[0].attr.que.id = af::kIdNone;
x2->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x2->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x3 = graph.FindNode("x3");
x3->outputs[0].attr.mem.tensor_id = 2;
x3->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x3->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x3->outputs[0].attr.mem.position = Position::kPositionGM;
x3->outputs[0].attr.buf.id = af::kIdNone;
x3->outputs[0].attr.que.id = af::kIdNone;
x3->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x3->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x4 = graph.FindNode("x4");
x4->outputs[0].attr.mem.tensor_id = 3;
x4->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x4->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x4->outputs[0].attr.mem.position = Position::kPositionGM;
x4->outputs[0].attr.buf.id = af::kIdNone;
x4->outputs[0].attr.que.id = af::kIdNone;
x4->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x4->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x5 = graph.FindNode("x5");
x5->outputs[0].attr.mem.tensor_id = 4;
x5->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x5->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x5->outputs[0].attr.mem.position = Position::kPositionGM;
x5->outputs[0].attr.buf.id = af::kIdNone;
x5->outputs[0].attr.que.id = af::kIdNone;
x5->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x5->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x6 = graph.FindNode("x6");
x6->outputs[0].attr.mem.tensor_id = 5;
x6->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x6->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x6->outputs[0].attr.mem.position = Position::kPositionGM;
x6->outputs[0].attr.buf.id = af::kIdNone;
x6->outputs[0].attr.que.id = af::kIdNone;
x6->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x6->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto x7 = graph.FindNode("x7");
x7->outputs[0].attr.mem.tensor_id = 6;
x7->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
x7->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
x7->outputs[0].attr.mem.position = Position::kPositionGM;
x7->outputs[0].attr.buf.id = af::kIdNone;
x7->outputs[0].attr.que.id = af::kIdNone;
x7->outputs[0].attr.opt.ref_tensor = af::kIdNone;
x7->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load1 = graph.FindNode("load1");
load1->outputs[0].attr.mem.tensor_id = 8;
load1->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load1->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load1->outputs[0].attr.mem.position = Position::kPositionVecIn;
load1->outputs[0].attr.buf.id = af::kIdNone;
load1->outputs[0].attr.que.id = 0;
load1->outputs[0].attr.mem.reuse_id = 0;
load1->outputs[0].attr.que.depth = 1;
load1->outputs[0].attr.que.buf_num = 2;
load1->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load1->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load2 = graph.FindNode("load2");
load2->outputs[0].attr.mem.tensor_id = 9;
load2->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load2->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load2->outputs[0].attr.mem.position = Position::kPositionVecIn;
load2->outputs[0].attr.buf.id = af::kIdNone;
load2->outputs[0].attr.que.id = 1;
load2->outputs[0].attr.mem.reuse_id = 1;
load2->outputs[0].attr.que.depth = 1;
load2->outputs[0].attr.que.buf_num = 2;
load2->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load2->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load3 = graph.FindNode("load3");
load3->outputs[0].attr.mem.tensor_id = 10;
load3->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load3->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load3->outputs[0].attr.mem.position = Position::kPositionVecIn;
load3->outputs[0].attr.buf.id = af::kIdNone;
load3->outputs[0].attr.que.id = 2;
load3->outputs[0].attr.mem.reuse_id = 2;
load3->outputs[0].attr.que.depth = 1;
load3->outputs[0].attr.que.buf_num = 2;
load3->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load3->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load4 = graph.FindNode("load4");
load4->outputs[0].attr.mem.tensor_id = 11;
load4->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load4->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load4->outputs[0].attr.mem.position = Position::kPositionVecIn;
load4->outputs[0].attr.buf.id = af::kIdNone;
load4->outputs[0].attr.que.id = 3;
load4->outputs[0].attr.mem.reuse_id = 3;
load4->outputs[0].attr.que.depth = 1;
load4->outputs[0].attr.que.buf_num = 2;
load4->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load4->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load5 = graph.FindNode("load5");
load5->outputs[0].attr.mem.tensor_id = 12;
load5->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load5->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load5->outputs[0].attr.mem.position = Position::kPositionVecIn;
load5->outputs[0].attr.buf.id = af::kIdNone;
load5->outputs[0].attr.que.id = 4;
load5->outputs[0].attr.mem.reuse_id = 4;
load5->outputs[0].attr.que.depth = 1;
load5->outputs[0].attr.que.buf_num = 2;
load5->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load5->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load6 = graph.FindNode("load6");
load6->outputs[0].attr.mem.tensor_id = 13;
load6->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load6->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load6->outputs[0].attr.mem.position = Position::kPositionVecIn;
load6->outputs[0].attr.buf.id = af::kIdNone;
load6->outputs[0].attr.que.id = 5;
load6->outputs[0].attr.mem.reuse_id = 5;
load6->outputs[0].attr.que.depth = 1;
load6->outputs[0].attr.que.buf_num = 2;
load6->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load6->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto load7 = graph.FindNode("load7");
load7->outputs[0].attr.mem.tensor_id = 14;
load7->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
load7->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
load7->outputs[0].attr.mem.position = Position::kPositionVecIn;
load7->outputs[0].attr.buf.id = af::kIdNone;
load7->outputs[0].attr.que.id = 6;
load7->outputs[0].attr.mem.reuse_id = 6;
load7->outputs[0].attr.que.depth = 1;
load7->outputs[0].attr.que.buf_num = 2;
load7->outputs[0].attr.opt.ref_tensor = af::kIdNone;
load7->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto concat = graph.FindNode("concat");
concat->outputs[0].attr.mem.tensor_id = 16;
concat->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeQueue;
concat->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareUB;
concat->outputs[0].attr.mem.position = Position::kPositionVecOut;
concat->outputs[0].attr.buf.id = af::kIdNone;
concat->outputs[0].attr.que.id = 7;
concat->outputs[0].attr.mem.reuse_id = 7;
concat->outputs[0].attr.que.depth = 1;
concat->outputs[0].attr.que.buf_num = 2;
concat->outputs[0].attr.opt.ref_tensor = af::kIdNone;
concat->outputs[0].attr.opt.merge_scope = af::kIdNone;
auto store = graph.FindNode("store");
store->outputs[0].attr.mem.tensor_id = 17;
store->outputs[0].attr.mem.alloc_type = AllocType::kAllocTypeGlobal;
store->outputs[0].attr.mem.hardware = MemHardware::kMemHardwareGM;
store->outputs[0].attr.mem.position = Position::kPositionGM;
store->outputs[0].attr.buf.id = af::kIdNone;
store->outputs[0].attr.que.id = af::kIdNone;
store->outputs[0].attr.opt.ref_tensor = af::kIdNone;
store->outputs[0].attr.opt.merge_scope = af::kIdNone;
}
void LoadConcatStore_SmallTailBeforeAutofuse(AscGraph &graph,
ge::DataType data_type,
const std::vector<int64_t> &concat_dim_sizes) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar(concat_dim_sizes[0]);
auto s2 = graph.CreateSizeVar(concat_dim_sizes[1]);
auto z0 = graph.CreateAxis("z0", s0);
auto zo = graph.CreateAxis("zo", s1 + s2);
auto zo_s_0 = graph.CreateAxis("zo_s_0", Axis::Type::kAxisTypeOriginal, s1, {zo.id}, af::kIdNone);
auto zo_s_1 = graph.CreateAxis("zo_s_1", Axis::Type::kAxisTypeOriginal, s2, {zo.id}, af::kIdNone);
Data x1("x1", graph);
Data x2("x2", graph);
Load load1("load1");
Load load2("load2");
af::ascir_op::Concat concat("concat");
Store store("store");
Output y("y");
x1.attr.sched.axis = {z0.id, zo_s_0.id};
x1.y.dtype = data_type;
*x1.y.axis = {z0.id, zo_s_0.id};
*x1.y.repeats = {s0, s1};
*x1.y.strides = {s1, One};
x2.attr.sched.axis = {z0.id, zo_s_1.id};
x2.y.dtype = data_type;
*x2.y.axis = {z0.id, zo_s_1.id};
*x2.y.repeats = {s0, s2};
*x2.y.strides = {s2, One};
load1.x = x1.y;
load1.attr.sched.axis = {z0.id, zo_s_0.id};
load1.y.dtype = data_type;
*load1.y.axis = {z0.id, zo_s_0.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, One};
load2.x = x2.y;
load2.attr.sched.axis = {z0.id, zo_s_1.id};
load2.y.dtype = data_type;
*load2.y.axis = {z0.id, zo_s_1.id};
*load2.y.repeats = {s0, s2};
*load2.y.strides = {s2, One};
concat.x = {load1.y, load2.y};
concat.attr.sched.axis = {z0.id, zo.id};
concat.y.dtype = data_type;
*concat.y.axis = {z0.id, zo.id};
*concat.y.repeats = {s0, s1 + s2};
*concat.y.strides = {s1+s2, One};
concat.attr.tmp_buffers = {{{af::Symbol(16384), -1}, MemAttr(), 0}};
store.x = concat.y;
store.attr.sched.axis = {z0.id, zo.id};
store.y.dtype = data_type;
*store.y.axis = {z0.id, zo.id};
*store.y.repeats = {s0, s1 + s2};
*store.y.strides = {s1+s2, One};
y.x = store.y;
y.attr.sched.axis = {z0.id, zo.id};
y.y.dtype = data_type;
*y.y.axis = {z0.id, zo.id};
*y.y.repeats = {s0, s1 + s2};
*y.y.strides = {s1 + s2, One};
}
void LoadConcatStore_SmallTailAfterInferOutput(AscGraph &graph) {
LoadConcatStore_AfterInferOutput(graph);
}
void LoadConcatStore_SmallTailAfterGetApiInfo(AscGraph &graph) {
LoadConcatStore_AfterGetApiInfo(graph);
}
void LoadConcatStore_SmallTailAfterScheduler(AscGraph &graph, int32_t alignment) {
LoadConcatStore_AfterScheduler(graph, alignment);
}
void LoadConcatStore_SmallTailAfterQueBufAlloc(AscGraph &graph) {
LoadConcatStore_AfterQueBufAlloc(graph);
auto node = graph.FindNode("concat");
af::AttrUtils::SetBool(node->GetOpDesc(), "_concat_small_tail", true);
}