#include "ios/web/public/test/web_task_environment.h"
#include <memory>
#include "base/notreached.h"
#include "base/run_loop.h"
#include "ios/web/public/test/test_web_thread.h"
#include "ios/web/web_thread_impl.h"
namespace web {
namespace {
base::test::TaskEnvironment::MainThreadType ConvertMainThreadType(
WebTaskEnvironment::MainThreadType main_thread_type) {
switch (main_thread_type) {
case WebTaskEnvironment::MainThreadType::UI:
return base::test::TaskEnvironment::MainThreadType::UI;
case WebTaskEnvironment::MainThreadType::IO:
return base::test::TaskEnvironment::MainThreadType::IO;
}
NOTREACHED();
}
}
WebTaskEnvironment::WebTaskEnvironment(TimeSource time_source,
MainThreadType main_thread_type,
IOThreadType io_thread_type,
base::trait_helpers::NotATraitTag tag)
: TaskEnvironment(time_source, ConvertMainThreadType(main_thread_type)),
io_thread_type_(io_thread_type) {
WebThreadImpl::CreateTaskExecutor();
ui_thread_ =
std::make_unique<TestWebThread>(WebThread::UI, GetMainThreadTaskRunner());
if (io_thread_type_ != IOThreadType::REAL_THREAD_DELAYED) {
StartIOThreadInternal();
}
}
WebTaskEnvironment::~WebTaskEnvironment() {
base::RunLoop().RunUntilIdle();
if (io_thread_) {
io_thread_->Stop();
}
base::RunLoop().RunUntilIdle();
ui_thread_->Stop();
base::RunLoop().RunUntilIdle();
RunUntilIdle();
WebThreadImpl::ResetTaskExecutorForTesting();
}
void WebTaskEnvironment::StartIOThread() {
DCHECK_EQ(io_thread_type_, IOThreadType::REAL_THREAD_DELAYED);
StartIOThreadInternal();
}
void WebTaskEnvironment::StartIOThreadInternal() {
DCHECK(!io_thread_);
if (io_thread_type_ == IOThreadType::FAKE_THREAD) {
io_thread_ = std::make_unique<TestWebThread>(WebThread::IO,
GetMainThreadTaskRunner());
} else {
io_thread_ = std::make_unique<TestWebThread>(WebThread::IO);
io_thread_->StartIOThread();
}
}
}