#include "xsched/utils/pci.h"
#include "xsched/protocol/device.h"
#include "xsched/opencl/hal.h"
#include "xsched/opencl/hal/ocl_queue.h"
#include "xsched/opencl/hal/ocl_assert.h"
#include "xsched/opencl/hal/ocl_command.h"
using namespace xsched::opencl;
using namespace xsched::preempt;
using namespace xsched::protocol;
OclQueue::OclQueue(cl_command_queue cmdq): kCmdq(cmdq)
{
cl_device_id id;
cl_device_type type;
cl_device_pci_bus_info_khr pci;
XDeviceId dev_id = 0;
OCL_ASSERT(Driver::GetCommandQueueInfo(kCmdq, CL_QUEUE_DEVICE, sizeof(id), &id, nullptr));
OCL_ASSERT(Driver::GetDeviceInfo(id, CL_DEVICE_TYPE, sizeof(type), &type, nullptr));
cl_int ret = Driver::GetDeviceInfo(id, CL_DEVICE_PCI_BUS_INFO_KHR, sizeof(pci), &pci, nullptr);
if (ret == CL_SUCCESS) {
dev_id = MakePciId(pci.pci_domain, pci.pci_bus, pci.pci_device, pci.pci_function);
}
device_ = MakeDevice(GetXDeviceType(type), dev_id);
XDEBG("OclQueue (%p) created for cmdq (%p)", this, kCmdq);
OCL_ASSERT(Driver::Flush(kCmdq));
OCL_ASSERT(Driver::Finish(kCmdq));
}
void OclQueue::Launch(std::shared_ptr<preempt::HwCommand> hw_cmd)
{
auto ocl_cmd = std::dynamic_pointer_cast<OclCommand>(hw_cmd);
XASSERT(ocl_cmd != nullptr, "hw_cmd is not an OclCommand");
OCL_ASSERT(ocl_cmd->LaunchWrapper(kCmdq));
}
void OclQueue::Synchronize()
{
XDEBG("synchronize cmdq (%p)", kCmdq);
OCL_ASSERT(Driver::Flush(kCmdq));
OCL_ASSERT(Driver::Finish(kCmdq));
}
EXPORT_C_FUNC XResult OclQueueCreate(HwQueueHandle *hwq, cl_command_queue cmdq)
{
if (hwq == nullptr) {
XWARN("OclQueueCreate failed: hwq is nullptr");
return kXSchedErrorInvalidValue;
}
if (cmdq == nullptr) {
XWARN("OclQueueCreate failed: cmdq is nullptr");
return kXSchedErrorInvalidValue;
}
HwQueueHandle hwq_h = GetHwQueueHandle(cmdq);
auto res = HwQueueManager::Add(hwq_h, [&]() { return std::make_shared<OclQueue>(cmdq); });
if (res == kXSchedSuccess) *hwq = hwq_h;
return res;
}