910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/compose/proactive_nudge_tracker.h"

#include <memory>

#include "base/functional/callback_helpers.h"
#include "base/memory/scoped_refptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/numerics/safe_conversions.h"
#include "base/rand_util.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "base/types/cxx23_to_underlying.h"
#include "chrome/browser/compose/proto/compose_optimization_guide.pb.h"
#include "components/autofill/content/browser/content_autofill_client.h"
#include "components/autofill/core/common/signatures.h"
#include "components/compose/core/browser/compose_metrics.h"
#include "components/compose/core/browser/config.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "url/gurl.h"
#include "url/origin.h"

namespace compose {

namespace {
using segmentation_platform::processing::ProcessedValue;

scoped_refptr<segmentation_platform::InputContext> PopulateInputContextForField(
    const ProactiveNudgeTracker::Signals& signals,
    const compose::ComposeHintMetadata& compose_hint) {
  auto input_context =
      base::MakeRefCounted<segmentation_platform::InputContext>();

  input_context->metadata_args.emplace(
      "field_max_length",
      ProcessedValue::FromFloat(signals.field.max_length()));
  input_context->metadata_args.emplace(
      "field_width_pixels",
      ProcessedValue::FromFloat(signals.field.bounds().width()));
  input_context->metadata_args.emplace(
      "field_height_pixels",
      ProcessedValue::FromFloat(signals.field.bounds().height()));
  input_context->metadata_args.emplace(
      "field_form_control_type", ProcessedValue::FromFloat(base::to_underlying(
                                     signals.field.form_control_type())));
  input_context->metadata_args.emplace(
      "total_field_count",
      ProcessedValue::FromFloat(signals.form.fields().size()));

  int multiline_field_count = 0;
  for (const auto& f : signals.form.fields()) {
    if (f.form_control_type() == autofill::FormControlType::kTextArea ||
        f.form_control_type() == autofill::FormControlType::kContentEditable) {
      ++multiline_field_count;
    }
  }
  input_context->metadata_args.emplace(
      "multiline_field_count",
      ProcessedValue::FromFloat(multiline_field_count));
  input_context->metadata_args.emplace(
      "time_spent_on_page",
      ProcessedValue::FromFloat(
          (base::TimeTicks::Now() - signals.page_change_time).InSecondsF()));

  input_context->metadata_args.emplace(
      "field_signature",
      ProcessedValue(autofill::HashFieldSignature(
          autofill::CalculateFieldSignatureForField(signals.field))));
  input_context->metadata_args.emplace(
      "form_signature", ProcessedValue(autofill::HashFormSignature(
                            autofill::CalculateFormSignature(signals.form))));

  input_context->metadata_args.emplace("page_url",
                                       ProcessedValue(signals.page_url));
  input_context->metadata_args.emplace(
      "origin", ProcessedValue(signals.page_origin.GetURL()));

  input_context->metadata_args.emplace(
      "field_value", ProcessedValue(base::UTF16ToUTF8(signals.field.value())));
  input_context->metadata_args.emplace(
      "field_selected_text",
      ProcessedValue(base::UTF16ToUTF8(signals.field.selected_text())));
  input_context->metadata_args.emplace(
      "field_placeholder_text",
      ProcessedValue(base::UTF16ToUTF8(signals.field.placeholder())));
  input_context->metadata_args.emplace(
      "field_label", ProcessedValue(base::UTF16ToUTF8(signals.field.label())));
  input_context->metadata_args.emplace(
      "field_aria_label",
      ProcessedValue(base::UTF16ToUTF8(signals.field.aria_label())));
  input_context->metadata_args.emplace(
      "field_aria_description",
      ProcessedValue(base::UTF16ToUTF8(signals.field.aria_description())));

  for (auto& pair : compose_hint.model_params()) {
    input_context->metadata_args.emplace(pair.first,
                                         ProcessedValue(pair.second));
  }
  return input_context;
}

}  // namespace

ProactiveNudgeTracker::Signals::Signals() = default;
ProactiveNudgeTracker::Signals::~Signals() = default;
ProactiveNudgeTracker::Signals::Signals(Signals&&) = default;
ProactiveNudgeTracker::Signals& ProactiveNudgeTracker::Signals::operator=(
    Signals&&) = default;

// Tracks user engagement as a result of the nudge.
class ProactiveNudgeTracker::EngagementTracker {
 public:
  EngagementTracker(
      autofill::FieldGlobalId field_global_id,
      segmentation_platform::TrainingRequestId training_request_id,
      ProactiveNudgeTracker* nudge_tracker)
      : field_global_id_(field_global_id),
        training_request_id_(training_request_id),
        nudge_tracker_(nudge_tracker) {}
  EngagementTracker(const EngagementTracker&) = delete;
  EngagementTracker& operator=(const EngagementTracker&) = delete;

  ~EngagementTracker() {
    ReportIfFirst(ProactiveNudgeDerivedEngagement::kIgnored);
  }

  void ComposeSessionCompleted(ComposeSessionCloseReason session_close_reason,
                               const compose::ComposeSessionEvents& events) {
    if (events.inserted_results) {
      ReportIfFirst(
          ProactiveNudgeDerivedEngagement::kAcceptedComposeSuggestion);
    } else if (events.compose_requests_count > 0) {
      ReportIfFirst(
          ProactiveNudgeDerivedEngagement::kGeneratedComposeSuggestion);
    } else {
      ReportIfFirst(ProactiveNudgeDerivedEngagement::kOpenedComposeMinimalUse);
    }
  }

  void UserDisabledNudge(bool single_site_only) {
    if (single_site_only) {
      ReportIfFirst(
          ProactiveNudgeDerivedEngagement::kNudgeDisabledOnSingleSite);
    } else {
      ReportIfFirst(ProactiveNudgeDerivedEngagement::kNudgeDisabledOnAllSites);
    }
  }

 private:
  void ReportIfFirst(ProactiveNudgeDerivedEngagement engagement) {
    if (reported_) {
      return;
    }
    nudge_tracker_->CollectTrainingData(training_request_id_, engagement);
    reported_ = true;
  }

  bool reported_ = false;
  autofill::FieldGlobalId field_global_id_;
  segmentation_platform::TrainingRequestId training_request_id_;
  raw_ptr<ProactiveNudgeTracker> nudge_tracker_;
};

ProactiveNudgeTracker::State::State() = default;
ProactiveNudgeTracker::State::~State() = default;

float ProactiveNudgeTracker::Delegate::SegmentationFallbackShowResult() {
  return base::RandFloat();
}

float ProactiveNudgeTracker::Delegate::SegmentationForceShowResult() {
  return base::RandFloat();
}

ProactiveNudgeTracker::ProactiveNudgeTracker(
    segmentation_platform::SegmentationPlatformService* segmentation_service,
    Delegate* delegate)
    : segmentation_service_(segmentation_service), delegate_(delegate) {
  CHECK(segmentation_service_);
  CHECK(delegate_);
}

void ProactiveNudgeTracker::StartObserving(content::WebContents* web_contents) {
  autofill_managers_observation_.Observe(
      autofill::ContentAutofillClient::FromWebContents(web_contents),
      autofill::ScopedAutofillManagersObservation::InitializationPolicy::
          kObservePreexistingManagers);
}

ProactiveNudgeTracker::~ProactiveNudgeTracker() {
  // Destroy all engagement trackers first to ensure CollectTrainingData() is
  // called before other parts of this class are destroyed.
  engagement_trackers_.clear();
}

bool ProactiveNudgeTracker::ProactiveNudgeRequestedForFormField(
    Signals signals) {
  if (!MatchesCurrentField(signals.field.renderer_form_id(),
                           signals.field.global_id())) {
    ResetState();
    state_ = std::make_unique<State>();
    state_->signals = std::move(signals);
  }

  if (state_->show_state == ShowState::kShown) {
    return false;
  }

  nudge_currently_requested_ = true;
  UpdateStateForCurrentFormField();
  nudge_currently_requested_ = false;
  return state_->show_state == ShowState::kShown;
}

bool ProactiveNudgeTracker::IsTimerRunning() {
  return state_ && state_->timer.IsRunning();
}

void ProactiveNudgeTracker::FocusChangedInPage() {
  ResetState();
}

void ProactiveNudgeTracker::Clear() {
  engagement_trackers_.clear();
  seen_fields_.clear();
  ResetState();
}

void ProactiveNudgeTracker::OnAfterFocusOnFormField(
    autofill::AutofillManager& manager,
    autofill::FormGlobalId form,
    autofill::FieldGlobalId field) {
  DVLOG(2) << "ProactiveNudgeTracker: OnAfterFocusOnFormField";
  // If this focus is on the current field, we are (presumably) already focused
  // and this is a no-op.
  if (MatchesCurrentField(form, field) || state_ == nullptr) {
    return;
  }

  // Now we should transition to the kInitial state.
  ResetState();
}

void ProactiveNudgeTracker::OnAfterTextFieldValueChanged(
    autofill::AutofillManager& manager,
    autofill::FormGlobalId form,
    autofill::FieldGlobalId field,
    const std::u16string& text_value) {
  if (!MatchesCurrentField(form, field)) {
    return;
  }
  ++state_->text_change_count;
  UpdateStateForCurrentFormField();
}

void ProactiveNudgeTracker::OnAfterCaretMovedInFormField(
    autofill::AutofillManager& manager,
    const autofill::FormGlobalId& form,
    const autofill::FieldGlobalId& field,
    const std::u16string& selection,
    const gfx::Rect& caret_bounds) {
  if (!MatchesCurrentField(form, field)) {
    return;
  }

  bool selection_valid =
      GetComposeConfig().selection_nudge_enabled &&
      selection.size() >= GetComposeConfig().selection_nudge_length;

  if (IsTimerRunning() && state_->selection_nudge_requested &&
      !selection_valid) {
    // Cancel the timer if the selection is no longer valid.
    state_->timer_canceled = true;
  } else {
    state_->selection_nudge_requested = selection_valid;
    if (IsTimerRunning()) {
      // Extend the timer if it is currently running. This will restart with
      // the correct delay if the state has changed.
      StartOrRestartTimer();
    }
  }
  UpdateStateForCurrentFormField();
}

void ProactiveNudgeTracker::ResetState() {
  DVLOG(2) << "ProactiveNudgeTracker: ResetState";
  weak_ptr_factory_.InvalidateWeakPtrs();
  state_.reset();
}

void ProactiveNudgeTracker::UpdateStateForCurrentFormField() {
  while (auto new_state = CheckForStateTransition()) {
    TransitionToState(new_state.value());
  }
}

std::optional<ProactiveNudgeTracker::ShowState>
ProactiveNudgeTracker::CheckForStateTransition() {
  switch (state_->show_state) {
    case ShowState::kInitial:
      // Block if the cached result should not show the nudge.
      if (!CachedSegmentationResult().value_or(true)) {
        return ShowState::kBlockedBySegmentation;
      }
      if (CanStartFocusTimer() || CanStartTextSettledTimer() ||
          CanStartSelectionTimer()) {
        return ShowState::kWaitingForTimerToStop;
      }
      // Remain in initial state until any delay timer can be started.
      return std::nullopt;
    case ShowState::kWaitingForTimerToStop:
      if (state_->timer_canceled) {
        return ShowState::kTimerCanceled;
      }
      if (!IsTimerRunning()) {
        return ShowState::kWaitingForSegmentation;
      }
      // Continue to wait if the timer is running or not canceled.
      return std::nullopt;
    case ShowState::kWaitingForSegmentation:
      // Use cached segmentation result if possible.
      if (auto result = CachedSegmentationResult(); result.has_value()) {
        return result.value() ? ShowState::kWaitingForProactiveNudgeRequest
                              : ShowState::kBlockedBySegmentation;
      }
      if (state_->segmentation_result.has_value()) {
        bool segmentation_succeeded =
            !state_->segmentation_result->ordered_labels.empty() &&
            state_->segmentation_result->ordered_labels[0] ==
                segmentation_platform::kComposePrmotionLabelShow;
        seen_fields_.emplace(state_->signals.field.global_id(),
                             segmentation_succeeded);
        return segmentation_succeeded
                   ? ShowState::kWaitingForProactiveNudgeRequest
                   : ShowState::kBlockedBySegmentation;
      }
      return std::nullopt;
    case ShowState::kWaitingForProactiveNudgeRequest:
      if (nudge_currently_requested_) {
        return ShowState::kShown;
      }
      return std::nullopt;
    case ShowState::kTimerCanceled:
      if (state_->timer_canceled) {
        return std::nullopt;
      }
      return ShowState::kShown;
    case ShowState::kShown:
      if (CanStartSelectionTimer()) {
        // Start waiting for the selection delay timer.
        return ShowState::kWaitingForTimerToStop;
      }
      return std::nullopt;
    case ShowState::kBlockedBySegmentation:
      return std::nullopt;
  }
}

void ProactiveNudgeTracker::TransitionToState(ShowState new_show_state) {
  switch (new_show_state) {
    case ShowState::kInitial:
      NOTREACHED();
    case ShowState::kWaitingForTimerToStop:
      BeginWaitingForTimerToStop();
      break;
    case ShowState::kWaitingForSegmentation:
      BeginSegmentation();
      break;
    case ShowState::kWaitingForProactiveNudgeRequest:
      BeginWaitingForProactiveNudgeRequest();
      break;
    case ShowState::kBlockedBySegmentation:
      BeginBlockedBySegmentation();
      break;
    case ShowState::kShown:
      BeginShown();
      break;
    case ShowState::kTimerCanceled:
      BeginTimerCanceled();
      break;
  }
  state_->show_state = new_show_state;
}

void ProactiveNudgeTracker::BeginWaitingForTimerToStop() {
  StartOrRestartTimer();
}

bool ProactiveNudgeTracker::CanStartFocusTimer() {
  if (!GetComposeConfig().proactive_nudge_enabled) {
    return false;
  }
  if (GetComposeConfig().proactive_nudge_field_per_navigation &&
      CachedSegmentationResult().has_value()) {
    return false;
  }
  return GetComposeConfig().proactive_nudge_focus_delay > base::Seconds(0);
}

bool ProactiveNudgeTracker::CanStartTextSettledTimer() {
  if (!GetComposeConfig().proactive_nudge_enabled) {
    return false;
  }
  if (GetComposeConfig().proactive_nudge_field_per_navigation &&
      CachedSegmentationResult().has_value()) {
    return false;
  }
  if (state_->text_change_count <
      GetComposeConfig().proactive_nudge_text_change_count) {
    return false;
  }
  return GetComposeConfig().proactive_nudge_text_settled_delay >
         base::Seconds(0);
}

bool ProactiveNudgeTracker::CanStartSelectionTimer() {
  if (!state_ || !CachedSegmentationResult().value_or(true)) {
    return false;
  }
  if (GetComposeConfig().selection_nudge_once_per_focus &&
      state_->selection_nudge_shown) {
    return false;
  }
  return GetComposeConfig().selection_nudge_delay > base::Seconds(0) &&
         state_->selection_nudge_requested;
}

void ProactiveNudgeTracker::StartOrRestartTimer() {
  if (!state_) {
    return;
  }
  if (state_->timer.IsRunning()) {
    state_->timer.Stop();
  }

  base::TimeDelta delay = GetComposeConfig().proactive_nudge_focus_delay;
  if (CanStartSelectionTimer()) {
    delay = GetComposeConfig().selection_nudge_delay;
  } else if (CanStartTextSettledTimer()) {
    delay = GetComposeConfig().proactive_nudge_text_settled_delay;
  }

  if (delay.is_zero()) {
    return;
  }
  state_->timer.Start(FROM_HERE, delay, this,
                      &ProactiveNudgeTracker::ShowTimerElapsed);
}

void ProactiveNudgeTracker::BeginTimerCanceled() {
  if (!state_) {
    return;
  }
  state_->timer.Stop();
  state_->selection_nudge_requested = false;
  state_->timer_canceled = false;
}

void ProactiveNudgeTracker::BeginSegmentation() {
  if (!compose::GetComposeConfig().proactive_nudge_segmentation) {
    autofill::FieldGlobalId field_global_id = state_->signals.field.global_id();
    seen_fields_.emplace(field_global_id, true);
    return;
  }
  segmentation_platform::PredictionOptions options;
  options.on_demand_execution = true;

  segmentation_service_->GetClassificationResult(
      segmentation_platform::kComposePromotionKey, options,
      PopulateInputContextForField(state_->signals,
                                   delegate_->GetComposeHintMetadata()),
      base::BindOnce(&ProactiveNudgeTracker::GotClassificationResult,
                     weak_ptr_factory_.GetWeakPtr()));
}

void ProactiveNudgeTracker::BeginWaitingForProactiveNudgeRequest() {
  if (state_->segmentation_result &&
      (state_->segmentation_result_ignored_for_training ||
       compose::GetComposeConfig()
           .proactive_nudge_always_collect_training_data)) {
    engagement_trackers_[state_->signals.field.global_id()] =
        std::make_unique<EngagementTracker>(
            state_->signals.field.global_id(),
            state_->segmentation_result->request_id, this);
  }
  if (nudge_currently_requested_) {
    // If no async calls were made the first request is still active and there
    // is no need to request the nudge again.
    return;
  }
  compose::ComposeEntryPoint entry_point =
      state_->selection_nudge_requested
          ? compose::ComposeEntryPoint::kSelectionNudge
          : compose::ComposeEntryPoint::kProactiveNudge;
  delegate_->ShowProactiveNudge(state_->signals.field.renderer_form_id(),
                                state_->signals.field.global_id(), entry_point);
}

void ProactiveNudgeTracker::BeginBlockedBySegmentation() {
  if (state_->show_state != ShowState::kWaitingForSegmentation) {
    return;
  }

  if (state_->selection_nudge_requested) {
    state_->selection_nudge_requested = false;
    return;
  }

  delegate_->GetPageUkmTracker()->ComposeProactiveNudgeShouldShow();
  compose::LogComposeProactiveNudgeShowStatus(
      compose::ComposeShowStatus::kProactiveNudgeBlockedBySegmentationPlatform);
}

void ProactiveNudgeTracker::BeginShown() {
  if (!state_ ||
      state_->show_state != ShowState::kWaitingForProactiveNudgeRequest) {
    return;
  }

  if (state_->selection_nudge_requested) {
    state_->selection_nudge_requested = false;
    state_->selection_nudge_shown = true;
    compose::LogComposeSelectionNudgeCtr(
        compose::ComposeNudgeCtrEvent::kNudgeDisplayed);
    return;
  }

  compose::LogComposeProactiveNudgeCtr(
      compose::ComposeNudgeCtrEvent::kNudgeDisplayed);
  compose::LogComposeProactiveNudgeShowStatus(
      compose::ComposeShowStatus::kShouldShow);
  delegate_->GetPageUkmTracker()->ProactiveNudgeShown();
  delegate_->GetPageUkmTracker()->ComposeProactiveNudgeShouldShow();
}

void ProactiveNudgeTracker::ShowTimerElapsed() {
  DVLOG(2) << "ProactiveNudgeTracker: ShowTimerElapsed";
  // If we are not waiting for the timer, the elapsed timer is stale and should
  // be ignored.
  if (!state_ || state_->show_state != ShowState::kWaitingForTimerToStop) {
    return;
  }
  UpdateStateForCurrentFormField();
}


void ProactiveNudgeTracker::GotClassificationResult(
    const segmentation_platform::ClassificationResult& result) {
  if (!state_ || state_->show_state != ShowState::kWaitingForSegmentation) {
    return;
  }

  state_->segmentation_result = result;

  switch (state_->segmentation_result->status) {
    case segmentation_platform::PredictionStatus::kFailed:
    case segmentation_platform::PredictionStatus::kNotReady:
      if (delegate_->SegmentationFallbackShowResult() <
          compose::GetComposeConfig().proactive_nudge_show_probability) {
        // Override default DontShow decision.
        state_->segmentation_result->ordered_labels.emplace_back(
            segmentation_platform::kComposePrmotionLabelShow);
      }
      break;
    case segmentation_platform::PredictionStatus::kSucceeded:
      if (delegate_->SegmentationForceShowResult() <
          compose::GetComposeConfig().proactive_nudge_force_show_probability) {
        state_->segmentation_result->ordered_labels = {
            segmentation_platform::kComposePrmotionLabelShow};
        state_->segmentation_result_ignored_for_training = true;
      }
      break;
  }

  UpdateStateForCurrentFormField();
}

void ProactiveNudgeTracker::CollectTrainingData(
    const segmentation_platform::TrainingRequestId training_request_id,
    ProactiveNudgeDerivedEngagement engagement) {
  segmentation_platform::TrainingLabels training_labels;
  base::UmaHistogramEnumeration("Compose.ProactiveNudge.DerivedEngagement",
                                engagement);
  training_labels.output_metric =
      std::make_pair("Compose.ProactiveNudge.DerivedEngagement",
                     static_cast<base::HistogramBase::Sample32>(engagement));
  ukm::SourceId source =
      state_ ? state_->signals.ukm_source_id : ukm::kInvalidSourceId;
  segmentation_service_->CollectTrainingData(
      segmentation_platform::proto::SegmentId::
          OPTIMIZATION_TARGET_SEGMENTATION_COMPOSE_PROMOTION,
      training_request_id, source, training_labels, base::DoNothing());
}

bool ProactiveNudgeTracker::MatchesCurrentField(autofill::FormGlobalId form,
                                                autofill::FieldGlobalId field) {
  return state_ && state_->signals.field.renderer_form_id() == form &&
         state_->signals.field.global_id() == field;
}

void ProactiveNudgeTracker::ComposeSessionCompleted(
    autofill::FieldGlobalId field_global_id,
    ComposeSessionCloseReason session_close_reason,
    const compose::ComposeSessionEvents& events) {
  auto iter = engagement_trackers_.find(field_global_id);
  if (iter != engagement_trackers_.end()) {
    iter->second->ComposeSessionCompleted(session_close_reason, events);
    engagement_trackers_.erase(iter);
  }
  ResetState();
}

void ProactiveNudgeTracker::OnUserDisabledNudge(bool single_site_only) {
  for (auto& iter : engagement_trackers_) {
    iter.second->UserDisabledNudge(single_site_only);
  }
  engagement_trackers_.clear();
}

std::optional<bool> ProactiveNudgeTracker::CachedSegmentationResult() {
  if (!state_) {
    return std::nullopt;
  }
  if (auto iter = seen_fields_.find(state_->signals.field.global_id());
      iter != seen_fields_.end()) {
    return iter->second;
  }
  return std::nullopt;
}

}  // namespace compose