#include "xsched/protocol/def.h"
#include "xsched/levelzero/hal/pool.h"
#include "xsched/levelzero/shim/cmd_list.h"
using namespace xsched::levelzero;
std::shared_ptr<SlicedCommandList> CommandListManager::Get(ze_command_list_handle_t cmd_list)
{
if (GetSliceCmdCnt() == 0) return nullptr;
std::lock_guard<std::mutex> lock(mtx_);
auto it = slices_.find(cmd_list);
if (it == slices_.end()) return nullptr;
return it->second;
}
ze_result_t CommandListManager::Create(ze_context_handle_t ctx, ze_device_handle_t dev,
const ze_command_list_desc_t *desc,
ze_command_list_handle_t *cmd_list)
{
ze_result_t res = Driver::CommandListCreate(ctx, dev, desc, cmd_list);
if (res != ZE_RESULT_SUCCESS) return res;
if (GetSliceCmdCnt() == 0) return res;
std::lock_guard<std::mutex> lock(mtx_);
auto it = slices_.find(*cmd_list);
XASSERT(it == slices_.end(), "slice for cmd_list %p already exists", *cmd_list);
slices_[*cmd_list] = std::make_shared<SlicedCommandList>(SlicedCommandList{
.cmd_cnt = 0, .ctx = ctx, .dev = dev, .desc = *desc, .cmd_lists = {}
});
return res;
}
ze_result_t CommandListManager::Destroy(ze_command_list_handle_t cmd_list)
{
ze_result_t res = Driver::CommandListDestroy(cmd_list);
if (res != ZE_RESULT_SUCCESS) return res;
if (GetSliceCmdCnt() == 0) return res;
std::lock_guard<std::mutex> lock(mtx_);
auto it = slices_.find(cmd_list);
XASSERT(it != slices_.end(), "slice for cmd_list %p not found", cmd_list);
for (auto cl : it->second->cmd_lists) {
ZE_ASSERT(Driver::CommandListReset(cl));
CommandListPool::Instance(it->second->ctx, it->second->dev)->Push(cl);
}
slices_.erase(it);
return res;
}
ze_result_t CommandListManager::Close(ze_command_list_handle_t cmd_list)
{
ze_result_t res = Driver::CommandListClose(cmd_list);
if (res != ZE_RESULT_SUCCESS) return res;
if (GetSliceCmdCnt() == 0) return res;
std::lock_guard<std::mutex> lock(mtx_);
auto it = slices_.find(cmd_list);
XASSERT(it != slices_.end(), "slice for cmd_list %p not found", cmd_list);
for (auto cl : it->second->cmd_lists) {
res = Driver::CommandListClose(cl);
if (res != ZE_RESULT_SUCCESS) return res;
}
return res;
}
ze_result_t CommandListManager::Reset(ze_command_list_handle_t cmd_list)
{
ze_result_t res = Driver::CommandListReset(cmd_list);
if (res != ZE_RESULT_SUCCESS) return res;
if (GetSliceCmdCnt() == 0) return res;
std::lock_guard<std::mutex> lock(mtx_);
auto it = slices_.find(cmd_list);
XASSERT(it != slices_.end(), "slice for cmd_list %p not found", cmd_list);
for (auto cl : it->second->cmd_lists) {
res = Driver::CommandListReset(cl);
if (res != ZE_RESULT_SUCCESS) return res;
CommandListPool::Instance(cl)->Push(cl);
}
it->second->cmd_lists.clear();
it->second->cmd_cnt = 0;
return res;
}
ze_result_t CommandListManager::Append(ze_command_list_handle_t cmd_list, std::function<ze_result_t(ze_command_list_handle_t)> append_func)
{
if (GetSliceCmdCnt() == 0) return append_func(cmd_list);
std::lock_guard<std::mutex> lock(mtx_);
auto it = slices_.find(cmd_list);
XASSERT(it != slices_.end(), "slice for cmd_list %p not found", cmd_list);
if (it->second->cmd_cnt++ % GetSliceCmdCnt() == 0) {
auto new_cmd_list = (ze_command_list_handle_t)CommandListPool::Instance(it->second->ctx, it->second->dev)->Pop();
it->second->cmd_lists.push_back(new_cmd_list);
}
return append_func(it->second->cmd_lists.back());
}
uint64_t CommandListManager::GetSliceCmdCnt()
{
static uint64_t slice_cmd_cnt = []() -> uint64_t {
uint64_t val = 0;
char *env = std::getenv(XSCHED_AUTO_XQUEUE_ENV_NAME);
if (env == nullptr || strlen(env) == 0 || strcmp(env, "0") == 0 || strcasecmp(env, "off") == 0) return 0;
char *str = std::getenv(XSCHED_LEVELZERO_SLICE_CNT_ENV_NAME);
if (str == nullptr) return 0;
try { val = std::stoll(str); } catch (...) { return 0; }
return val;
}();
return slice_cmd_cnt;
}
ze_result_t ImmediateCommandListManager::Create(ze_context_handle_t ctx, ze_device_handle_t dev, const ze_command_queue_desc_t *altdesc, ze_command_list_handle_t *cmd_list)
{
ze_result_t res = Driver::CommandListCreateImmediate(ctx, dev, altdesc, cmd_list);
if (res != ZE_RESULT_SUCCESS) return res;
std::lock_guard<std::mutex> lock(mtx_);
auto it = immediates_.find(*cmd_list);
XASSERT(it == immediates_.end(), "immediate command list %p already exists", *cmd_list);
immediates_.insert(*cmd_list);
return res;
}
bool ImmediateCommandListManager::Exists(ze_command_list_handle_t cmd_list)
{
std::lock_guard<std::mutex> lock(mtx_);
auto it = immediates_.find(cmd_list);
return it != immediates_.end();
}
ze_result_t ImmediateCommandListManager::Destroy(ze_command_list_handle_t cmd_list)
{
ze_result_t res = Driver::CommandListDestroy(cmd_list);
if (res != ZE_RESULT_SUCCESS) return res;
std::lock_guard<std::mutex> lock(mtx_);
auto it = immediates_.find(cmd_list);
XASSERT(it != immediates_.end(), "immediate command list %p not found", cmd_list);
immediates_.erase(it);
return res;
}