#ifndef COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
#define COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/base_predictor.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
namespace base {
class FilePath;
}
namespace network {
class SharedURLLoaderFactory;
}
namespace assist_ranker {
class GenericLogisticRegressionInference;
class BinaryClassifierPredictor final : public BasePredictor {
public:
BinaryClassifierPredictor(const BinaryClassifierPredictor&) = delete;
BinaryClassifierPredictor& operator=(const BinaryClassifierPredictor&) =
delete;
~BinaryClassifierPredictor() override;
[[nodiscard]] static std::unique_ptr<BinaryClassifierPredictor> Create(
const PredictorConfig& config,
const base::FilePath& model_path,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
[[nodiscard]] bool Predict(const RankerExample& example, bool* prediction);
[[nodiscard]] bool PredictScore(const RankerExample& example,
float* prediction);
static RankerModelStatus ValidateModel(const RankerModel& model);
base::WeakPtr<BinaryClassifierPredictor> AsWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
protected:
bool Initialize() override;
private:
friend class BinaryClassifierPredictorTest;
BinaryClassifierPredictor(const PredictorConfig& config);
std::unique_ptr<GenericLogisticRegressionInference> inference_module_;
base::WeakPtrFactory<BinaryClassifierPredictor> weak_ptr_factory_{this};
};
}
#endif