* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_graph_only.cpp
* \brief
*/
#include "gtest/gtest.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/configs/config_manager.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "operator/models/llama/llama_def.h"
#include "operator/models/deepseek/deepseek_mla.h"
#include "operator/models/deepseek/mla_prolog.h"
#include "operator/models/deepseek/deepseek_spec.h"
#include "tilefwk/data_type.h"
using namespace npu::tile_fwk;
class GraphTest : public testing::Test {
public:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
void SetUp() override
{
Program::GetInstance().Reset();
config::Reset();
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
config::SetHostConfig(KEY_STRATEGY, "PVC2_OOO");
config::SetPlatformConfig(KEY_ENABLE_COST_MODEL, false);
config::SetSimConfig(KEY_BUILD_TASK_BASED_TOPO, false);
}
void TearDown() override {}
};
void RunLLamaLayerGraph(const AttentionDims& dimsCfg)
{
int b = dimsCfg.b;
int n = dimsCfg.n;
int s = dimsCfg.s;
int d = dimsCfg.d;
PROGRAM("LLAMALAYER")
{
Tensor H(DataType::DT_FP32, {b * s, n * d}, "H");
Tensor AW(DataType::DT_FP16, {n * d, n * d * 3}, "AW");
Tensor DW(DataType::DT_FP16, {n * d, n * d}, "DW");
Tensor FW(DataType::DT_FP16, {n * d, n * d * 3}, "FW");
Tensor Res(DT_FP32, {b * s, n * d}, "Res");
FUNCTION("LLAMA") { Res = LlamaLayer(H, AW, DW, FW, dimsCfg, SMALL_DFS_VEC_CFG, DFS_CUBE_CFG); }
}
}
TEST_F(GraphTest, llama_1_1_128_128)
{
AttentionDims dimsCfg = {1, 1, 128, 128, DFT_SINGLE_M, DFT_SINGLE_N};
RunLLamaLayerGraph(dimsCfg);
}
TEST_F(GraphTest, llama_1_1_256_128)
{
AttentionDims dimsCfg = {1, 1, 256, 128, DFT_SINGLE_M, DFT_SINGLE_N};
RunLLamaLayerGraph(dimsCfg);
}
TEST_F(GraphTest, llama_1_1_512_128)
{
AttentionDims dimsCfg = {1, 1, 512, 128, DFT_SINGLE_M, DFT_SINGLE_N};
RunLLamaLayerGraph(dimsCfg);
}
TEST_F(GraphTest, llama_1_1_256_128_mix)
{
config::SetPassConfig("PVC2_OOO", "PreGraphProcess", KEY_PRE_CHECK, false);
config::SetPassConfig("PVC2_OOO", "PreGraphProcess", KEY_POST_CHECK, false);
AttentionDims dimsCfg = {1, 1, 256, 128, DFT_SINGLE_M, DFT_SINGLE_N};
RunLLamaLayerGraph(dimsCfg);
}
TEST_F(GraphTest, deepseek_qkvPre)
{
int b = 2;
int s = 128;
int h = std::get<int>(deepseekConfig1["hiddenSize"]);
int num_heads = std::get<int>(deepseekConfig1["numAttentionHeads"]);
int qLoraRank = std::get<int>(deepseekConfig1["qLoraRank"]);
int qkRopeHeadDim = std::get<int>(deepseekConfig1["qkRopeHeadDim"]);
int kvLoraRank = std::get<int>(deepseekConfig1["kvLoraRank"]);
int vHeadDim = std::get<int>(deepseekConfig1["vHeadDim"]);
int qkNopeHeadDim = std::get<int>(deepseekConfig1["qkNopeHeadDim"]);
int q_head_dim = qkNopeHeadDim + qkRopeHeadDim;
Tensor hidden_states = Tensor(DT_BF16, {b, s, h}, "hidden_states");
AttentionW aw;
aw.qAProjW = Tensor(DT_BF16, {h, qLoraRank}, "qAProjW");
aw.qBProjW = Tensor(DT_BF16, {qLoraRank, num_heads * q_head_dim}, "qBProjW");
aw.kvAProjWithMqaW = Tensor(DT_BF16, {h, kvLoraRank + qkRopeHeadDim}, "kvAProjWithMqaW");
aw.kvBProjWK = Tensor(DT_BF16, {num_heads, qkNopeHeadDim, kvLoraRank}, "kvBProjWK");
aw.kvBProjWV = Tensor(DT_BF16, {num_heads, kvLoraRank, vHeadDim}, "kvBProjWV");
aw.oProjW = Tensor(DT_BF16, {num_heads * vHeadDim, h}, "oProjW");
std::tuple<Tensor, Tensor> res;
DeepseekAttention Attention(deepseekConfig1, aw, 1);
FUNCTION("A") { res = Attention.QkvPre(hidden_states); }
}
TEST_F(GraphTest, TestAttentionPost)
{
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
int b = 1;
int n = 2;
int s = 128;
int d = 512;
int v_head = 128;
int h = 256;
std::vector<int64_t> inShape = {b, n, s, d};
Tensor attnPostIn(DT_FP32, inShape, "attnPostIn");
Tensor kvBProjWV(DT_FP32, {n, d, v_head}, "kvBProjWV");
Tensor oProjW(DT_FP32, {n * v_head, h}, "oProjW");
Tensor atten_output;
ConfigManager::Instance();
FUNCTION("AttentionPost")
{
int new_b = attnPostIn.GetShape()[0];
int new_n = attnPostIn.GetShape()[1];
int new_s = attnPostIn.GetShape()[2];
DataType dType = attnPostIn.GetStorage()->Datatype();
TileShape::Current().SetVecTile({1, 1, 32, d});
Tensor atten_res1 = Reshape(Transpose(attnPostIn, {1, 2}), {new_b * new_s, new_n, d});
TileShape::Current().SetVecTile({32, 1, d});
Tensor atten_res2 = Transpose(atten_res1, {0, 1});
TileShape::Current().SetVecTile(128, 128);
TileShape::Current().SetCubeTile({32, 32}, {128, 128}, {128, 128});
Tensor mm7_res = Matrix::BatchMatmul(dType, atten_res2, kvBProjWV);
TileShape::Current().SetVecTile({1, 128, 128});
Tensor mm7_res1 = Transpose(mm7_res, {0, 1});
Tensor mm7_res2 = Reshape(mm7_res1, {new_b, new_s, new_n * v_head});
Tensor attn_out_w = Unsqueeze(oProjW, 0);
atten_output = Matrix::BatchMatmul(dType, mm7_res2, attn_out_w);
}
}
TEST_F(GraphTest, Test_deepseekAttention_s_1)
{
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
int b = 2;
int s = 1;
int s2 = 512;
int h = std::get<int>(deepseekConfig1["hiddenSize"]);
int num_heads = std::get<int>(deepseekConfig1["numAttentionHeads"]);
int qLoraRank = std::get<int>(deepseekConfig1["qLoraRank"]);
int qkRopeHeadDim = std::get<int>(deepseekConfig1["qkRopeHeadDim"]);
int kvLoraRank = std::get<int>(deepseekConfig1["kvLoraRank"]);
int vHeadDim = std::get<int>(deepseekConfig1["vHeadDim"]);
int qkNopeHeadDim = std::get<int>(deepseekConfig1["qkNopeHeadDim"]);
int q_head_dim = qkNopeHeadDim + qkRopeHeadDim;
Tensor hidden_states = Tensor(DT_BF16, {b, s, h}, "hidden_states");
Tensor atten_mask = Tensor(DT_FP32, {b, 1, s, s2}, "atten_mask");
Tensor position_ids = Tensor(DT_INT32, {b, s}, "position_ids");
Tensor cos = Tensor(DT_BF16, {s, qkRopeHeadDim}, "cos");
Tensor sin = Tensor(DT_BF16, {s, qkRopeHeadDim}, "sin");
Tensor kv_len = Tensor(DT_INT32, {1, 1}, "kv_len");
Tensor past_key_states = Tensor(DT_BF16, {b, 1, s2, kvLoraRank + qkRopeHeadDim}, "past_key_states");
AttentionW aw;
aw.qAProjW = Tensor(DT_BF16, {h, qLoraRank}, "qAProjW");
aw.qBProjW = Tensor(DT_BF16, {qLoraRank, num_heads * q_head_dim}, "qBProjW");
aw.kvAProjWithMqaW = Tensor(DT_BF16, {h, kvLoraRank + qkRopeHeadDim}, "kvAProjWithMqaW");
aw.kvBProjWK = Tensor(DT_BF16, {num_heads, qkNopeHeadDim, kvLoraRank}, "kvBProjWK");
aw.kvBProjWV = Tensor(DT_BF16, {num_heads, kvLoraRank, vHeadDim}, "kvBProjWV");
aw.oProjW = Tensor(DT_BF16, {num_heads * vHeadDim, h}, "oProjW");
RoPETileShapeConfig ropeTileConfig{{32, 32}, {1, 32, 32}, {1, 1, 32, 32}, {1, 1, 32, 32, 2}};
Tensor res;
DeepseekAttention deepseekAttention(deepseekConfig1, aw, 1);
ConfigManager::Instance();
FUNCTION("A")
{
res = deepseekAttention.Forward(
hidden_states, atten_mask, position_ids, cos, sin, kv_len, past_key_states, ropeTileConfig);
}
}
TEST_F(GraphTest, Test_deepseekAttention_pre)
{
config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
int b = 2;
int s = 1;
int s2 = 256;
int h = std::get<int>(deepseekConfig1["hiddenSize"]);
int num_heads = std::get<int>(deepseekConfig1["numAttentionHeads"]);
int qLoraRank = std::get<int>(deepseekConfig1["qLoraRank"]);
int qkRopeHeadDim = std::get<int>(deepseekConfig1["qkRopeHeadDim"]);
int kvLoraRank = std::get<int>(deepseekConfig1["kvLoraRank"]);
int vHeadDim = std::get<int>(deepseekConfig1["vHeadDim"]);
int qkNopeHeadDim = std::get<int>(deepseekConfig1["qkNopeHeadDim"]);
int q_head_dim = qkNopeHeadDim + qkRopeHeadDim;
Tensor hidden_states = Tensor(DT_BF16, {b, s, h}, "hidden_states");
Tensor atten_mask = Tensor(DT_FP32, {b, 1, s, s2}, "atten_mask");
Tensor position_ids = Tensor(DT_INT32, {b, s}, "position_ids");
Tensor cos = Tensor(DT_BF16, {s, qkRopeHeadDim}, "cos");
Tensor sin = Tensor(DT_BF16, {s, qkRopeHeadDim}, "sin");
Tensor kv_len = Tensor(DT_INT32, {1, 1}, "kv_len");
Tensor past_key_states = Tensor(DT_BF16, {b, 1, s2, kvLoraRank + qkRopeHeadDim}, "past_key_states");
AttentionW aw;
aw.qAProjW = Tensor(DT_BF16, {h, qLoraRank}, "qAProjW");
aw.qBProjW = Tensor(DT_BF16, {qLoraRank, num_heads * q_head_dim}, "qBProjW");
aw.kvAProjWithMqaW = Tensor(DT_BF16, {h, kvLoraRank + qkRopeHeadDim}, "kvAProjWithMqaW");
aw.kvBProjWK = Tensor(DT_BF16, {num_heads, qkNopeHeadDim, kvLoraRank}, "kvBProjWK");
aw.kvBProjWV = Tensor(DT_BF16, {num_heads, kvLoraRank, vHeadDim}, "kvBProjWV");
aw.oProjW = Tensor(DT_BF16, {num_heads * vHeadDim, h}, "oProjW");
RoPETileShapeConfig ropeTileConfig{{32, 32}, {1, 32, 32}, {1, 1, 32, 32}, {1, 1, 32, 32, 2}};
std::tuple<Tensor, Tensor> res;
DeepseekAttention deepseekAttention(deepseekConfig1, aw, 1);
ConfigManager::Instance();
FUNCTION("A")
{
res = deepseekAttention.AtentionPreForward(
hidden_states, atten_mask, position_ids, cos, sin, kv_len, past_key_states, ropeTileConfig);
}
}
TEST_F(GraphTest, test_operation_rope_subgraph_deepseekv3_bf16)
{
RoPETileShapeConfig ropeTileConfig{
{64, 64},
{1, 64, 64},
{1, 64, 1, 64},
{1, 64, 1, 32, 2}
};
int B = 1;
int N = 32;
int S = 1;
int qkRopeHeadDim = 64;
std::vector<int64_t> qPeShape{B, S, N, qkRopeHeadDim};
std::vector<int64_t> kPeShape{B, S, qkRopeHeadDim};
std::vector<int64_t> idsShape{B, S};
std::vector<int64_t> cosShape{S, qkRopeHeadDim};
std::vector<int64_t> qEmbedShape{B, N, S, qkRopeHeadDim};
std::vector<int64_t> kEmbedShape{B, 1, S, qkRopeHeadDim};
Tensor qPe(DT_BF16, qPeShape, "qPe");
Tensor kPe(DT_BF16, kPeShape, "kPe");
Tensor cos(DT_BF16, cosShape, "cos");
Tensor sin(DT_BF16, cosShape, "sin");
Tensor positionIds(DT_INT32, idsShape, "positionIds");
Tensor qEmbed(DT_BF16, qEmbedShape, "qEmbed");
Tensor kEmbed(DT_BF16, kEmbedShape, "kEmbed");
ConfigManager::Instance();
FUNCTION("RoPE")
{
TileShape::Current().SetVecTile({1, 1, 64, 64});
auto qPeTrans = Transpose(qPe, {1, 2});
int b = kPe.GetShape()[0];
int s = kPe.GetShape()[1];
int d = kPe.GetShape()[2];
auto kPeReshape = Reshape(kPe, {b, 1, s, d});
ApplyRotaryPosEmb(qPeTrans, kPeReshape, cos, sin, positionIds, qEmbed, kEmbed, 1, ropeTileConfig);
}
}
template <bool splitReduceLastDim = true, bool splitK = false, bool nz = false>
void TestMlaPrologV2(std::vector<int>& params, int inputType, bool isQuant = false)
{
int b = params[0];
int s = params[1];
int s2 = params[2];
int n = params[3];
int h = params[4];
int qLoraRank = params[5];
int qkNopeHeadDim = params[6];
int qkRopeHeadDim = params[7];
int kvLoraRank = params[8];
int q_head_dim = qkNopeHeadDim + qkRopeHeadDim;
DataType dType = DT_FP32;
if (inputType == 0) {
dType = DT_FP16;
} else if (inputType == 1) {
dType = DT_BF16;
} else {
dType = DT_FP32;
}
DataType dTypeQuantIn = isQuant ? DT_INT8 : dType;
std::vector<int64_t> x_shape = {b, s, h};
std::vector<int64_t> w_qa_shape = {h, qLoraRank};
std::vector<int64_t> w_qb_shape = {qLoraRank, n * q_head_dim};
std::vector<int64_t> w_kv_a_shape = {h, kvLoraRank + qkRopeHeadDim};
std::vector<int64_t> w_kv_b_k_shape = {n, qkNopeHeadDim, kvLoraRank};
std::vector<int64_t> cos_shape = {b, s, qkRopeHeadDim};
std::vector<int64_t> gamma_cq_shape = {qLoraRank};
std::vector<int64_t> gamma_ckv_shape = {kvLoraRank};
std::vector<int64_t> kv_len_shape = {b, s};
std::vector<int64_t> kv_cache_shape = {b, 1, s2, kvLoraRank};
std::vector<int64_t> kr_cache_shape = {b, 1, s2, qkRopeHeadDim};
std::vector<int64_t> q_out_shape = {b, s, n, kvLoraRank};
std::vector<int64_t> q_rope_out_shape = {b, s, n, qkRopeHeadDim};
std::vector<int64_t> kv_cache_out_shape = {b, 1, s2, kvLoraRank};
std::vector<int64_t> kr_cache_out_shape = {b, 1, s2, qkRopeHeadDim};
std::vector<int64_t> w_qb_scale_shape;
if (isQuant) {
w_qb_scale_shape = {1, n * q_head_dim};
}
ConfigManager::Instance();
PROGRAM("MlaProlog")
{
Tensor x(dType, x_shape, "x");
Tensor w_qa(dType, w_qa_shape, "w_qa");
Tensor w_qb(dTypeQuantIn, w_qb_shape, "w_qb");
Tensor w_kv_a(dType, w_kv_a_shape, "w_kv_a");
Tensor w_kv_b_k(dType, w_kv_b_k_shape, "w_kv_b_k");
Tensor gamma_cq(dType, gamma_cq_shape, "gamma_cq");
Tensor gamma_ckv(dType, gamma_ckv_shape, "gamma_ckv");
Tensor cos(dType, cos_shape, "cos");
Tensor sin(dType, cos_shape, "sin");
Tensor kv_len(DT_INT64, kv_len_shape, "kv_len");
Tensor kv_cache(dType, kv_cache_shape, "kv_cache");
Tensor kr_cache(dType, kr_cache_shape, "kr_cache");
Tensor output_q(dType, q_out_shape, "output_q");
Tensor output_q_rope(dType, q_rope_out_shape, "output_q_rope");
RoPETileShapeConfigNew ropeConfig{
{32, 1, 64},
{1, 1, 32, 64},
{32, 1, 1, 64},
{32, 1, 1, 32, 2}
};
MlaQuantInputs quantInputs;
if (isQuant) {
Tensor w_qb_scale = Tensor(DT_FP32, w_qb_scale_shape, "w_qb_scale");
quantInputs.dequantScaleWUqQr = w_qb_scale;
config::SetBuildStatic(true);
FUNCTION(
"MlaProlog_T", {x, w_qa, w_qb, w_qb_scale, w_kv_b_k, w_kv_a, gamma_cq, gamma_ckv, sin, cos, kv_len,
kv_cache, kr_cache, output_q, output_q_rope})
{
MlaProlog(
x, w_qa, w_qb, w_kv_b_k, w_kv_a, gamma_cq, gamma_ckv, sin, cos, kv_len, kv_cache, kr_cache,
quantInputs, ropeConfig, output_q, output_q_rope, kv_cache, kr_cache, 1e-5f, 1e-5f, "BNSD",
splitReduceLastDim, splitK);
};
} else {
config::SetBuildStatic(true);
FUNCTION(
"MlaProlog_T", {x, w_qa, w_qb, w_kv_b_k, w_kv_a, gamma_cq, gamma_ckv, sin, cos, kv_len, kv_cache,
kr_cache, output_q, output_q_rope})
{
MlaProlog(
x, w_qa, w_qb, w_kv_b_k, w_kv_a, gamma_cq, gamma_ckv, sin, cos, kv_len, kv_cache, kr_cache,
quantInputs, ropeConfig, output_q, output_q_rope, kv_cache, kr_cache, 1e-5f, 1e-5f, "BNSD",
splitReduceLastDim, splitK);
};
}
}
}
void TestMlaProlog(std::vector<int>& params)
{
int b = params[0];
int s = params[1];
int s2 = params[2];
int n = params[3];
int h = params[4];
int qLoraRank = params[5];
int qkNopeHeadDim = params[6];
int qkRopeHeadDim = params[7];
int kvLoraRank = params[8];
int q_head_dim = qkNopeHeadDim + qkRopeHeadDim;
DataType dType = DT_BF16;
std::vector<int64_t> x_shape = {b, s, h};
std::vector<int64_t> w_qa_shape = {h, qLoraRank};
std::vector<int64_t> w_qb_shape = {qLoraRank, n * q_head_dim};
std::vector<int64_t> w_kv_a_shape = {h, kvLoraRank + qkRopeHeadDim};
std::vector<int64_t> w_kv_b_k_shape = {n, qkNopeHeadDim, kvLoraRank};
std::vector<int64_t> position_ids_shape = {b, s};
std::vector<int64_t> cos_shape = {s, qkRopeHeadDim};
std::vector<int64_t> past_key_states_shape = {b, 1, s2, kvLoraRank + qkRopeHeadDim};
std::vector<int64_t> kv_len_shape = {1, 1};
std::vector<int64_t> q_shape = {b, n, s, kvLoraRank + qkRopeHeadDim};
std::vector<int64_t> kv_shape = {b, 1, s2, kvLoraRank + qkRopeHeadDim};
PROGRAM("MlaProlog")
{
Tensor x(dType, x_shape, "x");
Tensor w_qa(dType, w_qa_shape, "w_qa");
Tensor w_qb(dType, w_qb_shape, "w_qb");
Tensor w_kv_a(dType, w_kv_a_shape, "w_kv_a");
Tensor w_kv_b_k(dType, w_kv_b_k_shape, "w_kv_b_k");
Tensor position_ids(DT_INT32, position_ids_shape, "position_ids");
Tensor cos(dType, cos_shape, "cos");
Tensor sin(dType, cos_shape, "sin");
Tensor past_key_states(dType, past_key_states_shape, "past_key_states");
Tensor kv_len(DT_INT32, kv_len_shape, "kv_len");
Tensor output_q(dType, q_shape, "output_q");
AttentionW aw;
aw.qAProjW = w_qa;
aw.qBProjW = w_qb;
aw.kvAProjWithMqaW = w_kv_a;
aw.kvBProjWK = w_kv_b_k;
Tensor kvBProjWV;
Tensor oProjW;
aw.kvBProjWV = kvBProjWV;
aw.oProjW = oProjW;
std::tuple<Tensor, Tensor> res;
DeepseekAttention Attention(g_deepseekConfig, aw, 1);
RoPETileShapeConfig ropeTileConfig{
{32, 64},
{1, 32, 64},
{1, 32, 1, 64},
{1, 32, 1, 64, 64}
};
config::SetBuildStatic(true);
FUNCTION(
"MlaProlog_T", {x, w_qa, w_qb, w_kv_a, w_kv_b_k, position_ids, cos, sin, past_key_states, kv_len, output_q})
{
auto q_kv = Attention.MlaPrologFoward(x, position_ids, cos, sin, kv_len, past_key_states, ropeTileConfig);
output_q = q_kv[0];
past_key_states = q_kv[1];
}
}
}
TEST_F(GraphTest, test_attention_bf16_4_1024_1024_32_256)
{
config::SetPassOption(VEC_NBUFFER_SETTING, std::map<int64_t, int64_t>{{-1, 2}});
int& h = std::get<int>(g_deepseekConfig["hiddenSize"]);
int& n = std::get<int>(g_deepseekConfig["numAttentionHeads"]);
int& qLoraRank = std::get<int>(g_deepseekConfig["qLoraRank"]);
int& qkRopeHeadDim = std::get<int>(g_deepseekConfig["qkRopeHeadDim"]);
int& kvLoraRank = std::get<int>(g_deepseekConfig["kvLoraRank"]);
int& vHeadDim = std::get<int>(g_deepseekConfig["vHeadDim"]);
int& qkNopeHeadDim = std::get<int>(g_deepseekConfig["qkNopeHeadDim"]);
int b = 4;
int s = 1;
int s2 = 1024;
h = 1024;
n = 32;
qLoraRank = 256;
qkNopeHeadDim = 128;
qkRopeHeadDim = 64;
kvLoraRank = 512;
vHeadDim = 128;
std::vector<int> params = {b, s, s2, n, h, qLoraRank, qkNopeHeadDim, qkRopeHeadDim, kvLoraRank, vHeadDim};
TestMlaProlog(params);
}
void TestLoopTailBlock(const Tensor& t0, const Tensor& blockTable, Tensor& out, int s)
{
int blockSize = 64;
FUNCTION("main", {t0, blockTable}, {out})
{
LOOP("L0", FunctionType::DYNAMIC_LOOP, i, LoopRange(GetInputShape(t0, 0) / s))
{
SymbolicScalar size = GetTensorData(blockTable, {i, 0});
Tensor t0s = View(t0, {s, s}, {size, s}, {blockSize * i, 0});
Tensor t1s = View(t0, {s / 2, s}, {size, s}, {blockSize * i, 0});
Tensor t1 = Add(t1s, t1s);
Assemble(t1, {blockSize * i, 0}, out);
}
}
}
TEST_F(GraphTest, TestTailBlock)
{
TileShape::Current().SetVecTile(32, 32);
TileShape::Current().SetCubeTile({32, 32}, {32, 32}, {32, 32});
int s = 64;
int n = 8;
Tensor t0(DT_FP32, {n * s, s}, "t0");
Tensor blockTable{DT_INT32, {n, 1}, "blockTable"};
Tensor out(DT_FP32, {n * s, s}, "out");
TestLoopTailBlock(t0, blockTable, out, s);
}
TEST_F(GraphTest, TestTranspose_MLA_3D_2_add)
{
int bs = 8;
int n = 32;
int d = 128;
std::vector<int64_t> shape{bs, n, d};
std::vector<int64_t> resShape{n, bs, d};
PROGRAM("Transpose")
{
Tensor input(DataType::DT_FP32, shape, "input");
Tensor output(DataType::DT_FP32, resShape, "res");
config::SetBuildStatic(true);
FUNCTION("MLA_3D_2", {input, output})
{
TileShape::Current().SetVecTile(NUM_2, NUM_2, NUM_128);
auto tmp = Transpose(input, {0, 1});
TileShape::Current().SetVecTile(NUM_8, NUM_8, NUM_128);
output = Add(tmp, Element(DataType::DT_FP32, 0.0));
}
}
}
TEST_F(GraphTest, TestTranspose_MLA_3D_2_reshape)
{
int bs = 8;
int n = 32;
int d = 128;
std::vector<int64_t> shape{bs, n, d};
std::vector<int64_t> transposeShape{n, bs, d};
std::vector<int64_t> resShape{n, bs * d};
std::vector<int64_t> flattenShape{n * bs * d};
PROGRAM("Transpose")
{
Tensor input(DataType::DT_FP32, shape, "input");
Tensor output1(DataType::DT_FP32, transposeShape, "res1");
Tensor output2(DataType::DT_FP32, resShape, "res2");
Tensor output3(DataType::DT_FP32, flattenShape, "res3");
config::SetBuildStatic(true);
FUNCTION("MLA_3D_2", {input, output1, output2})
{
TileShape::Current().SetVecTile(NUM_2, NUM_2, NUM_128);
output1 = Transpose(input, {0, 1});
TileShape::Current().SetVecTile(NUM_8, NUM_8, NUM_128);
output2 = Reshape(output1, resShape);
output3 = Reshape(output1, {-1});
}
}
}