910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/accessibility/phrase_segmentation/dependency_parser_model_loader.h"

#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros.h"
#include "components/optimization_guide/core/delivery/optimization_guide_model_provider.h"

namespace {

// Load the model file at the provided file path.
base::File LoadModelFile(const base::FilePath& model_file_path) {
  if (!base::PathExists(model_file_path)) {
    return base::File();
  }

  return base::File(model_file_path,
                    base::File::FLAG_OPEN | base::File::FLAG_READ);
}

// Close the provided model file.
void CloseModelFile(base::File model_file) {
  if (!model_file.IsValid()) {
    return;
  }
  model_file.Close();
}

// Util class for recording the result of loading the dependency parser model.
// The result is recorded when it goes out of scope and its destructor is
// called.
class ScopedModelLoadingResultRecorder {
 public:
  ScopedModelLoadingResultRecorder() = default;
  ~ScopedModelLoadingResultRecorder() {
    UMA_HISTOGRAM_BOOLEAN(
        "Accessibility.DependencyParserModelLoader.DependencyParserModel."
        "WasLoaded",
        was_loaded_);
  }

  void SetLoaded() { was_loaded_ = true; }

 private:
  bool was_loaded_ = false;
};

// The maximum number of pending model requests allowed to be kept
// by the DependencyParserModelLoader.
constexpr int kMaxPendingRequestsAllowed = 100;

}  // namespace

DependencyParserModelLoader::DependencyParserModelLoader(
    optimization_guide::OptimizationGuideModelProvider* opt_guide,
    const scoped_refptr<base::SequencedTaskRunner>& background_task_runner)
    : opt_guide_(opt_guide), background_task_runner_(background_task_runner) {
  opt_guide_->AddObserverForOptimizationTargetModel(
      optimization_guide::proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION,
      /*model_metadata=*/std::nullopt, background_task_runner, this);
}

DependencyParserModelLoader::~DependencyParserModelLoader() {
  opt_guide_->RemoveObserverForOptimizationTargetModel(
      optimization_guide::proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION, this);
  // Clear any pending requests, no model file is acceptable as shutdown is
  // happening.
  NotifyModelUpdatesAndClear(false);
}

void DependencyParserModelLoader::Shutdown() {
  // This and the optimization guide are keyed services, currently optimization
  // guide is a BrowserContextKeyedService, it will be cleaned first so removing
  // the observer should not be performed.
  UnloadModelFile();
  // Clear any pending requests, no model file is acceptable as shutdown is
  // happening.
  NotifyModelUpdatesAndClear(false);
}

void DependencyParserModelLoader::UnloadModelFile() {
  if (dependency_parser_model_file_) {
    // If the model file is already loaded, it should be closed on a
    // background thread.
    background_task_runner_->PostTask(
        FROM_HERE, base::BindOnce(&CloseModelFile,
                                  std::move(*dependency_parser_model_file_)));
  }
}

void DependencyParserModelLoader::NotifyModelUpdatesAndClear(
    bool is_model_available) {
  for (auto& pending_request : pending_model_requests_) {
    std::move(pending_request).Run(is_model_available);
  }
  pending_model_requests_.clear();
}

void DependencyParserModelLoader::OnModelUpdated(
    optimization_guide::proto::OptimizationTarget optimization_target,
    base::optional_ref<const optimization_guide::ModelInfo> model_info) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (optimization_target !=
      optimization_guide::proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION) {
    return;
  }
  if (!model_info.has_value()) {
    UnloadModelFile();
    NotifyModelUpdatesAndClear(false);
    return;
  }
  background_task_runner_->PostTaskAndReplyWithResult(
      FROM_HERE, base::BindOnce(&LoadModelFile, model_info->GetModelFilePath()),
      base::BindOnce(&DependencyParserModelLoader::OnModelFileLoaded,
                     weak_ptr_factory_.GetWeakPtr()));
}

void DependencyParserModelLoader::OnModelFileLoaded(base::File model_file) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  ScopedModelLoadingResultRecorder result_recorder;
  if (!model_file.IsValid()) {
    return;
  }

  UnloadModelFile();
  dependency_parser_model_file_ = std::move(model_file);
  result_recorder.SetLoaded();
  NotifyModelUpdatesAndClear(true);
}

base::File DependencyParserModelLoader::GetDependencyParserModelFile() {
  DCHECK(IsModelAvailable());
  if (!dependency_parser_model_file_) {
    return base::File();
  }
  // The model must be valid at this point.
  DCHECK(dependency_parser_model_file_->IsValid());
  return dependency_parser_model_file_->Duplicate();
}

void DependencyParserModelLoader::NotifyOnModelFileAvailable(
    NotifyModelAvailableCallback callback) {
  DCHECK(!IsModelAvailable());
  if (pending_model_requests_.size() < kMaxPendingRequestsAllowed) {
    pending_model_requests_.emplace_back(std::move(callback));
    return;
  }
  std::move(callback).Run(false);
}