910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/passage_embeddings/passage_embeddings_service.h"

#include "base/path_service.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/passage_embeddings/passage_embedder.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace passage_embeddings {
namespace {

constexpr uint32_t kInputWindowSize = 256;
constexpr size_t kEmbeddingsOutputSize = 768;

class PassageEmbeddingsServiceTest : public testing::Test {
 public:
  PassageEmbeddingsServiceTest()
      : service_impl_(service_.BindNewPipeAndPassReceiver()) {}

  mojo::Remote<mojom::PassageEmbeddingsService>& service() { return service_; }

  void SetUp() override {
    base::FilePath test_data_dir;
    base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir);
    test_data_dir = test_data_dir.AppendASCII("services")
                        .AppendASCII("test")
                        .AppendASCII("data")
                        .AppendASCII("passage_embeddings");
    embeddings_path_ =
        test_data_dir.AppendASCII("dummy_embeddings_model.tflite");
    sp_path_ = test_data_dir.AppendASCII("sentencepiece.model");
  }

  mojom::PassageEmbeddingsLoadModelsParamsPtr MakeModelParams(
      base::FilePath embeddings_path,
      base::FilePath sp_path,
      uint32_t input_window_size) {
    auto params = mojom::PassageEmbeddingsLoadModelsParams::New();
    params->embeddings_model = base::File(
        embeddings_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
    params->sp_model =
        base::File(sp_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
    params->input_window_size = input_window_size;
    return params;
  }

  mojom::PassageEmbedderParamsPtr MakeEmbedderParams() {
    auto params = mojom::PassageEmbedderParams::New();
    params->user_initiated_priority_num_threads = 4;
    params->passive_priority_num_threads = 1;
    params->embedder_cache_size = 1000;
    return params;
  }

 protected:
  base::FilePath embeddings_path_;
  base::FilePath sp_path_;
  base::HistogramTester histogram_tester_;

 private:
  base::test::TaskEnvironment task_environment_;
  mojo::Remote<mojom::PassageEmbeddingsService> service_;
  PassageEmbeddingsService service_impl_;
};

TEST_F(PassageEmbeddingsServiceTest, LoadValidModels) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> future;
  service()->LoadModels(
      MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize),
      MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(),
      future.GetCallback());
  bool load_models_success = future.Get();
  EXPECT_TRUE(load_models_success);
}

TEST_F(PassageEmbeddingsServiceTest, LoadModelsWithInvalidEmbeddingsModel) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> load_models_future;
  service()->LoadModels(MakeModelParams(sp_path_, sp_path_, kInputWindowSize),
                        MakeEmbedderParams(),
                        embedder_remote.BindNewPipeAndPassReceiver(),
                        load_models_future.GetCallback());
  bool load_models_success = load_models_future.Get();
  // LoadModels succeeds since the model file can still be read.
  EXPECT_TRUE(load_models_success);

  base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>>
      execute_future;
  embedder_remote->GenerateEmbeddings({"foo"},
                                      mojom::PassagePriority::kUserInitiated,
                                      execute_future.GetCallback());
  std::vector<mojom::PassageEmbeddingsResultPtr> results =
      execute_future.Take();
  // Execution fails since the embeddings model is invalid.
  EXPECT_EQ(results.size(), 0u);
}

TEST_F(PassageEmbeddingsServiceTest, LoadModelsWithInvalidSpModel) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> future;
  service()->LoadModels(
      MakeModelParams(embeddings_path_, embeddings_path_, kInputWindowSize),
      MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(),
      future.GetCallback());
  bool load_models_success = future.Get();
  EXPECT_FALSE(load_models_success);
}

TEST_F(PassageEmbeddingsServiceTest, LoadModelsWithInvalidInputWindowSize) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> future;
  service()->LoadModels(
      MakeModelParams(embeddings_path_, sp_path_, 0u), MakeEmbedderParams(),
      embedder_remote.BindNewPipeAndPassReceiver(), future.GetCallback());
  bool load_models_success = future.Get();
  EXPECT_FALSE(load_models_success);
}

TEST_F(PassageEmbeddingsServiceTest, RespondsWithEmbeddings) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> load_models_future;
  service()->LoadModels(
      MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize),
      MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(),
      load_models_future.GetCallback());
  bool load_models_success = load_models_future.Get();
  EXPECT_TRUE(load_models_success);

  base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>>
      execute_future;
  embedder_remote->GenerateEmbeddings({"hello", "world", ""},
                                      mojom::PassagePriority::kUserInitiated,
                                      execute_future.GetCallback());
  auto results = execute_future.Take();
  EXPECT_EQ(results.size(), 3u);
  for (const auto& result : results) {
    EXPECT_EQ(result->embeddings.size(), kEmbeddingsOutputSize);
  }

  histogram_tester_.ExpectUniqueSample(kCacheHitMetricName, false, 3);
}

TEST_F(PassageEmbeddingsServiceTest, CacheHits) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> load_models_future;
  service()->LoadModels(
      MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize),
      MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(),
      load_models_future.GetCallback());
  bool load_models_success = load_models_future.Get();
  EXPECT_TRUE(load_models_success);

  base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>>
      execute_future;
  embedder_remote->GenerateEmbeddings(
      {"hello", "world", "hello", "world", "foo", ""},
      mojom::PassagePriority::kUserInitiated, execute_future.GetCallback());
  auto results = execute_future.Take();

  EXPECT_EQ(results.size(), 6u);

  EXPECT_EQ(results[0]->embeddings, results[2]->embeddings);
  EXPECT_EQ(results[1]->embeddings, results[3]->embeddings);

  for (const auto& result : results) {
    EXPECT_EQ(result->embeddings.size(), kEmbeddingsOutputSize);
  }

  histogram_tester_.ExpectTotalCount(kCacheHitMetricName, 6);
  histogram_tester_.ExpectBucketCount(kCacheHitMetricName, true, 2);
  histogram_tester_.ExpectBucketCount(kCacheHitMetricName, false, 4);
}

TEST_F(PassageEmbeddingsServiceTest, RecordsDurationHistogramsWithPriority) {
  mojo::Remote<mojom::PassageEmbedder> embedder_remote;
  base::test::TestFuture<bool> load_models_future;
  service()->LoadModels(
      MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize),
      MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(),
      load_models_future.GetCallback());
  std::ignore = load_models_future.Take();

  base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>>
      execute_future;
  embedder_remote->GenerateEmbeddings({"hello", "world"},
                                      mojom::PassagePriority::kPassive,
                                      execute_future.GetCallback());
  std::ignore = execute_future.Take();

  embedder_remote->GenerateEmbeddings({"foo"},
                                      mojom::PassagePriority::kUserInitiated,
                                      execute_future.GetCallback());
  std::ignore = execute_future.Take();

  histogram_tester_.ExpectTotalCount(
      "History.Embeddings.Embedder.PassageEmbeddingsGenerationDuration", 2);
  histogram_tester_.ExpectTotalCount(
      "History.Embeddings.Embedder.PassageEmbeddingsGenerationThreadDuration",
      2);
  histogram_tester_.ExpectTotalCount(
      "History.Embeddings.Embedder.QueryEmbeddingsGenerationDuration", 1);
  histogram_tester_.ExpectTotalCount(
      "History.Embeddings.Embedder.QueryEmbeddingsGenerationThreadDuration", 1);
}

}  // namespace
}  // namespace passage_embeddings