910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/passage_embeddings/passage_embedder.h"

#include <utility>

#include "base/files/file.h"
#include "base/files/memory_mapped_file.h"
#include "base/metrics/histogram_functions.h"
#include "base/timer/elapsed_timer.h"
#include "base/trace_event/trace_event.h"
#include "base/trace_event/trace_id_helper.h"
#include "base/trace_event/typed_macros.h"
#include "build/build_config.h"
#include "services/passage_embeddings/passage_embeddings_op_resolver.h"
#include "third_party/sentencepiece/src/src/sentencepiece_model.pb.h"

namespace {
// Records duration and trace event for embeddings generation.
void RecordEmbeddingsDurationMetrics(
    bool is_passive,
    base::TimeTicks start_time,
    base::TimeDelta elapsed_time,
    std::optional<base::TimeDelta> elapsed_thread_time) {
  const auto trace_track =
      perfetto::Track(base::trace_event::GetNextGlobalTraceId());

  if (is_passive) {
    TRACE_EVENT_BEGIN("loading", "PassageEmbeddingsGeneration", trace_track,
                      start_time);
    if (elapsed_thread_time.has_value()) {
      base::UmaHistogramMediumTimes(
          "History.Embeddings.Embedder."
          "PassageEmbeddingsGenerationThreadDuration",
          *elapsed_thread_time);
    }
    base::UmaHistogramMediumTimes(
        "History.Embeddings.Embedder.PassageEmbeddingsGenerationDuration",
        elapsed_time);
  } else {
    TRACE_EVENT_BEGIN("loading", "QueryEmbeddingsGeneration", trace_track,
                      start_time);
    if (elapsed_thread_time.has_value()) {
      base::UmaHistogramMediumTimes(
          "History.Embeddings.Embedder.QueryEmbeddingsGenerationThreadDuration",
          *elapsed_thread_time);
    }
    base::UmaHistogramMediumTimes(
        "History.Embeddings.Embedder.QueryEmbeddingsGenerationDuration",
        elapsed_time);
  }

  TRACE_EVENT_END("loading", trace_track, start_time + elapsed_time);
}
}  // namespace

namespace passage_embeddings {

PassageEmbedder::PassageEmbedder(
    mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
    mojom::PassageEmbedderParamsPtr embedder_params,
    base::OnceCallback<void()> on_disconnect)
    : receiver_(this, std::move(receiver)),
      embeddings_cache_(embedder_params->embedder_cache_size),
      user_initiated_priority_num_threads_(
          embedder_params->user_initiated_priority_num_threads),
      urgent_priority_num_threads_(
          embedder_params->urgent_priority_num_threads),
      passive_priority_num_threads_(
          embedder_params->passive_priority_num_threads),
      allow_gpu_execution_(embedder_params->allow_gpu_execution) {
  receiver_.set_disconnect_handler(std::move(on_disconnect));
}

PassageEmbedder::~PassageEmbedder() = default;

bool PassageEmbedder::LoadModels(
    base::File embeddings_model_file,
    base::File sp_file,
    uint32_t embeddings_input_window_size,
    std::unique_ptr<tflite::task::core::TfLiteEngine> tflite_engine) {
  UnloadModelFiles();

  embeddings_model_file_ = std::move(embeddings_model_file);

  tflite_engine_overridden_ = !!tflite_engine;
  override_tflite_engine_ = std::move(tflite_engine);

  base::ElapsedTimer sp_timer;
  bool sp_load_success = LoadSentencePieceModelFile(std::move(sp_file));
  base::UmaHistogramBoolean(
      "History.Embeddings.Embedder.SentencePieceModelLoadSucceeded",
      sp_load_success);
  if (!sp_load_success) {
    return false;
  }
  base::UmaHistogramMediumTimes(
      "History.Embeddings.Embedder.SentencePieceModelLoadDuration",
      sp_timer.Elapsed());

  embeddings_input_window_size_ = embeddings_input_window_size;

  return true;
}

bool PassageEmbedder::LoadSentencePieceModelFile(base::File sp_file) {
  base::MemoryMappedFile sp_model;
  bool was_mapped = sp_model.Initialize(std::move(sp_file));
  if (!was_mapped) {
    return false;
  }

  auto model_proto = std::make_unique<sentencepiece::ModelProto>();
  model_proto->ParseFromArray(sp_model.data(), sp_model.length());
  sp_processor_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
  if (!(sp_processor_->Load(std::move(model_proto)).ok())) {
    sp_processor_.reset();
    return false;
  }
  return true;
}

bool PassageEmbedder::BuildExecutionTask() {
  CHECK_NE(current_priority_, mojom::PassagePriority::kUnknown);
  // Do nothing if an override model has been loaded.
  if (tflite_engine_overridden_ && !override_tflite_engine_) {
    return true;
  }

  loaded_model_.reset();

  // Load the override model if it is set but not loaded yet.
  if (tflite_engine_overridden_) {
    loaded_model_ = std::make_unique<PassageEmbedderExecutionTask>(
        std::move(override_tflite_engine_));
    override_tflite_engine_.reset();
    return true;
  }

  // Build a new task from the model bytes and the task priority.
  auto tflite_engine = std::make_unique<tflite::task::core::TfLiteEngine>(
      std::make_unique<PassageEmbeddingsOpResolver>(allow_gpu_execution_));

  base::ElapsedTimer embeddings_timer;
#if BUILDFLAG(IS_WIN)
  absl::Status model_load_status = tflite_engine->BuildModelFromFileHandle(
      embeddings_model_file_.GetPlatformFile());
#else
  absl::Status model_load_status = tflite_engine->BuildModelFromFileDescriptor(
      embeddings_model_file_.GetPlatformFile());
#endif
  base::UmaHistogramBoolean(
      "History.Embeddings.Embedder.EmbeddingsModelLoadSucceeded",
      model_load_status.ok());
  if (!model_load_status.ok()) {
    return false;
  }
  base::UmaHistogramMediumTimes(
      "History.Embeddings.Embedder.EmbeddingsModelLoadDuration",
      embeddings_timer.Elapsed());

  int num_threads;
  switch (current_priority_) {
    case mojom::PassagePriority::kUserInitiated:
      num_threads = user_initiated_priority_num_threads_;
      break;
    case mojom::PassagePriority::kUrgent:
      num_threads = urgent_priority_num_threads_;
      break;
    case mojom::PassagePriority::kPassive:
      num_threads = passive_priority_num_threads_;
      break;
    case mojom::PassagePriority::kUnknown:
      return false;
  }

  absl::Status interpreter_status = tflite_engine->InitInterpreter(num_threads);
  if (!interpreter_status.ok()) {
    return false;
  }

  loaded_model_ =
      std::make_unique<PassageEmbedderExecutionTask>(std::move(tflite_engine));

  return true;
}

void PassageEmbedder::UnloadModelFiles() {
  sp_processor_.reset();
  loaded_model_.reset();
  embeddings_model_file_.Close();
}

std::optional<OutputType> PassageEmbedder::Execute(InputType input) {
  if (!loaded_model_) {
    return std::nullopt;
  }
  return loaded_model_->Execute(input);
}

void PassageEmbedder::GenerateEmbeddings(
    const std::vector<std::string>& inputs,
    mojom::PassagePriority priority,
    PassageEmbedder::GenerateEmbeddingsCallback callback) {
  std::vector<mojom::PassageEmbeddingsResultPtr> results;
  CHECK_NE(priority, mojom::PassagePriority::kUnknown);
  if (!sp_processor_ || !sp_processor_->status().ok()) {
    std::move(callback).Run({});
    return;
  }

  // Rebuild the execution task if necessary.
  if (current_priority_ != priority) {
    current_priority_ = priority;
    BuildExecutionTask();
  }

  for (const std::string& input : inputs) {
    mojom::PassageEmbeddingsResultPtr result =
        mojom::PassageEmbeddingsResult::New();

    auto cache_value = embeddings_cache_.Get(input);
    bool cache_hit = cache_value != embeddings_cache_.end();
    base::UmaHistogramBoolean(kCacheHitMetricName, cache_hit);
    if (cache_hit) {
      result->embeddings = cache_value->second;
      results.push_back(std::move(result));
      continue;
    }

    std::vector<int> tokenized;
    base::ElapsedTimer tokenize_timer;
    auto status = sp_processor_->Encode(input, &tokenized);
    base::UmaHistogramBoolean(
        "History.Embeddings.Embedder.TokenizationSucceeded", status.ok());
    if (!status.ok()) {
      std::move(callback).Run({});
      return;
    }
    base::UmaHistogramCounts1000(
        "History.Embeddings.Embedder.PassageTokenCount", tokenized.size());
    if (tokenized.size() < embeddings_input_window_size_) {
      tokenized.push_back(sp_processor_->eos_id());
    }
    base::UmaHistogramBoolean("History.Embeddings.Embedder.InputTruncated",
                              tokenized.size() > embeddings_input_window_size_);
    tokenized.resize(embeddings_input_window_size_);
    base::TimeDelta tokenize_elapsed = tokenize_timer.Elapsed();
    base::UmaHistogramMediumTimes(
        "History.Embeddings.Embedder.TokenizationDuration", tokenize_elapsed);

    const auto tokenize_start_time = tokenize_timer.start_time();
    const auto trace_track =
        perfetto::Track(base::trace_event::GetNextGlobalTraceId());
    TRACE_EVENT_BEGIN("loading", "PassageTokenization", trace_track,
                      tokenize_start_time);
    TRACE_EVENT_END("loading", trace_track,
                    tokenize_start_time + tokenize_elapsed);

    base::ElapsedThreadTimer execute_thread_timer;
    base::ElapsedTimer execute_timer;
    std::optional<std::vector<float>> embeddings = Execute(tokenized);
    base::UmaHistogramBoolean(
        "History.Embeddings.Embedder.EmbeddingsGenerationSucceeded",
        !!embeddings);
    if (!embeddings) {
      std::move(callback).Run({});
      return;
    }

    RecordEmbeddingsDurationMetrics(
        priority == mojom::PassagePriority::kPassive,
        execute_timer.start_time(), execute_timer.Elapsed(),
        execute_thread_timer.is_supported()
            ? std::optional<base::TimeDelta>(execute_thread_timer.Elapsed())
            : std::nullopt);

    result->embeddings = *embeddings;
    embeddings_cache_.Put({input, *embeddings});

    results.push_back(std::move(result));
  }
  std::move(callback).Run(std::move(results));
}

}  // namespace passage_embeddings