#ifndef SERVICES_WEBNN_DML_COMMAND_QUEUE_H_
#define SERVICES_WEBNN_DML_COMMAND_QUEUE_H_
#include <deque>
#include <vector>
#include "base/component_export.h"
#include "base/containers/span.h"
#include "base/functional/callback_forward.h"
#include "base/gtest_prod_util.h"
#include "base/memory/ref_counted.h"
#include "base/sequence_checker.h"
#include "base/win/object_watcher.h"
#include "base/win/scoped_handle.h"
#include "third_party/microsoft_dxheaders/src/include/directx/d3d12.h"
#include <wrl.h>
namespace webnn::dml {
class COMPONENT_EXPORT(WEBNN_SERVICE) CommandQueue
: public base::win::ObjectWatcher::Delegate,
public base::RefCountedThreadSafe<CommandQueue> {
public:
static scoped_refptr<CommandQueue> Create(ID3D12Device* d3d12_device);
CommandQueue(const CommandQueue&) = delete;
CommandQueue& operator=(const CommandQueue&) = delete;
HRESULT ExecuteCommandList(ID3D12CommandList* command_list);
HRESULT ExecuteCommandLists(base::span<ID3D12CommandList*> command_lists);
HRESULT WaitSync();
void WaitAsync(base::OnceCallback<void(HRESULT hr)> callback);
void ReferenceUntilCompleted(Microsoft::WRL::ComPtr<IUnknown> object);
uint64_t GetCompletedValue() const;
uint64_t GetLastFenceValue() const;
ID3D12Fence* submission_fence() const { return fence_.Get(); }
HRESULT WaitForFence(Microsoft::WRL::ComPtr<ID3D12Fence> wait_fence,
uint64_t wait_fence_value);
private:
FRIEND_TEST_ALL_PREFIXES(WebNNCommandQueueTest, ReferenceAndRelease);
friend class base::RefCountedThreadSafe<CommandQueue>;
CommandQueue(Microsoft::WRL::ComPtr<ID3D12CommandQueue> command_queue,
Microsoft::WRL::ComPtr<ID3D12Fence> fence);
~CommandQueue() override;
void ReleaseCompletedResources();
struct QueuedObject {
QueuedObject() = delete;
QueuedObject(uint64_t fence_value, Microsoft::WRL::ComPtr<IUnknown> object);
QueuedObject(QueuedObject&& other);
QueuedObject& operator=(QueuedObject&& other);
~QueuedObject();
uint64_t fence_value = 0;
Microsoft::WRL::ComPtr<IUnknown> object;
};
std::deque<QueuedObject> queued_objects_
GUARDED_BY_CONTEXT(sequence_checker_);
const std::deque<QueuedObject>& GetQueuedObjectsForTesting() const;
struct QueuedCallback {
QueuedCallback() = delete;
QueuedCallback(uint64_t fence_value, base::OnceClosure callback);
QueuedCallback(QueuedCallback&& other);
QueuedCallback& operator=(QueuedCallback&& other);
~QueuedCallback();
uint64_t fence_value = 0;
base::OnceClosure callback;
};
std::deque<QueuedCallback> queued_callbacks_
GUARDED_BY_CONTEXT(sequence_checker_);
class PendingWorkDelegate : public base::win::ObjectWatcher::Delegate {
public:
PendingWorkDelegate(
std::deque<CommandQueue::QueuedObject> queued_objects,
Microsoft::WRL::ComPtr<ID3D12CommandQueue> command_queue,
uint64_t last_fence_value,
Microsoft::WRL::ComPtr<ID3D12Fence> fence,
base::win::ScopedHandle fence_event);
~PendingWorkDelegate() override;
PendingWorkDelegate(const PendingWorkDelegate&) = delete;
PendingWorkDelegate& operator=(const PendingWorkDelegate&) = delete;
private:
void OnObjectSignaled(HANDLE object) override;
std::deque<CommandQueue::QueuedObject> queued_objects_;
Microsoft::WRL::ComPtr<ID3D12CommandQueue> command_queue_;
const uint64_t last_fence_value_;
Microsoft::WRL::ComPtr<ID3D12Fence> fence_;
base::win::ScopedHandle fence_event_;
base::win::ObjectWatcher object_watcher_;
};
void OnObjectSignaled(HANDLE object) override;
static void ScheduleCleanupForPendingWork(
std::deque<CommandQueue::QueuedObject> queued_objects,
Microsoft::WRL::ComPtr<ID3D12CommandQueue> command_queue,
uint64_t last_fence_value,
Microsoft::WRL::ComPtr<ID3D12Fence> fence);
Microsoft::WRL::ComPtr<ID3D12CommandQueue> command_queue_
GUARDED_BY_CONTEXT(sequence_checker_);
uint64_t last_fence_value_ GUARDED_BY_CONTEXT(sequence_checker_) = 0;
Microsoft::WRL::ComPtr<ID3D12Fence> fence_;
base::win::ScopedHandle fence_event_ GUARDED_BY_CONTEXT(sequence_checker_);
base::win::ObjectWatcher object_watcher_
GUARDED_BY_CONTEXT(sequence_checker_);
SEQUENCE_CHECKER(sequence_checker_);
};
}
#endif