#include "xsched/utils/xassert.h"
#include "xsched/ascend/hal/event_pool.h"
#include "xsched/ascend/hal/acl_command.h"
using namespace xsched::ascend;
AclCommand::~AclCommand()
{
if (following_event_ == nullptr) return;
EventPool::Instance().Push(following_event_);
}
void AclCommand::Synchronize()
{
XASSERT(following_event_ != nullptr,
"following_event_ is nullptr, EnableSynchronization() should be called first");
ACL_ASSERT(Driver::rtSynchronizeEvent(following_event_));
}
bool AclCommand::Synchronizable()
{
return following_event_ != nullptr;
}
bool AclCommand::EnableSynchronization()
{
following_event_ = (aclrtEvent)EventPool::Instance().Pop();
return following_event_ != nullptr;
}
aclError AclCommand::LaunchWrapper(aclrtStream stream)
{
aclError ret = Launch(stream);
if (UNLIKELY(ret != ACL_SUCCESS)) return ret;
if (following_event_ != nullptr) ret = Driver::rtRecordEvent(following_event_, stream);
return ret;
}
AclEventRecordCommand::AclEventRecordCommand(aclrtEvent event): event_(event)
{
XASSERT(event_ != nullptr, "aclrtEvent should not be nullptr");
this->SetProps(preempt::kCommandPropertyIdempotent);
}
std::mutex TensorDesc::tensor_desc_mutex_;
std::mutex DataBuffer::data_buffer_mutex_;
std::mutex OpAttr::op_attr_mutex_;
std::unordered_map<const aclTensorDesc *, std::shared_ptr<TensorDesc>> TensorDesc::tensor_descs_;
std::unordered_map<const aclDataBuffer *, std::shared_ptr<DataBuffer>> DataBuffer::data_buffers_;
std::unordered_map<const aclopAttr *, std::shared_ptr<OpAttr>> OpAttr::op_attrs_;
std::shared_ptr<TensorDesc> TensorDesc::Create(const aclTensorDesc *desc)
{
std::lock_guard<std::mutex> lock(tensor_desc_mutex_);
auto it = tensor_descs_.find(desc);
if (it != tensor_descs_.end()) return it->second;
auto tensor_desc = std::make_shared<TensorDesc>();
tensor_desc->desc_ = desc;
tensor_descs_[desc] = tensor_desc;
return tensor_desc;
}
bool TensorDesc::Destroy(const aclTensorDesc *desc)
{
std::unique_lock<std::mutex> lock(tensor_desc_mutex_);
auto it = tensor_descs_.find(desc);
if (it == tensor_descs_.end()) return false;
auto tensor_desc = it->second;
tensor_descs_.erase(it);
lock.unlock();
tensor_desc = nullptr;
return true;
}
std::shared_ptr<DataBuffer> DataBuffer::Create(const aclDataBuffer *buffer)
{
std::lock_guard<std::mutex> lock(data_buffer_mutex_);
auto it = data_buffers_.find(buffer);
if (it != data_buffers_.end()) return it->second;
auto data_buffer = std::make_shared<DataBuffer>();
data_buffer->buffer_ = buffer;
data_buffers_[buffer] = data_buffer;
return data_buffer;
}
bool DataBuffer::Destroy(const aclDataBuffer *buffer)
{
std::unique_lock<std::mutex> lock(data_buffer_mutex_);
auto it = data_buffers_.find(buffer);
if (it == data_buffers_.end()) return false;
auto data_buffer = it->second;
data_buffers_.erase(it);
lock.unlock();
data_buffer = nullptr;
return true;
}
std::shared_ptr<OpAttr> OpAttr::Create(const aclopAttr *attr)
{
std::lock_guard<std::mutex> lock(op_attr_mutex_);
auto it = op_attrs_.find(attr);
if (it != op_attrs_.end()) return it->second;
auto op_attr = std::make_shared<OpAttr>();
op_attr->attr_ = attr;
op_attrs_[attr] = op_attr;
return op_attr;
}
bool OpAttr::Destroy(const aclopAttr *attr)
{
std::unique_lock<std::mutex> lock(op_attr_mutex_);
auto it = op_attrs_.find(attr);
if (it == op_attrs_.end()) return false;
auto op_attr = it->second;
op_attrs_.erase(it);
lock.unlock();
op_attr = nullptr;
return true;
}
aclError AclOpCompileAndExecuteCommand::Launch(aclrtStream stream)
{
int num_inputs = inputDesc_->size();
int num_outputs = outputDesc_->size();
const aclTensorDesc **input_desc = (const aclTensorDesc**)malloc(num_inputs * sizeof(aclTensorDesc*));
const aclDataBuffer **inputs = (const aclDataBuffer**)malloc(num_inputs * sizeof(aclDataBuffer*));
const aclTensorDesc **output_desc = (const aclTensorDesc**)malloc(num_outputs * sizeof(aclTensorDesc*));
aclDataBuffer **outputs = ( aclDataBuffer**)malloc(num_outputs * sizeof(aclDataBuffer*));
for (int i = 0; i < num_inputs; ++i) {
input_desc[i] = inputDesc_->at(i)->desc();
inputs[i] = inputs_->at(i)->buffer();
}
for (int i = 0; i < num_outputs; ++i) {
output_desc[i] = outputDesc_->at(i)->desc();
outputs[i] = (aclDataBuffer *)outputs_->at(i)->buffer();
}
return OpCompiler::opCompileAndExecute(opType_,
num_inputs, input_desc, inputs,
num_outputs, output_desc, outputs,
attr_->attr(), engineType_, compileFlag_, opPath_,
stream);
}