* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
*/
#include <gtest/gtest.h>
#include <cstring>
#include <string>
#include <vector>
#include "construct_execute_request.h"
#include "sampling.h"
namespace mindie_llm {
static std::vector<BlockId> DecodeBlockIds(const std::string &bytes)
{
EXPECT_EQ(bytes.size() % sizeof(BlockId), 0u);
std::vector<BlockId> out(bytes.size() / sizeof(BlockId));
if (!out.empty()) {
std::memcpy(out.data(), bytes.data(), bytes.size());
}
return out;
}
TEST(ConstructExecuteRequestTest, ConstructExecuteModelRequest_SerializesRepeatedBlockTables)
{
SequenceGroupMetaData meta{};
meta.requestId_ = "req_0";
meta.serverid_ = "srv";
meta.seqIds_ = {123};
meta.samplingParams_ = std::make_shared<SamplingParams>();
meta.samplingParams_->n = 1;
meta.samplingParams_->bestOf = 1;
meta.blockIds_.push_back(BlockIds{1, 2, 3});
meta.blockIds_.push_back(BlockIds{});
meta.blockIds_.push_back(BlockIds{9});
SequenceGroupMetaDatas metas{};
metas.metaList.push_back(meta);
SchedulerOutputs out{};
out.forwardMode_ = ForwardMode::DECODE;
ExecuteModelRequestPtr req = std::make_unique<model_execute_data::ExecuteModelRequest>();
ConstructExecuteRequest::ConstructExecuteModelRequest(req, metas, out, 0);
ASSERT_NE(req, nullptr);
ASSERT_EQ(req->seq_group_metadata_list_size(), 1);
const auto &protoMeta = req->seq_group_metadata_list(0);
ASSERT_EQ(static_cast<size_t>(protoMeta.block_tables_size()), meta.blockIds_.size());
EXPECT_EQ(DecodeBlockIds(protoMeta.block_tables(0)), (std::vector<BlockId>{1, 2, 3}));
EXPECT_TRUE(DecodeBlockIds(protoMeta.block_tables(1)).empty());
EXPECT_EQ(DecodeBlockIds(protoMeta.block_tables(2)), (std::vector<BlockId>{9}));
}
TEST(ConstructExecuteRequestTest, ConstructPullKVRequest_SerializesSrcDstRepeatedBlockTables)
{
SequenceGroupMetaData meta{};
meta.requestId_ = "req_kv";
meta.serverid_ = "srv";
meta.dpInstanceId_ = 42;
meta.seqIds_ = {7};
meta.samplingParams_ = std::make_shared<SamplingParams>();
meta.samplingParams_->n = 1;
meta.samplingParams_->bestOf = 1;
meta.promptLens_ = {4};
meta.tokenIds_ = {11, 12, 13, 14};
meta.computedLens_ = {0};
meta.remoteComputedLens_ = {0};
meta.blockIds_.push_back(BlockIds{100, 101});
meta.blockIds_.push_back(BlockIds{});
meta.srcBlockIds_.push_back(BlockIds{200});
meta.srcBlockIds_.push_back(BlockIds{});
SequenceGroupMetaDatas metas{};
metas.metaList.push_back(meta);
PullKVRequestPtr req = ConstructExecuteRequest::ConstructPullKVRequest(metas);
ASSERT_NE(req, nullptr);
ASSERT_EQ(req->pull_kv_infos_size(), 1);
const auto &info = req->pull_kv_infos(0);
ASSERT_EQ(info.dst_block_tables_size(), 2);
ASSERT_EQ(info.src_block_tables_size(), 2);
EXPECT_EQ(DecodeBlockIds(info.dst_block_tables(0)), (std::vector<BlockId>{100, 101}));
EXPECT_TRUE(DecodeBlockIds(info.dst_block_tables(1)).empty());
EXPECT_EQ(DecodeBlockIds(info.src_block_tables(0)), (std::vector<BlockId>{200}));
EXPECT_TRUE(DecodeBlockIds(info.src_block_tables(1)).empty());
}
}