#include "chrome/browser/ai/ai_crx_component.h"
#include <cstdint>
#include <memory>
#include <utility>
#include "base/barrier_closure.h"
#include "base/task/current_thread.h"
#include "base/test/gtest_util.h"
#include "base/time/time.h"
#include "chrome/browser/ai/ai_model_download_progress_manager.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "chrome/browser/ai/ai_utils.h"
#include "testing/gmock/include/gmock/gmock-matchers.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace on_device_ai {
using component_updater::CrxUpdateItem;
using testing::_;
using update_client::ComponentState;
class AICrxComponentTest : public testing::Test {
public:
AICrxComponentTest() = default;
~AICrxComponentTest() override = default;
protected:
void SendUpdate(const AITestUtils::FakeComponent& component,
ComponentState state,
uint64_t downloaded_bytes) {
component_update_service_.SendUpdate(
component.CreateUpdateItem(state, downloaded_bytes));
}
void FastForwardBy(base::TimeDelta delta) {
task_environment_.FastForwardBy(delta);
}
AITestUtils::FakeComponent& CreateComponent(std::string id,
uint64_t total_bytes) {
auto [iter, emplaced] = fake_components_.try_emplace(id, id, total_bytes);
CHECK(emplaced);
return iter->second;
}
AITestUtils::MockComponentUpdateService component_update_service_;
private:
void SetUp() override {
EXPECT_CALL(component_update_service_, GetComponentDetails(_, _))
.WillRepeatedly([&](const std::string& id, CrxUpdateItem* item) {
auto iter = fake_components_.find(id);
if (iter == fake_components_.end()) {
return false;
}
*item = iter->second.CreateUpdateItem(
update_client::ComponentState::kNew, 0);
return true;
});
}
std::map<std::string, AITestUtils::FakeComponent> fake_components_;
base::test::SingleThreadTaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
};
TEST_F(AICrxComponentTest, DoesntReceiveUpdatesForNonDownloadEvents) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
AITestUtils::FakeComponent& component = CreateComponent("component_id", 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(
&component_update_service_, {component.id()}));
for (const auto state : {
ComponentState::kNew,
ComponentState::kChecking,
ComponentState::kCanUpdate,
ComponentState::kUpdated,
ComponentState::kUpdateError,
ComponentState::kRun,
}) {
SendUpdate(component, state, 10);
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
}
}
TEST_F(AICrxComponentTest,
DoesntReceiveUpdatesForEventsWithNegativeDownloadedBytes) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
AITestUtils::FakeComponent& component = CreateComponent("component_id", 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(
&component_update_service_, {component.id()}));
SendUpdate(component, ComponentState::kDownloading, -1);
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
}
TEST_F(AICrxComponentTest,
DoesntReceiveUpdatesForEventsWithNegativeTotalBytes) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
AITestUtils::FakeComponent& component = CreateComponent("component_id", -1);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(
&component_update_service_, {component.id()}));
SendUpdate(component, ComponentState::kDownloading, 0);
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
}
TEST_F(AICrxComponentTest, DoesntReceiveUpdatesForComponentsNotObserving) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
AITestUtils::FakeComponent& component_observed =
CreateComponent("component_id1", 100);
AITestUtils::FakeComponent& component_not_observed =
CreateComponent("component_id2", 100);
manager.AddObserver(
monitor.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(&component_update_service_,
{component_observed.id()}));
SendUpdate(component_not_observed, ComponentState::kDownloading, 10);
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
}
TEST_F(AICrxComponentTest, ObservesComponentsMidDownload) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor1;
AITestUtils::FakeMonitor monitor2;
AITestUtils::FakeComponent& component = CreateComponent("component_id", 100);
{
manager.AddObserver(monitor1.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(
&component_update_service_, {component.id()}));
}
SendUpdate(component, ComponentState::kDownloading, 0);
monitor1.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
monitor2.ExpectNoUpdate();
{
manager.AddObserver(monitor2.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(
&component_update_service_, {component.id()}));
}
constexpr int64_t update1_for_monitor2 = 60;
FastForwardBy(base::Milliseconds(51));
SendUpdate(component, ComponentState::kDownloading, update1_for_monitor2);
{
base::RunLoop run_loop;
base::RepeatingClosure update_callback =
base::BarrierClosure(2, run_loop.QuitClosure());
monitor1.ExpectReceivedNormalizedUpdate(
update1_for_monitor2, component.total_bytes(), update_callback);
monitor2.ExpectReceivedNormalizedUpdate(
0, component.total_bytes() - update1_for_monitor2, update_callback);
run_loop.Run();
}
constexpr int64_t update2_for_monitor2 = 75;
FastForwardBy(base::Milliseconds(51));
SendUpdate(component, ComponentState::kDownloading, update2_for_monitor2);
{
base::RunLoop run_loop;
base::RepeatingClosure update_callback =
base::BarrierClosure(2, run_loop.QuitClosure());
monitor1.ExpectReceivedNormalizedUpdate(
update2_for_monitor2, component.total_bytes(), update_callback);
monitor2.ExpectReceivedNormalizedUpdate(
update2_for_monitor2 - update1_for_monitor2,
component.total_bytes() - update1_for_monitor2, update_callback);
run_loop.Run();
}
}
TEST_F(AICrxComponentTest, DownloadedBytesWontExceedTotalBytes) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
AITestUtils::FakeComponent& component = CreateComponent("component_id", 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
AICrxComponent::FromComponentIds(
&component_update_service_, {component.id()}));
SendUpdate(component, ComponentState::kDownloading, 0);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
FastForwardBy(base::Milliseconds(51));
SendUpdate(component, ComponentState::kDownloading,
component.total_bytes() * 2);
monitor.ExpectReceivedNormalizedUpdate(component.total_bytes(),
component.total_bytes());
FastForwardBy(base::Milliseconds(51));
}
}