#include "xsched/utils/pci.h"
#include "xsched/protocol/device.h"
#include "xsched/levelzero/hal/ze_queue.h"
#include "xsched/levelzero/hal/ze_assert.h"
using namespace xsched::preempt;
using namespace xsched::protocol;
using namespace xsched::levelzero;
ZeQueue::ZeQueue(ze_device_handle_t dev, ze_command_queue_handle_t cmdq): kDev(dev), kCmdq(cmdq)
{
ze_device_properties_t dev_props;
ze_pci_ext_properties_t pci_ext_props;
dev_props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
dev_props.pNext = nullptr;
pci_ext_props.stype = ZE_STRUCTURE_TYPE_PCI_EXT_PROPERTIES;
pci_ext_props.pNext = nullptr;
ZE_ASSERT(Driver::DeviceGetProperties(kDev, &dev_props));
ZE_ASSERT(Driver::DevicePciGetPropertiesExt(kDev, &pci_ext_props));
XDeviceId id = MakePciId(pci_ext_props.address.domain, pci_ext_props.address.bus,
pci_ext_props.address.device, pci_ext_props.address.function);
device_ = MakeDevice(GetXDeviceType(dev_props.type), id);
ZE_ASSERT(Driver::CommandQueueSynchronize(kCmdq, UINT64_MAX));
}
void ZeQueue::Launch(std::shared_ptr<preempt::HwCommand> hw_cmd)
{
auto cmd = std::dynamic_pointer_cast<ZeListExecuteCommand>(hw_cmd);
XASSERT(cmd != nullptr, "hw_cmd is not a ZeListExecuteCommand");
ZE_ASSERT(cmd->Launch(kCmdq));
}
void ZeQueue::Synchronize()
{
ZE_ASSERT(Driver::CommandQueueSynchronize(kCmdq, UINT64_MAX));
}
ZeIntelNpuQueue::ZeIntelNpuQueue(ze_device_handle_t dev, ze_command_queue_handle_t cmdq)
: ZeQueue(dev, cmdq)
{
kmd_cmdq_id_ = get_kmd_cmdq_id(cmdq);
}
void ZeIntelNpuQueue::Deactivate()
{
npu_sched_suspend_cmdq(kmd_cmdq_id_);
}
void ZeIntelNpuQueue::Reactivate(const preempt::CommandLog &)
{
npu_sched_resume_cmdq(kmd_cmdq_id_);
}
ZeList::ZeList(ze_device_handle_t dev, ze_command_list_handle_t cmdl)
: kDev(dev), kCmdl(cmdl)
{
ze_device_properties_t dev_props;
ze_pci_ext_properties_t pci_ext_props;
dev_props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
dev_props.pNext = nullptr;
pci_ext_props.stype = ZE_STRUCTURE_TYPE_PCI_EXT_PROPERTIES;
pci_ext_props.pNext = nullptr;
ZE_ASSERT(Driver::DeviceGetProperties(kDev, &dev_props));
ZE_ASSERT(Driver::DevicePciGetPropertiesExt(kDev, &pci_ext_props));
XDeviceId id = MakePciId(pci_ext_props.address.domain, pci_ext_props.address.bus,
pci_ext_props.address.device, pci_ext_props.address.function);
device_ = MakeDevice(GetXDeviceType(dev_props.type), id);
ZE_ASSERT(Driver::CommandListHostSynchronize(kCmdl, UINT64_MAX));
}
void ZeList::Launch(std::shared_ptr<preempt::HwCommand> hw_cmd)
{
auto cmd = std::dynamic_pointer_cast<ZeKernelCommand>(hw_cmd);
XASSERT(cmd != nullptr, "hw_cmd is not a ZeListExecuteCommand");
ZE_ASSERT(cmd->Launch());
}
void ZeList::Synchronize()
{
ZE_ASSERT(Driver::CommandListHostSynchronize(kCmdl, UINT64_MAX));
}
EXPORT_C_FUNC XResult ZeQueueCreate(HwQueueHandle *hwq, ze_device_handle_t dev, ze_command_queue_handle_t cmdq)
{
if (hwq == nullptr) {
XWARN("ZeQueueCreate failed: hwq is nullptr");
return kXSchedErrorInvalidValue;
}
if (dev == nullptr || cmdq == nullptr) {
XWARN("ZeQueueCreate failed: dev or cmdq is nullptr");
return kXSchedErrorInvalidValue;
}
ze_device_properties_t dev_props;
dev_props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
dev_props.pNext = nullptr;
ZE_ASSERT(Driver::DeviceGetProperties(dev, &dev_props));
HwQueueHandle hwq_h = GetHwQueueHandle(cmdq);
auto res = HwQueueManager::Add(hwq_h, [&]() -> std::shared_ptr<ZeQueue> {
if (dev_props.type == ZE_DEVICE_TYPE_VPU) {
return std::make_shared<ZeIntelNpuQueue>(dev, cmdq);
}
return std::make_shared<ZeQueue>(dev, cmdq);
});
if (res == kXSchedSuccess) *hwq = hwq_h;
return res;
}
EXPORT_C_FUNC XResult ZeListreate(HwQueueHandle *hwq, ze_device_handle_t dev, ze_command_list_handle_t cmdl)
{
if (hwq == nullptr) {
XWARN("ZeListreate failed: hwq is nullptr");
return kXSchedErrorInvalidValue;
}
if (dev == nullptr || cmdl == nullptr) {
XWARN("ZeListreate failed: dev or cmdl is nullptr");
return kXSchedErrorInvalidValue;
}
ze_device_properties_t dev_props;
dev_props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
dev_props.pNext = nullptr;
ZE_ASSERT(Driver::DeviceGetProperties(dev, &dev_props));
HwQueueHandle hwq_h = GetHwQueueHandle(cmdl);
auto res = HwQueueManager::Add(hwq_h, [&]() -> std::shared_ptr<ZeList> {
if (dev_props.type == ZE_DEVICE_TYPE_VPU) {
XERRO_UNSUPPORTED();
}
return std::make_shared<ZeList>(dev, cmdl);
});
if (res == kXSchedSuccess) *hwq = hwq_h;
return res;
}