* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "es_showcase.h"
#include "es_all_ops.h"
#include "utils.h"
#include <memory>
#include "ge/ge_api.h"
#include <random>
using namespace ge;
using namespace ge::es;
namespace es_showcase {
int RunGraph(ge::Graph &graph, const std::vector<ge::Tensor> &inputs,
const std::string &output_prefix) {
ge::Utils::PrintTensorsToFile(inputs, "input");
std::map<ge::AscendString, ge::AscendString> options;
auto *s = new (std::nothrow) ge::Session(options);
if (s == nullptr) {
std::cout << "Global session not ready" << std::endl;
return -1;
}
static uint32_t next =0;
const uint32_t graph_id = next++;
auto ret = s->AddGraph(graph_id, graph);
if (ret != ge::SUCCESS) {
std::cout << "AddGraph failed" << std::endl;
delete s;
return -1;
}
std::vector<ge::Tensor> outputs;
ret = s->RunGraph(graph_id, inputs, outputs);
if (ret != ge::SUCCESS) {
std::cout << "RunGraph failed" << std::endl;
(void)s->RemoveGraph(graph_id);
delete s;
return -1;
}
(void)s->RemoveGraph(graph_id);
ge::Utils::PrintTensorsToFile(outputs, output_prefix);
delete s;
return 0;
}
std::unique_ptr<ge::Graph> MakeTransformerGraphByEs() {
auto builder = std::make_unique<EsGraphBuilder>("MakeTransformerSubGraph");
auto [input1, input2, input3] = builder->CreateInputs<3>();
auto reshape_result1 = Reshape(input1, std::vector<int64_t>{-1, 7168});
auto matmul_result1 = MatMul(
Cast(reshape_result1, ge::DT_FLOAT),
Transpose(Cast(input2, ge::DT_FLOAT), std::vector<int64_t>{1, 0}));
auto sigmoid_result1 = Sigmoid(matmul_result1);
auto reshape_result2 = Reshape(sigmoid_result1, std::vector<int64_t>{-1, 256});
auto add_result1 = reshape_result2 + Cast(Unsqueeze(input3, std::vector<int64_t>{0}), ge::DT_FLOAT);
auto [values1, indices1] = TopKV2(add_result1, builder->CreateScalar(2), true, -1, true, 3);
auto reducesum_result1 = ReduceSum(values1, std::vector<int64_t>{-1});
auto [values2, indices2] = TopKV2(reducesum_result1, builder->CreateScalar(4), false, -1, true, 3);
auto cast_result2 = Cast(indices2, ge::DT_INT64);
auto scatterelements_result1 = ScatterElements(
ZerosLike(reducesum_result1),
cast_result2,
Fill(ge::es::Shape(cast_result2), Cast(builder->CreateScalar(1.0f), ge::DT_FLOAT))
);
auto identity_result1 = Identity(
BroadcastTo(Unsqueeze(scatterelements_result1, std::vector<int64_t>{-1}), std::vector<int64_t>{256, 256})
);
auto maskedfill_result1 = MaskedFill(
add_result1,
LogicalNot(Cast(Reshape(identity_result1, std::vector<int64_t>{256, 256}), ge::DT_BOOL)),
builder->CreateScalar(0.0f)
);
auto [values3, indices3] = TopKV2(maskedfill_result1, builder->CreateScalar(4), false, -1, true, 3);
auto cast_result3 = Cast(indices3, ge::DT_INT64);
auto gatherelements_result1 = GatherElements(sigmoid_result1, cast_result3, 1);
auto realdiv_result1 = RealDiv(gatherelements_result1, builder->CreateScalar(1e-6f));
return builder->BuildAndReset({cast_result3, Cast(realdiv_result1 * builder->CreateScalar(2.5f), ge::DT_FLOAT)});
}
void MakeTransformerGraphByEsAndDump() {
std::unique_ptr<ge::Graph> graph = MakeTransformerGraphByEs();
graph->DumpToFile(ge::Graph::DumpFormat::kOnnx, ge::AscendString("make_transformer_graph"));
}
int MakeTransformerGraphByEsAndRun() {
std::unique_ptr<ge::Graph> graph = MakeTransformerGraphByEs();
std::vector<ge::Tensor> inputs;
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> dist(0.0f, 1.0f);
std::vector<float> input1_data(32 * 8 * 7168);
for (auto &val : input1_data) {
val = dist(gen);
}
std::vector<float> input2_data(256 * 7168);
for (auto &val : input2_data) {
val = dist(gen);
}
std::vector<float> input3_data(256);
for (auto &val : input3_data) {
val = dist(gen);
}
inputs.push_back(*ge::Utils::StubTensor<float>(input1_data, {32, 8, 7168}));
inputs.push_back(*ge::Utils::StubTensor<float>(input2_data, {256, 7168}));
inputs.push_back(*ge::Utils::StubTensor<float>(input3_data, {256}));
return RunGraph(*graph, inputs, "Transformer");
}
}