#include "net/base/test_completion_callback.h"
#include "base/check_op.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/memory/raw_ptr.h"
#include "base/notreached.h"
#include "base/task/single_thread_task_runner.h"
#include "net/base/completion_once_callback.h"
#include "net/test/test_with_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"
namespace net {
namespace {
const int kMagicResult = 8888;
void CallClosureAfterCheckingResult(base::OnceClosure closure,
bool* did_check_result,
int result) {
DCHECK_EQ(result, kMagicResult);
*did_check_result = true;
std::move(closure).Run();
}
class ExampleEmployer {
public:
ExampleEmployer();
ExampleEmployer(const ExampleEmployer&) = delete;
ExampleEmployer& operator=(const ExampleEmployer&) = delete;
~ExampleEmployer();
bool DoSomething(CompletionOnceCallback callback);
private:
class ExampleWorker;
friend class ExampleWorker;
scoped_refptr<ExampleWorker> request_;
};
class ExampleEmployer::ExampleWorker
: public base::RefCountedThreadSafe<ExampleWorker> {
public:
ExampleWorker(ExampleEmployer* employer, CompletionOnceCallback callback)
: employer_(employer), callback_(std::move(callback)) {}
void DoWork();
void DoCallback();
private:
friend class base::RefCountedThreadSafe<ExampleWorker>;
~ExampleWorker() = default;
raw_ptr<ExampleEmployer> employer_;
CompletionOnceCallback callback_;
const scoped_refptr<base::SingleThreadTaskRunner> origin_task_runner_ =
base::SingleThreadTaskRunner::GetCurrentDefault();
};
void ExampleEmployer::ExampleWorker::DoWork() {
origin_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ExampleWorker::DoCallback, this));
}
void ExampleEmployer::ExampleWorker::DoCallback() {
employer_->request_ = nullptr;
std::move(callback_).Run(kMagicResult);
}
ExampleEmployer::ExampleEmployer() = default;
ExampleEmployer::~ExampleEmployer() = default;
bool ExampleEmployer::DoSomething(CompletionOnceCallback callback) {
DCHECK(!request_.get()) << "already in use";
request_ = base::MakeRefCounted<ExampleWorker>(this, std::move(callback));
if (!base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ExampleWorker::DoWork, request_))) {
NOTREACHED();
request_ = nullptr;
return false;
}
return true;
}
}
class TestCompletionCallbackTest : public PlatformTest,
public WithTaskEnvironment {};
TEST_F(TestCompletionCallbackTest, Simple) {
ExampleEmployer boss;
TestCompletionCallback callback;
bool queued = boss.DoSomething(callback.callback());
EXPECT_TRUE(queued);
int result = callback.WaitForResult();
EXPECT_EQ(result, kMagicResult);
}
TEST_F(TestCompletionCallbackTest, Closure) {
ExampleEmployer boss;
TestClosure closure;
bool did_check_result = false;
CompletionOnceCallback completion_callback =
base::BindOnce(&CallClosureAfterCheckingResult, closure.closure(),
base::Unretained(&did_check_result));
bool queued = boss.DoSomething(std::move(completion_callback));
EXPECT_TRUE(queued);
EXPECT_FALSE(did_check_result);
closure.WaitForResult();
EXPECT_TRUE(did_check_result);
}
}