#include "chrome/browser/contextual_tasks/contextual_tasks_context_service.h"
#include <memory>
#include "base/containers/contains.h"
#include "base/metrics/histogram_functions.h"
#include "base/stl_util.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/default_tick_clock.h"
#include "chrome/browser/contextual_tasks/contextual_tasks_signal_utils.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/page_content_annotations/page_content_annotations_web_contents_observer.h"
#include "chrome/browser/page_content_annotations/page_content_extraction_service.h"
#include "chrome/browser/page_content_annotations/page_content_extraction_types.h"
#include "chrome/browser/passage_embeddings/page_embeddings_service.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/browser_window/public/browser_window_interface.h"
#include "chrome/browser/ui/browser_window/public/browser_window_interface_iterator.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
#include "components/contextual_tasks/public/features.h"
#include "components/optimization_guide/core/model_quality/model_quality_log_entry.h"
#include "components/optimization_guide/proto/features/contextual_tasks_context.pb.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "content/public/browser/web_contents.h"
#include "url/gurl.h"
namespace contextual_tasks {
namespace {
#define AUTO_CONTEXT_LOG(message) \
OPTIMIZATION_GUIDE_LOG( \
optimization_guide_common::mojom::LogSource::CONTEXTUAL_TASKS_CONTEXT, \
optimization_guide_keyed_service_->GetOptimizationGuideLogger(), \
(message))
struct TabSignals {
raw_ptr<content::WebContents> web_contents = nullptr;
std::optional<float> embedding_score;
std::optional<base::TimeDelta> duration_since_last_active;
std::optional<int> num_query_title_matching_words;
};
struct TabSimilarityScores {
std::pair<float, std::string> best_similarity_score =
std::make_pair(0.0f, "");
std::pair<float, std::string> worst_similarity_score =
std::make_pair(1.0f, "");
};
std::optional<TabSimilarityScores> GetEmbeddingScores(
content::WebContents* web_contents,
const passage_embeddings::Embedding& query_embedding,
const std::vector<passage_embeddings::PassageEmbedding>&
web_contents_embeddings) {
if (web_contents_embeddings.empty()) {
return std::nullopt;
}
TabSimilarityScores similarity_scores;
for (const auto& embedding : web_contents_embeddings) {
if (kOnlyUseTitlesForSimilarity.Get() &&
embedding.passage.second != passage_embeddings::PassageType::kTitle) {
continue;
}
float similarity_score = embedding.embedding.ScoreWith(query_embedding);
if (similarity_score > similarity_scores.best_similarity_score.first) {
similarity_scores.best_similarity_score =
std::make_pair(similarity_score, embedding.passage.first);
}
if (similarity_score < similarity_scores.worst_similarity_score.first) {
similarity_scores.worst_similarity_score =
std::make_pair(similarity_score, embedding.passage.first);
}
}
return similarity_scores;
}
float ProbOr(const float score1, const float score2) {
return 1.0f - (1.0f - score1) * (1.0f - score2);
}
double GetTabScore(const TabSignals& signals) {
double score = 0;
if (signals.embedding_score.has_value()) {
score = ProbOr(score, *(signals.embedding_score));
}
if (signals.duration_since_last_active.has_value()) {
score = ProbOr(
score,
std::pow(0.7, signals.duration_since_last_active->InSeconds() / 180));
}
if (signals.num_query_title_matching_words.has_value()) {
float lexical_match_score =
1.0f - std::exp(-0.85 * *(signals.num_query_title_matching_words));
score = ProbOr(score, lexical_match_score);
}
return score;
}
void RecordContextDeterminationStatus(ContextDeterminationStatus status) {
base::UmaHistogramEnumeration(
"ContextualTasks.Context.ContextDeterminationStatus", status);
}
void RecordTabSelectionMetrics(std::set<GURL> relevant_tab_urls,
std::set<GURL> explicit_urls) {
CHECK(!explicit_urls.empty());
std::set<GURL> explicit_url_set =
std::set<GURL>(explicit_urls.begin(), explicit_urls.end());
base::UmaHistogramCounts100("ContextualTasks.Context.ExplicitTabsCount",
explicit_url_set.size());
base::flat_set<GURL> mutual_urls;
std::set_intersection(relevant_tab_urls.begin(), relevant_tab_urls.end(),
explicit_url_set.begin(), explicit_url_set.end(),
std::inserter(mutual_urls, mutual_urls.end()));
base::UmaHistogramCounts100("ContextualTasks.Context.TabOverlapCount",
mutual_urls.size());
base::UmaHistogramPercentage(
"ContextualTasks.Context.TabOverlapPercentage",
explicit_urls.empty() ? 0
: 100 * mutual_urls.size() / explicit_urls.size());
base::flat_set<GURL> excess_urls;
std::set_difference(relevant_tab_urls.begin(), relevant_tab_urls.end(),
explicit_url_set.begin(), explicit_url_set.end(),
std::inserter(excess_urls, excess_urls.end()));
base::UmaHistogramCounts100("ContextualTasks.Context.TabExcessCount",
excess_urls.size());
}
}
ContextualTasksContextService::ContextualTasksContextService(
Profile* profile,
passage_embeddings::PageEmbeddingsService* page_embeddings_service,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
OptimizationGuideKeyedService* optimization_guide_keyed_service,
page_content_annotations::PageContentExtractionService*
page_content_extraction_service)
: profile_(profile),
page_embeddings_service_(page_embeddings_service),
embedder_metadata_provider_(embedder_metadata_provider),
embedder_(embedder),
optimization_guide_keyed_service_(optimization_guide_keyed_service),
page_content_extraction_service_(page_content_extraction_service),
tick_clock_(base::DefaultTickClock::GetInstance()) {
scoped_embedder_metadata_provider_observation_.Observe(
embedder_metadata_provider_);
scoped_page_embeddings_service_observation_.Observe(page_embeddings_service_);
}
ContextualTasksContextService::~ContextualTasksContextService() = default;
void ContextualTasksContextService::SetClockForTesting(
const base::TickClock* tick_clock) {
tick_clock_ = tick_clock;
}
void ContextualTasksContextService::GetRelevantTabsForQuery(
const TabSelectionOptions& options,
const std::string& query,
const std::vector<GURL>& explicit_urls,
base::OnceCallback<void(std::vector<content::WebContents*>)> callback) {
base::TimeTicks now = tick_clock_->NowTicks();
AUTO_CONTEXT_LOG(base::StringPrintf("Processing query %s in mode %d", query,
options.tab_selection_mode));
if (!embedder_model_version_) {
AUTO_CONTEXT_LOG("Embedder not available");
RecordContextDeterminationStatus(
ContextDeterminationStatus::kEmbedderNotAvailable);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback),
std::vector<content::WebContents*>({})));
return;
}
page_embeddings_service_->ProcessAllEmbeddings();
AUTO_CONTEXT_LOG("Submitted query to embedder");
embedder_->ComputePassagesEmbeddings(
passage_embeddings::PassagePriority::kUrgent, {query},
base::BindOnce(&ContextualTasksContextService::OnQueryEmbeddingReady,
weak_ptr_factory_.GetWeakPtr(), query, options, now,
explicit_urls, std::move(callback)));
}
void ContextualTasksContextService::EmbedderMetadataUpdated(
passage_embeddings::EmbedderMetadata metadata) {
embedder_model_version_ = metadata.IsValid()
? std::make_optional(metadata.model_version)
: std::nullopt;
}
passage_embeddings::PageEmbeddingsService::Priority
ContextualTasksContextService::GetDefaultPriority() const {
return passage_embeddings::PageEmbeddingsService::Priority::kBackground;
}
void ContextualTasksContextService::OnQueryEmbeddingReady(
const std::string& query,
const TabSelectionOptions& options,
base::TimeTicks start_time,
const std::vector<GURL>& explicit_urls,
base::OnceCallback<void(std::vector<content::WebContents*>)> callback,
std::vector<std::string> passages,
std::vector<passage_embeddings::Embedding> embeddings,
passage_embeddings::Embedder::TaskId task_id,
passage_embeddings::ComputeEmbeddingsStatus status) {
if (status != passage_embeddings::ComputeEmbeddingsStatus::kSuccess) {
AUTO_CONTEXT_LOG(
base::StringPrintf("Query embedding for %s failed", query));
RecordContextDeterminationStatus(
ContextDeterminationStatus::kQueryEmbeddingFailed);
std::move(callback).Run({});
return;
}
if (embeddings.size() != 1u) {
AUTO_CONTEXT_LOG(base::StringPrintf(
"Query embedding for %s had unexpected output", query));
RecordContextDeterminationStatus(
ContextDeterminationStatus::kQueryEmbeddingOutputMalformed);
std::move(callback).Run({});
return;
}
AUTO_CONTEXT_LOG(
base::StringPrintf("Processing query embedding for %s", query));
std::vector<content::WebContents*> all_tabs = GetAllEligibleTabs();
if (all_tabs.empty()) {
AUTO_CONTEXT_LOG("No eligible tabs");
RecordContextDeterminationStatus(
ContextDeterminationStatus::kNoEligibleTabs);
std::move(callback).Run({});
return;
}
RecordContextDeterminationStatus(ContextDeterminationStatus::kSuccess);
auto log_entry = std::make_unique<optimization_guide::ModelQualityLogEntry>(
optimization_guide_keyed_service_->GetModelQualityLogsUploaderService()
->GetWeakPtr());
passage_embeddings::Embedding query_embedding = embeddings[0];
auto* quality_log = log_entry->log_ai_data_request()
->mutable_contextual_tasks_context()
->mutable_quality();
quality_log->set_embedding_model_version(
embedder_model_version_.value_or(-1));
std::vector<content::WebContents*> relevant_tabs = SelectRelevantTabs(
query, options, query_embedding, all_tabs, explicit_urls, quality_log);
AUTO_CONTEXT_LOG(base::StringPrintf(
"Number of eligible open tabs for query %s: %d", query, all_tabs.size()));
AUTO_CONTEXT_LOG(base::StringPrintf(
"Number of relevant tabs for query %s: %d", query, relevant_tabs.size()));
base::UmaHistogramTimes("ContextualTasks.Context.ContextCalculationLatency",
tick_clock_->NowTicks() - start_time);
base::UmaHistogramCounts100("ContextualTasks.Context.RelevantTabsCount",
relevant_tabs.size());
if (!explicit_urls.empty()) {
std::set<GURL> relevant_tab_url_set;
for (auto* web_contents : relevant_tabs) {
relevant_tab_url_set.insert(web_contents->GetLastCommittedURL());
}
RecordTabSelectionMetrics(
relevant_tab_url_set,
std::set<GURL>(explicit_urls.begin(), explicit_urls.end()));
}
if (!ShouldLogContextualTasksContextQuality() ||
quality_log->eligible_tabs().size() == 0) {
optimization_guide::ModelQualityLogEntry::Drop(std::move(log_entry));
}
std::move(callback).Run(std::move(relevant_tabs));
}
std::vector<content::WebContents*>
ContextualTasksContextService::GetAllEligibleTabs() {
std::vector<content::WebContents*> all_tabs;
ForEachCurrentBrowserWindowInterfaceOrderedByActivation(
[this, &all_tabs](BrowserWindowInterface* browser) {
if (browser->GetProfile() != profile_) {
return true;
}
TabStripModel* const tab_strip_model = browser->GetTabStripModel();
for (int i = 0; i < tab_strip_model->count(); i++) {
content::WebContents* web_contents =
tab_strip_model->GetWebContentsAt(i);
if (!web_contents) {
continue;
}
if (!web_contents->GetLastCommittedURL().SchemeIsHTTPOrHTTPS()) {
continue;
}
if (!ShouldAddTabToSelection(web_contents)) {
AUTO_CONTEXT_LOG(
base::StringPrintf("Removing %s from relevant set as it is not "
"eligible for server upload",
web_contents->GetLastCommittedURL().spec()));
continue;
}
all_tabs.push_back(web_contents);
}
return true;
});
return all_tabs;
}
std::vector<content::WebContents*>
ContextualTasksContextService::SelectRelevantTabs(
const std::string& query,
const TabSelectionOptions& options,
const passage_embeddings::Embedding& query_embedding,
const std::vector<content::WebContents*>& all_tabs,
const std::vector<GURL>& explicit_urls,
optimization_guide::proto::ContextualTasksContextQuality* quality_log) {
switch (options.tab_selection_mode) {
case mojom::TabSelectionMode::kMultiSignalScoring:
return SelectTabsByMultiSignalScore(query, options, query_embedding,
all_tabs, explicit_urls, quality_log);
case mojom::TabSelectionMode::kEmbeddingsMatch:
return SelectTabsByEmbeddingsMatch(query, options, query_embedding,
all_tabs);
}
}
std::vector<content::WebContents*>
ContextualTasksContextService::SelectTabsByMultiSignalScore(
const std::string& query,
const TabSelectionOptions& options,
const passage_embeddings::Embedding& query_embedding,
const std::vector<content::WebContents*>& all_tabs,
const std::vector<GURL>& explicit_urls,
optimization_guide::proto::ContextualTasksContextQuality* quality_log) {
std::vector<content::WebContents*> relevant_tabs;
for (auto* web_contents : all_tabs) {
optimization_guide::proto::ContextualTasksTabContext* tab_context =
quality_log->add_eligible_tabs();
TabSignals tab_signals;
tab_signals.web_contents = web_contents;
std::optional<TabSimilarityScores> similarity_scores = GetEmbeddingScores(
web_contents, query_embedding,
page_embeddings_service_->GetEmbeddings(web_contents));
tab_signals.embedding_score =
similarity_scores
? std::make_optional(similarity_scores->best_similarity_score.first)
: std::nullopt;
tab_signals.duration_since_last_active =
GetDurationSinceLastActive(web_contents);
tab_signals.num_query_title_matching_words = GetMatchingWordsCount(
query, base::UTF16ToUTF8(web_contents->GetTitle()));
if (similarity_scores) {
AUTO_CONTEXT_LOG(base::StringPrintf(
"Passage with highest similarity with query %s: %f",
similarity_scores->best_similarity_score.second,
similarity_scores->best_similarity_score.first));
AUTO_CONTEXT_LOG(
base::StringPrintf("Passage with lowest similarity with query %s: %f",
similarity_scores->worst_similarity_score.second,
similarity_scores->worst_similarity_score.first));
}
if (tab_signals.embedding_score.has_value()) {
base::UmaHistogramCounts100(
"ContextualTasks.Context.EmbeddingSimilarityScore",
static_cast<int>(
std::min(100 * *(tab_signals.embedding_score), 100.0f)));
tab_context->set_best_embedding_score(*tab_signals.embedding_score);
}
if (tab_signals.duration_since_last_active.has_value()) {
base::UmaHistogramTimes("ContextualTasks.Context.DurationSinceLastActive",
*(tab_signals.duration_since_last_active));
tab_context->set_seconds_since_last_active(
tab_signals.duration_since_last_active->InSeconds());
}
if (tab_signals.num_query_title_matching_words.has_value()) {
base::UmaHistogramCounts100(
"ContextualTasks.Context.MatchingWordsCount",
std::min(*(tab_signals.num_query_title_matching_words), 100));
tab_context->set_number_of_common_words(
*tab_signals.num_query_title_matching_words);
}
double score = GetTabScore(tab_signals);
tab_context->set_aggregate_tab_score(score);
if (score >= options.min_model_score.value_or(kMinMultiSignalScore.Get())) {
relevant_tabs.push_back(tab_signals.web_contents);
}
tab_context->set_was_explicitly_chosen(
base::Contains(explicit_urls, web_contents->GetLastCommittedURL()));
base::UmaHistogramSparse("ContextualTasks.Context.TabScore",
static_cast<int>(std::min(100 * score, 100.0)));
AUTO_CONTEXT_LOG(base::StringPrintf(
"Query: %s | TabTitle: %s | EmbeddingsScore: %f | "
"SecondsSinceLastActive: %d | MatchingWordsCount: %d | Score: %f",
query, base::UTF16ToUTF8(web_contents->GetTitle()),
tab_signals.embedding_score.value_or(0.0f),
tab_signals.duration_since_last_active.has_value()
? tab_signals.duration_since_last_active->InSeconds()
: -1,
tab_signals.num_query_title_matching_words.value_or(0), score));
}
return relevant_tabs;
}
std::vector<content::WebContents*>
ContextualTasksContextService::SelectTabsByEmbeddingsMatch(
const std::string& query,
const TabSelectionOptions& options,
const passage_embeddings::Embedding& query_embedding,
const std::vector<content::WebContents*>& all_tabs) {
std::vector<content::WebContents*> relevant_tabs;
for (auto* web_contents : all_tabs) {
std::vector<passage_embeddings::PassageEmbedding> web_contents_embeddings =
page_embeddings_service_->GetEmbeddings(web_contents);
AUTO_CONTEXT_LOG(base::StringPrintf(
"Comparing query embedding to %llu embeddings for %s",
web_contents_embeddings.size(),
web_contents->GetLastCommittedURL().spec()));
std::optional<TabSimilarityScores> similarity_scores = GetEmbeddingScores(
web_contents, query_embedding, web_contents_embeddings);
if (!similarity_scores) {
continue;
}
AUTO_CONTEXT_LOG(
base::StringPrintf("Passage with highest similarity with query %s: %f",
similarity_scores->best_similarity_score.second,
similarity_scores->best_similarity_score.first));
AUTO_CONTEXT_LOG(
base::StringPrintf("Passage with lowest similarity with query %s: %f",
similarity_scores->worst_similarity_score.second,
similarity_scores->worst_similarity_score.first));
if (similarity_scores->best_similarity_score.first >
options.min_model_score.value_or(kMinEmbeddingSimilarityScore.Get())) {
relevant_tabs.push_back(web_contents);
AUTO_CONTEXT_LOG(
base::StringPrintf("Adding %s to relevant set",
web_contents->GetLastCommittedURL().spec()));
}
}
return relevant_tabs;
}
std::optional<base::TimeDelta>
ContextualTasksContextService::GetDurationSinceLastActive(
content::WebContents* web_contents) {
base::TimeDelta time_elapsed =
tick_clock_->NowTicks() -
std::max(web_contents->GetLastActiveTimeTicks(),
web_contents->GetLastInteractionTimeTicks());
if (time_elapsed.is_positive()) {
return time_elapsed;
}
return std::nullopt;
}
bool ContextualTasksContextService::ShouldAddTabToSelection(
content::WebContents* web_contents) {
bool is_sensitive = false;
if (auto* page_content_annotations_observer =
page_content_annotations::PageContentAnnotationsWebContentsObserver::
FromWebContents(web_contents)) {
float visibility_score =
page_content_annotations_observer->content_visibility_score().value_or(
-1.0f);
is_sensitive = visibility_score < kContentVisibilityThreshold.Get() &&
visibility_score >= 0.0;
}
bool is_eligible_for_server_upload = true;
if (page_content_extraction_service_) {
std::optional<page_content_annotations::ExtractedPageContentResult>
extracted_page_content_result =
page_content_extraction_service_
->GetExtractedPageContentAndEligibilityForPage(
web_contents->GetPrimaryPage());
is_eligible_for_server_upload =
!extracted_page_content_result ||
extracted_page_content_result->is_eligible_for_server_upload;
}
return is_eligible_for_server_upload && !is_sensitive;
}
}