#include "chromeos/components/quick_answers/understanding/intent_generator.h"
#include <map>
#include "base/i18n/break_iterator.h"
#include "base/i18n/case_conversion.h"
#include "base/no_destructor.h"
#include "base/strings/string_split.h"
#include "base/strings/utf_string_conversions.h"
#include "chromeos/components/quick_answers/public/cpp/quick_answers_state.h"
#include "chromeos/components/quick_answers/quick_answers_model.h"
#include "chromeos/components/quick_answers/utils/quick_answers_metrics.h"
#include "chromeos/components/quick_answers/utils/quick_answers_utils.h"
#include "chromeos/components/quick_answers/utils/spell_checker.h"
#include "chromeos/components/quick_answers/utils/translation_v2_utils.h"
#include "chromeos/constants/chromeos_features.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "components/translate/core/browser/translate_download_manager.h"
#include "third_party/abseil-cpp/absl/strings/ascii.h"
#include "ui/base/l10n/l10n_util.h"
namespace quick_answers {
namespace {
using ::chromeos::machine_learning::mojom::LoadModelResult;
using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
using ::chromeos::machine_learning::mojom::TextAnnotationRequest;
using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
using ::chromeos::machine_learning::mojom::TextClassifier;
constexpr int kUnitConversionIntentAndSelectionLengthDiffThreshold = 5;
constexpr int kTranslationTextLengthThreshold = 100;
constexpr int kRichAnswersTranslationTextLengthThreshold = 250;
constexpr int kDefinitionIntentAndSelectionLengthDiffThreshold = 2;
constexpr char kInvalidCharactersSet[] = "()[]{}<>_&|!";
constexpr char kEnglishLanguage[] = "en";
const std::map<std::string, IntentType>& GetIntentTypeMap() {
static base::NoDestructor<std::map<std::string, IntentType>> kIntentTypeMap(
{{"unit", IntentType::kUnit}, {"dictionary", IntentType::kDictionary}});
return *kIntentTypeMap;
}
bool ExtractEntity(const std::string& selected_text,
const std::vector<TextAnnotationPtr>& annotations,
std::string* entity_str,
std::string* type) {
for (auto& annotation : annotations) {
*entity_str = base::UTF16ToUTF8(
base::UTF8ToUTF16(selected_text)
.substr(annotation->start_offset,
annotation->end_offset - annotation->start_offset));
auto intent_type_map = GetIntentTypeMap();
for (const auto& entity : annotation->entities) {
if (intent_type_map.find(entity->name) != intent_type_map.end()) {
*type = entity->name;
return true;
}
}
}
return false;
}
IntentType RewriteIntent(const std::string& selected_text,
const std::string& entity_str,
const IntentType intent) {
int intent_and_selection_length_diff =
base::UTF8ToUTF16(selected_text).length() -
base::UTF8ToUTF16(entity_str).length();
if ((intent == IntentType::kUnit &&
intent_and_selection_length_diff >
kUnitConversionIntentAndSelectionLengthDiffThreshold) ||
(intent == IntentType::kDictionary &&
intent_and_selection_length_diff >
kDefinitionIntentAndSelectionLengthDiffThreshold)) {
return IntentType::kUnknown;
}
return intent;
}
bool IsPreferredLanguage(std::string_view detected_language) {
auto preferred_languages_list =
base::SplitString(QuickAnswersState::Get()->preferred_languages(), ",",
base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
for (const std::string& locale : preferred_languages_list) {
if (l10n_util::GetLanguage(locale) == detected_language)
return true;
}
return false;
}
bool ShouldSkipDefinition(const std::string& text) {
auto device_language =
l10n_util::GetLanguage(QuickAnswersState::Get()->application_locale());
if (device_language != kEnglishLanguage &&
!IsPreferredLanguage(kEnglishLanguage))
return true;
DCHECK(text.length());
const auto& text_utf16 = base::UTF8ToUTF16(text);
bool has_capitalized_middle_characters =
text_utf16.substr(1) != base::i18n::ToLower(text_utf16.substr(1));
bool are_all_characters_capitalized =
text_utf16 == base::i18n::ToUpper(text_utf16);
if (has_capitalized_middle_characters && !are_all_characters_capitalized)
return true;
if (text.find_first_of(kInvalidCharactersSet) != std::string::npos)
return true;
return false;
}
bool AreTranslationLanguagesSupported(std::string_view source_language,
std::string_view target_language) {
return TranslationV2Utils::IsSupported(source_language) &&
TranslationV2Utils::IsSupported(target_language);
}
bool HasDigits(const std::string& word) {
for (char c : word) {
if (absl::ascii_isdigit(static_cast<unsigned char>(c))) {
return true;
}
}
return false;
}
}
IntentGenerator::IntentGenerator(base::WeakPtr<SpellChecker> spell_checker,
IntentGeneratorCallback complete_callback)
: spell_checker_(std::move(spell_checker)),
complete_callback_(std::move(complete_callback)) {}
IntentGenerator::~IntentGenerator() {
if (complete_callback_)
std::move(complete_callback_)
.Run(IntentInfo(std::string(), IntentType::kUnknown));
}
void IntentGenerator::GenerateIntent(const QuickAnswersRequest& request) {
const std::u16string& u16_text = base::UTF8ToUTF16(request.selected_text);
base::i18n::BreakIterator iter(u16_text,
base::i18n::BreakIterator::BREAK_WORD);
if (!iter.Init() || !iter.Advance()) {
NOTREACHED() << "Failed to load BreakIterator.";
}
DCHECK(spell_checker_.get()) << "spell_checker_ should exist when the "
"always trigger feature is enabled";
if (iter.IsWord() && iter.prev() == 0 && iter.pos() == u16_text.length()) {
auto text = base::UTF16ToUTF8(
base::i18n::ToLower(base::UTF8ToUTF16(request.selected_text)));
spell_checker_->CheckSpelling(
text, base::BindOnce(&IntentGenerator::CheckSpellingCallback,
weak_factory_.GetWeakPtr(), request));
return;
}
MaybeLoadTextClassifier(request);
}
void IntentGenerator::FlushForTesting() {
text_classifier_.FlushForTesting();
}
void IntentGenerator::MaybeLoadTextClassifier(
const QuickAnswersRequest& request) {
if (QuickAnswersState::Get()->ShouldUseQuickAnswersTextAnnotator()) {
chromeos::machine_learning::ServiceConnection::GetInstance()
->GetMachineLearningService()
.LoadTextClassifier(
text_classifier_.BindNewPipeAndPassReceiver(),
base::BindOnce(&IntentGenerator::LoadModelCallback,
weak_factory_.GetWeakPtr(), request));
return;
}
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kUnknown));
}
void IntentGenerator::CheckSpellingCallback(const QuickAnswersRequest& request,
bool correctness,
const std::string& language) {
if (correctness && !HasDigits(request.selected_text)) {
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kDictionary,
QuickAnswersState::Get()->application_locale(),
language));
RecordDictionaryIntentSource(DictionaryIntentSource::kHunspell);
RecordDictionaryIntentLanguage(language);
return;
}
MaybeLoadTextClassifier(request);
}
void IntentGenerator::LoadModelCallback(const QuickAnswersRequest& request,
LoadModelResult result) {
if (result != LoadModelResult::OK) {
LOG(ERROR) << "Failed to load TextClassifier.";
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kUnknown));
return;
}
if (text_classifier_) {
TextAnnotationRequestPtr text_annotation_request =
TextAnnotationRequest::New();
text_annotation_request->text = request.selected_text;
text_annotation_request->default_locales =
QuickAnswersState::Get()->application_locale();
text_annotation_request->trigger_dictionary_on_beginner_words = true;
text_classifier_->Annotate(
std::move(text_annotation_request),
base::BindOnce(&IntentGenerator::AnnotationCallback,
weak_factory_.GetWeakPtr(), request));
}
}
void IntentGenerator::AnnotationCallback(
const QuickAnswersRequest& request,
std::vector<TextAnnotationPtr> annotations) {
std::string entity_str;
std::string type;
if (ExtractEntity(request.selected_text, annotations, &entity_str, &type)) {
auto intent_type_map = GetIntentTypeMap();
auto it = intent_type_map.find(type);
if (it != intent_type_map.end()) {
bool definition_ineligible =
!QuickAnswersState::IsIntentEligible(Intent::kDefinition);
bool unit_conversion_ineligible =
!QuickAnswersState::IsIntentEligible(Intent::kUnitConversion);
if ((it->second == IntentType::kDictionary && definition_ineligible) ||
(it->second == IntentType::kUnit && unit_conversion_ineligible)) {
MaybeGenerateTranslationIntent(request);
return;
}
if (it->second == IntentType::kDictionary &&
ShouldSkipDefinition(request.selected_text)) {
MaybeGenerateTranslationIntent(request);
return;
}
std::move(complete_callback_)
.Run(IntentInfo(
entity_str,
RewriteIntent(request.selected_text, entity_str, it->second),
QuickAnswersState::Get()->application_locale()));
if (it->second == IntentType::kDictionary) {
RecordDictionaryIntentSource(DictionaryIntentSource::kTextClassifier);
RecordDictionaryIntentLanguage(kEnglishLanguage);
}
return;
}
}
MaybeGenerateTranslationIntent(request);
}
void IntentGenerator::MaybeGenerateTranslationIntent(
const QuickAnswersRequest& request) {
DCHECK(complete_callback_);
if (!QuickAnswersState::IsIntentEligible(Intent::kTranslation) ||
chromeos::features::IsQuickAnswersV2TranslationDisabled()) {
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kUnknown));
return;
}
size_t translation_text_length_threshold =
chromeos::features::IsQuickAnswersRichCardEnabled()
? kRichAnswersTranslationTextLengthThreshold
: kTranslationTextLengthThreshold;
if (QuickAnswersState::Get()->application_locale().empty() ||
request.selected_text.length() > translation_text_length_threshold) {
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kUnknown));
return;
}
language_detector_ =
std::make_unique<LanguageDetector>(text_classifier_.get());
language_detector_->DetectLanguage(
request.context.surrounding_text, request.selected_text,
base::BindOnce(&IntentGenerator::LanguageDetectorCallback,
weak_factory_.GetWeakPtr(), request));
}
void IntentGenerator::LanguageDetectorCallback(
const QuickAnswersRequest& request,
std::optional<std::string> detected_locale) {
language_detector_.reset();
auto device_language =
l10n_util::GetLanguage(QuickAnswersState::Get()->application_locale());
auto detected_language = detected_locale.has_value()
? l10n_util::GetLanguage(detected_locale.value())
: std::string_view();
if (!detected_language.empty() && detected_language != device_language &&
!IsPreferredLanguage(detected_language) &&
AreTranslationLanguagesSupported(detected_language, device_language)) {
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kTranslation,
device_language, detected_language));
return;
}
std::move(complete_callback_)
.Run(IntentInfo(request.selected_text, IntentType::kUnknown));
}
}