910e62b5创建于 1月15日历史提交
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_AUDIO_ML_MODEL_MANAGER_H_
#define SERVICES_AUDIO_ML_MODEL_MANAGER_H_

#include <optional>

#include "base/files/file.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/thread_annotations.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "services/audio/public/mojom/ml_model_manager.mojom.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"

namespace audio {

struct ModelWithBuffer;

class MlModelHandle {
 public:
  virtual ~MlModelHandle() = default;

  // Returns a pointer to a TFLite model file, valid for the lifetime of the
  // Modelhandle.
  virtual const tflite::FlatBufferModel* Get() = 0;
};

// Interface for providing Machine Learning models within the audio service.
// This interface is used by components like the AudioProcessorHandler to access
// // ML model information.
class MlModelManager {
 public:
  virtual ~MlModelManager() = default;

  // Returns a handle for a TFLite model file for the neural residual echo
  // estimator. Returns nullptr if no valid model file is available.
  // The model shares its lifetime with the MlModelManager, and the client must
  // stop using and destroy the handle within that lifetime.
  virtual std::unique_ptr<MlModelHandle> GetResidualEchoEstimationModel() = 0;
};

// Implementation of the MlModelManager interface. This class receives model
// information from the browser process through the mojom::MlModelManager Mojo
// interface and provides it to audio service components.
//
// Current Behavior:
// - Model files provided via SetResidualEchoEstimationModel() are loaded and
//   cached for as long as they are used or may be used.
// - GetResidualEchoEstimationModel() returns the last successfully loaded
//   model.
// - StopServingResidualEchoEstimationModel() stops serving all previously set
//   models, but the models are kept alive for as long as existing clients need
//   them.
class MlModelManagerImpl : public MlModelManager, public mojom::MlModelManager {
 public:
  MlModelManagerImpl();
  ~MlModelManagerImpl() override;

  MlModelManagerImpl(const MlModelManagerImpl&) = delete;
  MlModelManagerImpl& operator=(const MlModelManagerImpl&) = delete;

  void BindReceiver(mojo::PendingReceiver<mojom::MlModelManager> receiver);

  // mojom::MlModelManager implementation.
  void SetResidualEchoEstimationModel(base::File tflite_file) override;
  void StopServingResidualEchoEstimationModel() override;

  // MlModelManager implementation.
  std::unique_ptr<MlModelHandle> GetResidualEchoEstimationModel() override;

  // Stops any ongoing loading of model files.
  void CancelModelLoadingTasks();

  bool HasPendingTasksForTesting() const;

 private:
  // Callback for MlModelHandle to track active clients for a model.
  void OnModelHandleDestruction(ModelWithBuffer* model_with_buffer);

  // Continuation for SetResidualEchoEstimationModel after work on background
  // thread after work on background thread.
  void OnResidualEchoEstimationModelRead(
      std::unique_ptr<ModelWithBuffer> model_with_buffer);

  SEQUENCE_CHECKER(sequence_checker_);

  std::optional<mojo::Receiver<mojom::MlModelManager>> receiver_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // The service can hold on to multiple models simultaneously. The models
  // travel between three storages:
  //
  // 1. `unused_serving_model_`: A newly loaded model is placed here. It's
  //    available for use but hasn't been requested yet.
  //
  // 2. `used_serving_model_`: When a client requests a model via
  //    `GetResidualEchoEstimationModel()`, the model is moved from
  //    `unused_serving_model_` to here. This indicates the model is actively
  //    in use. If all clients stop using it, it moves back to
  //    `unused_serving_model_`.
  //
  // 3. `retired_models_`: If a new model is loaded while the current one in
  //    `used_serving_model_` is still being used, the `used_serving_model_`
  //    is moved into this map. This keeps the model alive for existing clients
  //    while allowing new clients to use the new model.
  //
  // At most one of `unused_serving_model_` and `used_serving_model_` is
  // non-null at any given time.
  std::unique_ptr<ModelWithBuffer> unused_serving_model_
      GUARDED_BY_CONTEXT(sequence_checker_);
  std::unique_ptr<ModelWithBuffer> used_serving_model_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Maps ModelWithBuffer pointers to the respective full ModelWithBuffer
  // object, for easy lookup.
  absl::flat_hash_map<ModelWithBuffer*, std::unique_ptr<ModelWithBuffer>>
      retired_models_ GUARDED_BY_CONTEXT(sequence_checker_);

  base::WeakPtrFactory<MlModelManagerImpl> weak_factory_{this};
};

}  // namespace audio

#endif  // SERVICES_AUDIO_ML_MODEL_MANAGER_H_