#include "xsched/protocol/device.h"
#include "xsched/ascend/hal.h"
#include "xsched/ascend/hal/driver.h"
#include "xsched/ascend/hal/acl_queue.h"
#include "xsched/ascend/hal/acl_assert.h"
#include "xsched/ascend/hal/acl_command.h"
using namespace xsched::ascend;
using namespace xsched::preempt;
using namespace xsched::protocol;
AclQueue::AclQueue(aclrtStream stream): kStream(stream)
{
ACL_ASSERT(Driver::rtGetCurrentContext(&context_));
ACL_ASSERT(Driver::rtGetDevice(&device_id_));
device_ = MakeDevice(kDeviceTypeNPU, XDeviceId(device_id_));
}
void AclQueue::Launch(std::shared_ptr<preempt::HwCommand> hw_cmd)
{
auto acl_cmd = std::dynamic_pointer_cast<AclCommand>(hw_cmd);
XASSERT(acl_cmd != nullptr, "hw_cmd is not an AclCommand");
ACL_ASSERT(acl_cmd->LaunchWrapper(kStream));
}
void AclQueue::Synchronize()
{
aclrtContext cur_ctx = nullptr;
ACL_ASSERT(Driver::rtGetCurrentContext(&cur_ctx));
if (cur_ctx != context_) {
XWARN("stream context (%p) != current context (%p), override current", context_, cur_ctx);
ACL_ASSERT(Driver::rtSetCurrentContext(context_));
}
ACL_ASSERT(Driver::rtSynchronizeStream(kStream));
}
void AclQueue::OnXQueueCreate()
{
ACL_ASSERT(Driver::rtSetCurrentContext(context_));
}
EXPORT_C_FUNC XResult AclQueueCreate(HwQueueHandle *hwq, aclrtStream stream)
{
if (hwq == nullptr) {
XWARN("AclQueueCreate failed: hwq is nullptr");
return kXSchedErrorInvalidValue;
}
if (stream == nullptr) {
XWARN("AclQueueCreate failed: stream is nullptr");
return kXSchedErrorInvalidValue;
}
HwQueueHandle hwq_h = GetHwQueueHandle(stream);
auto res = HwQueueManager::Add(hwq_h, [&]() { return std::make_shared<AclQueue>(stream); });
if (res == kXSchedSuccess) *hwq = hwq_h;
return res;
}