#include "chrome/browser/ai/ai_model_download_progress_manager.h"
#include <cstdint>
#include <memory>
#include "base/task/current_thread.h"
#include "base/test/gtest_util.h"
#include "base/time/time.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "chrome/browser/ai/ai_utils.h"
#include "testing/gmock/include/gmock/gmock-matchers.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace on_device_ai {
using testing::_;
using ComponentList =
base::flat_set<std::unique_ptr<AIModelDownloadProgressManager::Component>>;
namespace {
class FakeComponent {
public:
FakeComponent(std::optional<int64_t> downloaded_bytes,
std::optional<int64_t> total_bytes)
: downloaded_bytes_(downloaded_bytes), total_bytes_(total_bytes) {}
int64_t total_bytes() {
CHECK(total_bytes_.has_value());
return total_bytes_.value();
}
std::unique_ptr<AIModelDownloadProgressManager::Component> GetImpl() {
CHECK(!impl_);
std::unique_ptr<Impl> impl = std::make_unique<Impl>();
impl_ = impl->weak_ptr_factory_.GetWeakPtr();
if (total_bytes_) {
impl_->SetTotalBytes(total_bytes_.value());
}
if (downloaded_bytes_) {
impl_->SetDownloadedBytes(downloaded_bytes_.value());
}
return impl;
}
ComponentList GetImplAsList() {
ComponentList component_list;
component_list.insert(GetImpl());
return component_list;
}
void SetTotalBytes(int64_t total_bytes) {
total_bytes_ = total_bytes;
if (impl_) {
impl_->SetTotalBytes(total_bytes);
}
}
void SetDownloadedBytes(int64_t downloaded_bytes) {
downloaded_bytes_ = downloaded_bytes;
if (impl_) {
impl_->SetDownloadedBytes(downloaded_bytes);
}
}
private:
class Impl : public AIModelDownloadProgressManager::Component {
protected:
friend FakeComponent;
base::WeakPtrFactory<Impl> weak_ptr_factory_{this};
};
base::WeakPtr<Impl> impl_;
std::optional<int64_t> downloaded_bytes_;
std::optional<int64_t> total_bytes_;
};
}
class AIModelDownloadProgressManagerTest : public testing::Test {
public:
AIModelDownloadProgressManagerTest() = default;
~AIModelDownloadProgressManagerTest() override = default;
protected:
void FastForwardBy(base::TimeDelta delta) {
task_environment_.FastForwardBy(delta);
}
private:
base::test::SingleThreadTaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
};
TEST_F(AIModelDownloadProgressManagerTest,
ReporterIsDestroyedWhenRemoteIsDisconnected) {
AIModelDownloadProgressManager manager;
EXPECT_EQ(manager.GetNumberOfReporters(), 0);
{
FakeComponent component1(std::nullopt, std::nullopt);
AITestUtils::FakeMonitor monitor1;
manager.AddObserver(monitor1.BindNewPipeAndPassRemote(),
component1.GetImplAsList());
EXPECT_EQ(manager.GetNumberOfReporters(), 1);
{
FakeComponent component2(std::nullopt, std::nullopt);
AITestUtils::FakeMonitor monitor2;
manager.AddObserver(monitor2.BindNewPipeAndPassRemote(),
component2.GetImplAsList());
EXPECT_EQ(manager.GetNumberOfReporters(), 2);
}
base::test::RunUntil([&]() { return manager.GetNumberOfReporters() == 1; });
}
base::test::RunUntil([&]() { return manager.GetNumberOfReporters() == 0; });
}
TEST_F(AIModelDownloadProgressManagerTest,
DoesntReceiveUpdateUntilAllBytesAreDetermined) {
{
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt, std::nullopt);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
component.SetDownloadedBytes(0);
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
component.SetTotalBytes(100);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
}
{
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt, 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
component.SetDownloadedBytes(0);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
}
{
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(0, std::nullopt);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
component.SetTotalBytes(100);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
}
{
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(0, 100);
FakeComponent component2(std::nullopt, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
component2.SetDownloadedBytes(0);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
}
}
TEST_F(AIModelDownloadProgressManagerTest,
SendsUpdateIfBytesAreAlreadyDetermined) {
{
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(0, 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
}
{
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(0, 100);
FakeComponent component2(0, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
}
}
TEST_F(AIModelDownloadProgressManagerTest, FirstUpdateIsReportedAsZero) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(10, 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
FastForwardBy(base::Milliseconds(51));
}
TEST_F(AIModelDownloadProgressManagerTest, ProgressIsNormalized) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt, 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(0);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
FastForwardBy(base::Milliseconds(51));
uint64_t downloaded_bytes = 15;
uint64_t normalized_downloaded_bytes =
AIUtils::NormalizeModelDownloadProgress(downloaded_bytes,
component.total_bytes());
component.SetDownloadedBytes(downloaded_bytes);
monitor.ExpectReceivedUpdate(normalized_downloaded_bytes,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest,
AlreadyDownloadedBytesArentIncludedInProgress) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt, 100);
int64_t already_downloaded_bytes = 10;
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(already_downloaded_bytes);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
FastForwardBy(base::Milliseconds(51));
uint64_t downloaded_bytes = already_downloaded_bytes + 5;
uint64_t normalized_downloaded_bytes =
AIUtils::NormalizeModelDownloadProgress(
downloaded_bytes - already_downloaded_bytes,
component.total_bytes() - already_downloaded_bytes);
component.SetDownloadedBytes(downloaded_bytes);
monitor.ExpectReceivedUpdate(normalized_downloaded_bytes,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest,
MaxIsSentWhenDownloadedBytesEqualsTotalBytes) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt,
AIUtils::kNormalizedDownloadProgressMax * 5);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(10);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
FastForwardBy(base::Milliseconds(51));
component.SetDownloadedBytes(component.total_bytes() - 1);
monitor.ExpectReceivedUpdate(AIUtils::kNormalizedDownloadProgressMax - 1,
AIUtils::kNormalizedDownloadProgressMax);
component.SetDownloadedBytes(component.total_bytes());
monitor.ExpectReceivedUpdate(AIUtils::kNormalizedDownloadProgressMax,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest,
MaxIsSentWhenDownloadedBytesEqualsTotalBytesForFirstUpdate) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt,
AIUtils::kNormalizedDownloadProgressMax * 5);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(component.total_bytes());
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
monitor.ExpectReceivedUpdate(AIUtils::kNormalizedDownloadProgressMax,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest,
ReceiveZeroAndHundredPercentForNoComponents) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
manager.AddObserver(monitor.BindNewPipeAndPassRemote(), {});
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
monitor.ExpectReceivedUpdate(AIUtils::kNormalizedDownloadProgressMax,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest, OnlyReceivesUpdatesEvery50ms) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt, 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(0);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
component.SetDownloadedBytes(15);
FastForwardBy(base::Milliseconds(51));
component.SetDownloadedBytes(20);
monitor.ExpectReceivedNormalizedUpdate(20, component.total_bytes());
component.SetDownloadedBytes(25);
}
TEST_F(AIModelDownloadProgressManagerTest, OnlyReceivesUpdatesForNewProgress) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt,
AIUtils::kNormalizedDownloadProgressMax * 2);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(0);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
FastForwardBy(base::Milliseconds(51));
component.SetDownloadedBytes(10);
monitor.ExpectReceivedNormalizedUpdate(10, component.total_bytes());
FastForwardBy(base::Milliseconds(51));
component.SetDownloadedBytes(10);
component.SetDownloadedBytes(9);
FastForwardBy(base::Milliseconds(51));
CHECK_EQ(
AIUtils::NormalizeModelDownloadProgress(10, component.total_bytes()),
AIUtils::NormalizeModelDownloadProgress(11, component.total_bytes()));
component.SetDownloadedBytes(11);
FastForwardBy(base::Milliseconds(51));
}
TEST_F(AIModelDownloadProgressManagerTest, ShouldReceive100percent) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component(std::nullopt, 100);
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
component.GetImplAsList());
component.SetDownloadedBytes(10);
monitor.ExpectReceivedNormalizedUpdate(0, component.total_bytes());
component.SetDownloadedBytes(component.total_bytes());
monitor.ExpectReceivedNormalizedUpdate(component.total_bytes(),
component.total_bytes());
FastForwardBy(base::Milliseconds(51));
}
TEST_F(AIModelDownloadProgressManagerTest,
AllComponentsMustBeObservedBeforeSendingEvents) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(std::nullopt, 100);
FakeComponent component2(std::nullopt, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
component1.SetDownloadedBytes(0);
monitor.ExpectNoUpdate();
FastForwardBy(base::Milliseconds(51));
component2.SetDownloadedBytes(10);
uint64_t total_bytes = component1.total_bytes() + component2.total_bytes();
monitor.ExpectReceivedNormalizedUpdate(0, total_bytes);
}
TEST_F(AIModelDownloadProgressManagerTest,
ProgressIsNormalizedAgainstTheSumOfTheComponentsTotalBytes) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(std::nullopt, 100);
FakeComponent component2(std::nullopt, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
uint64_t component1_downloaded_bytes = 0;
component1.SetDownloadedBytes(component1_downloaded_bytes);
uint64_t component2_downloaded_bytes = 0;
component2.SetDownloadedBytes(component2_downloaded_bytes);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
FastForwardBy(base::Milliseconds(51));
component2_downloaded_bytes += 5;
component2.SetDownloadedBytes(component2_downloaded_bytes);
uint64_t downloaded_bytes =
component1_downloaded_bytes + component2_downloaded_bytes;
uint64_t total_bytes = component1.total_bytes() + component2.total_bytes();
uint64_t normalized_downloaded_bytes =
AIUtils::NormalizeModelDownloadProgress(downloaded_bytes, total_bytes);
monitor.ExpectReceivedUpdate(normalized_downloaded_bytes,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest,
AlreadyDownloadedBytesArentIncludedInProgressForMultipleComponents) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(std::nullopt, 100);
FakeComponent component2(std::nullopt, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
int64_t already_downloaded_bytes = 0;
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
uint64_t component1_downloaded_bytes = 5;
already_downloaded_bytes += 5;
component1.SetDownloadedBytes(component1_downloaded_bytes);
component1_downloaded_bytes += 5;
already_downloaded_bytes += 5;
component1.SetDownloadedBytes(component1_downloaded_bytes);
uint64_t component2_downloaded_bytes = 10;
already_downloaded_bytes += 10;
component2.SetDownloadedBytes(component2_downloaded_bytes);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
FastForwardBy(base::Milliseconds(51));
component2_downloaded_bytes += 5;
component2.SetDownloadedBytes(component2_downloaded_bytes);
uint64_t downloaded_bytes =
component1_downloaded_bytes + component2_downloaded_bytes;
uint64_t total_bytes = component1.total_bytes() + component2.total_bytes();
uint64_t normalized_downloaded_bytes =
AIUtils::NormalizeModelDownloadProgress(
downloaded_bytes - already_downloaded_bytes,
total_bytes - already_downloaded_bytes);
monitor.ExpectReceivedUpdate(normalized_downloaded_bytes,
AIUtils::kNormalizedDownloadProgressMax);
}
TEST_F(AIModelDownloadProgressManagerTest,
AlreadyInstalledComponentsAreNotObserved) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(100, 100);
FakeComponent component2(std::nullopt, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
component2.SetDownloadedBytes(0);
monitor.ExpectReceivedNormalizedUpdate(0, component2.total_bytes());
}
TEST_F(AIModelDownloadProgressManagerTest,
ProgressIsNormalizedAgainstOnlyUninstalledComponents) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(100, 100);
FakeComponent component2(std::nullopt, 1000);
FakeComponent component3(std::nullopt, 500);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
component_list.insert(component3.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
component2.SetDownloadedBytes(0);
component3.SetDownloadedBytes(0);
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
FastForwardBy(base::Milliseconds(51));
component2.SetDownloadedBytes(10);
uint64_t total_bytes = component2.total_bytes() + component3.total_bytes();
monitor.ExpectReceivedNormalizedUpdate(10, total_bytes);
}
TEST_F(AIModelDownloadProgressManagerTest,
ReceiveZeroAndHundredPercentWhenEverythingIsInstalled) {
AIModelDownloadProgressManager manager;
AITestUtils::FakeMonitor monitor;
FakeComponent component1(100, 100);
FakeComponent component2(1000, 1000);
ComponentList component_list;
component_list.insert(component1.GetImpl());
component_list.insert(component2.GetImpl());
manager.AddObserver(monitor.BindNewPipeAndPassRemote(),
std::move(component_list));
monitor.ExpectReceivedUpdate(0, AIUtils::kNormalizedDownloadProgressMax);
monitor.ExpectReceivedUpdate(AIUtils::kNormalizedDownloadProgressMax,
AIUtils::kNormalizedDownloadProgressMax);
}
}