#include "chrome/browser/accessibility/phrase_segmentation/dependency_parser_model_loader.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros.h"
#include "components/optimization_guide/core/delivery/optimization_guide_model_provider.h"
namespace {
base::File LoadModelFile(const base::FilePath& model_file_path) {
if (!base::PathExists(model_file_path)) {
return base::File();
}
return base::File(model_file_path,
base::File::FLAG_OPEN | base::File::FLAG_READ);
}
void CloseModelFile(base::File model_file) {
if (!model_file.IsValid()) {
return;
}
model_file.Close();
}
class ScopedModelLoadingResultRecorder {
public:
ScopedModelLoadingResultRecorder() = default;
~ScopedModelLoadingResultRecorder() {
UMA_HISTOGRAM_BOOLEAN(
"Accessibility.DependencyParserModelLoader.DependencyParserModel."
"WasLoaded",
was_loaded_);
}
void SetLoaded() { was_loaded_ = true; }
private:
bool was_loaded_ = false;
};
constexpr int kMaxPendingRequestsAllowed = 100;
}
DependencyParserModelLoader::DependencyParserModelLoader(
optimization_guide::OptimizationGuideModelProvider* opt_guide,
const scoped_refptr<base::SequencedTaskRunner>& background_task_runner)
: opt_guide_(opt_guide), background_task_runner_(background_task_runner) {
opt_guide_->AddObserverForOptimizationTargetModel(
optimization_guide::proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION,
std::nullopt, background_task_runner, this);
}
DependencyParserModelLoader::~DependencyParserModelLoader() {
opt_guide_->RemoveObserverForOptimizationTargetModel(
optimization_guide::proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION, this);
NotifyModelUpdatesAndClear(false);
}
void DependencyParserModelLoader::Shutdown() {
UnloadModelFile();
NotifyModelUpdatesAndClear(false);
}
void DependencyParserModelLoader::UnloadModelFile() {
if (dependency_parser_model_file_) {
background_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&CloseModelFile,
std::move(*dependency_parser_model_file_)));
}
}
void DependencyParserModelLoader::NotifyModelUpdatesAndClear(
bool is_model_available) {
for (auto& pending_request : pending_model_requests_) {
std::move(pending_request).Run(is_model_available);
}
pending_model_requests_.clear();
}
void DependencyParserModelLoader::OnModelUpdated(
optimization_guide::proto::OptimizationTarget optimization_target,
base::optional_ref<const optimization_guide::ModelInfo> model_info) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (optimization_target !=
optimization_guide::proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION) {
return;
}
if (!model_info.has_value()) {
UnloadModelFile();
NotifyModelUpdatesAndClear(false);
return;
}
background_task_runner_->PostTaskAndReplyWithResult(
FROM_HERE, base::BindOnce(&LoadModelFile, model_info->GetModelFilePath()),
base::BindOnce(&DependencyParserModelLoader::OnModelFileLoaded,
weak_ptr_factory_.GetWeakPtr()));
}
void DependencyParserModelLoader::OnModelFileLoaded(base::File model_file) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ScopedModelLoadingResultRecorder result_recorder;
if (!model_file.IsValid()) {
return;
}
UnloadModelFile();
dependency_parser_model_file_ = std::move(model_file);
result_recorder.SetLoaded();
NotifyModelUpdatesAndClear(true);
}
base::File DependencyParserModelLoader::GetDependencyParserModelFile() {
DCHECK(IsModelAvailable());
if (!dependency_parser_model_file_) {
return base::File();
}
DCHECK(dependency_parser_model_file_->IsValid());
return dependency_parser_model_file_->Duplicate();
}
void DependencyParserModelLoader::NotifyOnModelFileAvailable(
NotifyModelAvailableCallback callback) {
DCHECK(!IsModelAvailable());
if (pending_model_requests_.size() < kMaxPendingRequestsAllowed) {
pending_model_requests_.emplace_back(std::move(callback));
return;
}
std::move(callback).Run(false);
}