#include "AtbCommon.h"
namespace atb {
atb::Tensor AtTensor2AtbTensor(const at::Tensor at_tensor)
{
static std::map<at::ScalarType, aclDataType> dtype_map = {
{at::ScalarType::Bool, ACL_BOOL}, {at::ScalarType::Byte, ACL_UINT8},
{at::ScalarType::Char, ACL_INT8}, {at::ScalarType::Half, ACL_FLOAT16},
{at::ScalarType::Float, ACL_FLOAT}, {at::ScalarType::Int, ACL_INT32},
{at::ScalarType::Long, ACL_INT64}, {at::ScalarType::BFloat16, ACL_BF16},
{at::ScalarType::Double, ACL_DOUBLE}, {at::ScalarType::Short, ACL_INT16},
{at::ScalarType::ComplexHalf, ACL_COMPLEX32}, {at::ScalarType::ComplexFloat, ACL_COMPLEX64},
{at::ScalarType::ComplexDouble, ACL_COMPLEX128},
};
TORCH_CHECK(at_tensor.is_contiguous(), "at_tensor is not contiguous");
atb::Tensor tensor;
tensor.desc.format = atb::utils::GetFormatForAtb(at_tensor);
if (at_tensor.device().type() == at::kCPU) {
tensor.hostData = at_tensor.data_ptr();
} else {
tensor.deviceData = at_tensor.data_ptr();
}
tensor.desc.shape.dimNum = at_tensor.sizes().size();
for (uint64_t i = 0; i < at_tensor.sizes().size(); i++) {
tensor.desc.shape.dims[i] = at_tensor.sizes()[i];
}
auto dtype_iterator = dtype_map.find(at_tensor.scalar_type());
TORCH_CHECK(dtype_iterator != dtype_map.end(), "not support dtype: ", at_tensor.scalar_type());
tensor.desc.dtype = dtype_iterator->second;
tensor.dataSize = atb::Utils::GetTensorSize(tensor);
return tensor;
}
void RunAtbCmdV1(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name)
{
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
auto context_ptr = atb::utils::GetContext(stream);
atb::VariantPack variant_pack = paramsetter.variant_pack_;
uint64_t workspace_size = OperationSetup(variant_pack, op, context_ptr);
at::Tensor workspace_tensor;
void *workspace_ptr = nullptr;
if (workspace_size != 0) {
at::TensorOptions options = at::TensorOptions(c10::DeviceType::PrivateUse1);
workspace_tensor = at::empty({workspace_size}, options.dtype(at::kByte));
workspace_ptr = const_cast<void *>(workspace_tensor.storage().data());
}
const c10::SmallVector<at::Tensor, N>& cpu_tensors = paramsetter.tensor_maintainer_.cpu_tensors;
auto acl_call = [variant_pack, workspace_ptr, workspace_size, context_ptr, op, cpu_tensors]() -> int {
auto st = op->Execute(variant_pack, (uint8_t *)workspace_ptr, workspace_size, context_ptr);
DestroyOperation(op);
return st;
};
at_npu::native::OpCommand::RunOpApiV2(name, acl_call);
}
void RunAtbCmdV2(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name)
{
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
atb::VariantPack variant_pack = paramsetter.variant_pack_;
const c10::SmallVector<at::Tensor, N>& cpu_tensors = paramsetter.tensor_maintainer_.cpu_tensors;
auto acl_call = [op, variant_pack, stream, cpu_tensors]() -> int {
auto context_ptr = atb::utils::GetContext(stream);
uint64_t workspace_size = OperationSetup(variant_pack, op, context_ptr);
at::Tensor workspace_tensor;
void *workspace_ptr = nullptr;
if (workspace_size != 0) {
workspace_tensor = at_npu::native::allocate_workspace(workspace_size, stream);
workspace_ptr = const_cast<void *>(workspace_tensor.storage().data());
}
auto st = op->Execute(variant_pack, (uint8_t *)workspace_ptr, workspace_size, context_ptr);
return 0;
};
at_npu::native::OpCommand::RunOpApiV2(name, acl_call);
}
void RunAtbCmd(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name)
{
const auto is_capturing = static_cast<int>(c10_npu::currentStreamCaptureStatusMayInitCtx());
if (is_capturing) {
RunAtbCmdV1(op, paramsetter, name);
} else {
RunAtbCmdV2(op, paramsetter, name);
}
}
uint64_t GetAtbWorkspaceSizeCmd(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name)
{
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
auto context_ptr = atb::utils::GetContext(stream);
atb::VariantPack variant_pack = paramsetter.variant_pack_;
uint64_t workspace_size = OperationSetup(variant_pack, op, context_ptr);
return workspace_size;
}
void RunAtbCmdWithWorkspaceV1(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name, const at::Tensor &workspace_tensor)
{
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
auto context_ptr = atb::utils::GetContext(stream);
atb::VariantPack variant_pack = paramsetter.variant_pack_;
uint64_t fake_workspace_size = OperationSetup(variant_pack, op, context_ptr);
uint64_t workspace_size = static_cast<uint64_t>(workspace_tensor.numel() * workspace_tensor.element_size());
void *workspace_ptr = const_cast<void *>(workspace_tensor.storage().data());
const c10::SmallVector<at::Tensor, N>& cpu_tensors = paramsetter.tensor_maintainer_.cpu_tensors;
auto acl_call = [variant_pack, workspace_ptr, workspace_size, context_ptr, op, cpu_tensors]() -> int {
auto st = op->Execute(variant_pack, (uint8_t *)workspace_ptr, workspace_size, context_ptr);
DestroyOperation(op);
return st;
};
at_npu::native::OpCommand::RunOpApiV2(name, acl_call);
}
void RunAtbCmdWithWorkspaceV2(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name, const at::Tensor &workspace_tensor)
{
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
atb::VariantPack variant_pack = paramsetter.variant_pack_;
const c10::SmallVector<at::Tensor, N>& cpu_tensors = paramsetter.tensor_maintainer_.cpu_tensors;
uint64_t workspace_size = static_cast<uint64_t>(workspace_tensor.numel() * workspace_tensor.element_size());
void *workspace_ptr = const_cast<void *>(workspace_tensor.storage().data());
auto acl_call = [op, workspace_ptr, workspace_size, variant_pack, stream, cpu_tensors]() -> int {
auto context_ptr = atb::utils::GetContext(stream);
uint64_t fake_workspace_size = OperationSetup(variant_pack, op, context_ptr);
auto st = op->Execute(variant_pack, (uint8_t *)workspace_ptr, workspace_size, context_ptr);
return 0;
};
at_npu::native::OpCommand::RunOpApiV2(name, acl_call);
}
void RunAtbCmdWithWorkspace(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name, const at::Tensor &workspace_tensor)
{
const auto is_capturing = static_cast<int>(c10_npu::currentStreamCaptureStatusMayInitCtx());
if (is_capturing) {
RunAtbCmdWithWorkspaceV1(op, paramsetter, name, workspace_tensor);
} else {
RunAtbCmdWithWorkspaceV2(op, paramsetter, name, workspace_tensor);
}
}
ParamSetter& ParamSetter::Input(const at::Tensor &tensor, const bool &format_trans)
{
if (!tensor.defined()) {
variant_pack_.inTensors.push_back(atb::Tensor());
return *this;
}
at::Tensor new_tensor = tensor.contiguous();
if (format_trans) {
new_tensor = atb::utils::FormatTrans(new_tensor);
}
atb::Tensor atb_tensor;
if (new_tensor.device().type() == at::kCPU) {
auto tensor_clone = new_tensor.clone();
atb_tensor = AtTensor2AtbTensor(tensor_clone);
tensor_maintainer_.cpu_tensors.emplace_back(std::move(tensor_clone));
} else {
atb_tensor = AtTensor2AtbTensor(new_tensor);
tensor_maintainer_.contiguous_tensors.emplace_back(std::move(new_tensor));
}
variant_pack_.inTensors.push_back(atb_tensor);
return *this;
}
ParamSetter& ParamSetter::Input(const c10::optional<at::Tensor> &tensor, const bool &format_trans)
{
if (!tensor.has_value()) {
variant_pack_.inTensors.push_back(atb::Tensor());
return *this;
}
return Input(tensor.value(), format_trans);
}
ParamSetter& ParamSetter::Output(at::Tensor &output)
{
auto atb_tensor = AtTensor2AtbTensor(output);
variant_pack_.outTensors.push_back(atb_tensor);
return *this;
}
uint64_t OperationSetup(atb::VariantPack variant_pack, atb::Operation *operation, atb::Context* context_ptr)
{
uint64_t workspace_size = 0;
atb::Status status = operation->Setup(variant_pack, workspace_size, context_ptr);
TORCH_CHECK(status == 0, operation -> GetName(), " setup failed!");
return workspace_size;
}
}