#include "chromeos/services/machine_learning/cpp/ash/handwriting_model_loader.h"
#include <string>
#include <string_view>
#include <utility>
#include "base/command_line.h"
#include "base/functional/callback_helpers.h"
#include "base/metrics/histogram_macros.h"
#include "chromeos/services/machine_learning/public/cpp/ml_switches.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "third_party/cros_system_api/dbus/service_constants.h"
namespace ash {
namespace machine_learning {
namespace {
using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpecPtr;
using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using HandwritingRecognizer = mojo::PendingReceiver<
::chromeos::machine_learning::mojom::HandwritingRecognizer>;
using LoadHandwritingModelCallback = ::chromeos::machine_learning::mojom::
MachineLearningService::LoadHandwritingModelCallback;
void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) {
UMA_HISTOGRAM_ENUMERATION(
"MachineLearningService.HandwritingModel.LoadModelResult.Event", val,
LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR);
}
constexpr char kLibHandwritingDlcId[] = "libhandwriting";
constexpr char kLanguageCodeEn[] = "en";
constexpr char kLanguageCodeGesture[] = "gesture_in_context";
bool HandwritingSwitchHasValue(const std::string& value) {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
return command_line->HasSwitch(::switches::kOndeviceHandwritingSwitch) &&
command_line->GetSwitchValueASCII(
switches::kOndeviceHandwritingSwitch) == value;
}
bool IsLibHandwritingRootfsEnabled() {
return HandwritingSwitchHasValue("use_rootfs");
}
bool IsLibHandwritingDlcEnabled() {
return HandwritingSwitchHasValue("use_dlc");
}
void OnInstallDlcComplete(HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
const DlcserviceClient::InstallResult& result) {
if (result.error == dlcservice::kErrorNone) {
chromeos::machine_learning::ServiceConnection::GetInstance()
->GetMachineLearningService()
.LoadHandwritingModel(std::move(spec), std::move(receiver),
std::move(callback));
return;
}
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::DLC_INSTALL_ERROR);
std::move(callback).Run(LoadHandwritingModelResult::DLC_INSTALL_ERROR);
}
void OnGetExistingDlcsComplete(
HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
DlcserviceClient* const dlc_client,
std::string_view err,
const dlcservice::DlcsWithContent& dlcs_with_content) {
for (const auto& dlc_info : dlcs_with_content.dlc_infos()) {
if (dlc_info.id() == kLibHandwritingDlcId) {
dlcservice::InstallRequest install_request;
install_request.set_id(kLibHandwritingDlcId);
dlc_client->Install(
install_request,
base::BindOnce(&OnInstallDlcComplete, std::move(spec),
std::move(receiver), std::move(callback)),
base::DoNothing());
return;
}
}
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
std::move(callback).Run(LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
}
}
void LoadHandwritingModelFromRootfsOrDlc(HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
DlcserviceClient* const dlc_client) {
if (!IsLibHandwritingRootfsEnabled() && !IsLibHandwritingDlcEnabled()) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
std::move(callback).Run(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
return;
}
if (spec->language != kLanguageCodeEn &&
spec->language != kLanguageCodeGesture) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
std::move(callback).Run(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
return;
}
if (IsLibHandwritingRootfsEnabled()) {
chromeos::machine_learning::ServiceConnection::GetInstance()
->GetMachineLearningService()
.LoadHandwritingModel(std::move(spec), std::move(receiver),
std::move(callback));
return;
}
dlc_client->GetExistingDlcs(
base::BindOnce(&OnGetExistingDlcsComplete, std::move(spec),
std::move(receiver), std::move(callback), dlc_client));
}
}
}