import torch
from torch_npu.npu import device_count
from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device
from torch_npu.utils._inductor import NPUDeviceOpOverrides
from . import config as npu_config
class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides):
def import_get_raw_stream_as(self, name):
return f"from torch_npu._C import _npu_getCurrentRawStream as {name}"
def set_device(self, device_idx):
return f"torch.npu.set_device({device_idx})"
def synchronize(self):
return """
stream = torch.npu.current_stream()
stream.synchronize()
"""
def device_guard(self, device_idx):
return f"torch.npu.utils.device({device_idx})"
def cpp_aoti_device_guard(self):
return "AOTINpuGuard"
def cpp_aoti_stream_guard(self):
return "AOTINpuStreamGuard"
def kernel_driver(self):
source_code = """
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
extern "C" {
typedef int (* callback)(unsigned int type, void* data, unsigned int len);
extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api);
extern unsigned long int MsprofSysCycleTime();
extern int MsprofRegisterCallback(unsigned int moduleId, callback handle);
static unsigned int __MsprofFlagL0 = 0;
static unsigned int __MsprofFlagL1 = 0;
int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) {
if ((CtrlData == nullptr) || (DataLen == 0U)) {
return 1;
}
if (CtrlType == 1) {
MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData);
if (handle->type >= 6) // 6 is not used here
return 1;
if (handle->type == 1) { // init - 0 , start - 1
__MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0;
__MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0;
}
}
return 0;
}
}
"""
load_code = """
static std::unordered_map<std::string, size_t> registered_names;
static std::unordered_map<std::string, std::unique_ptr<size_t>> func_stubs;
static inline void * loadKernel(
std::string filePath,
const std::string &&nameFuncMode,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
std::string funcName;
std::string kernel_mode_str;
size_t spacePos = nameFuncMode.find(' ');
if (spacePos != std::string::npos) {
funcName = nameFuncMode.substr(0, spacePos);
kernel_mode_str = nameFuncMode.substr(spacePos + 1);
} else {
size_t underLinePos = nameFuncMode.find_last_of('_');
if (underLinePos != std::string::npos) {
funcName = nameFuncMode.substr(0, underLinePos);
kernel_mode_str = nameFuncMode.substr(underLinePos + 1);
} else {
throw std::runtime_error(std::string("Parse kernel name failed, expect "
"'kernel_name_kernel mode' or 'kernel_name_kernel_mode', bug got: ") + nameFuncMode);
}
}
std::ifstream file(std::string(filePath), std::ios::binary | std::ios::ate);
if (!file.is_open()) {
throw std::runtime_error(std::string("open npubin failed"));
}
std::streamsize data_size = file.tellg();
file.seekg(0, std::ios::beg);
char* buffer = new char[data_size];
if (!file.read(buffer, data_size)) {
throw std::runtime_error(std::string("read npubin failed"));
}
rtError_t rtRet;
rtDevBinary_t devbin;
devbin.data = buffer;
devbin.length = data_size;
const std::string kernel_mode{kernel_mode_str};
if (kernel_mode == "aiv") {
devbin.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC;
} else {
devbin.magic = RT_DEV_BINARY_MAGIC_ELF;
}
devbin.version = 0;
int device = 0;
rtRet = rtSetDevice(device);
if (rtRet != RT_ERROR_NONE) {
throw std::runtime_error(std::string("rtSetDevice failed, 0x") + std::to_string(rtRet));
}
void *devbinHandle = NULL;
rtRet = rtDevBinaryRegister(&devbin, &devbinHandle);
if (rtRet != RT_ERROR_NONE) {
throw std::runtime_error(std::string("rtDevBinaryRegister failed, 0x") + std::to_string(rtRet));
}
const char* name = funcName.c_str();
std::string stubName(name);
stubName += "_" + std::to_string(registered_names[name]);
registered_names[name]++;
auto registered = func_stubs.emplace(stubName, std::make_unique<size_t>(0));
void *func_stub_handle = registered.first->second.get();
rtRet = rtFunctionRegister(devbinHandle, func_stub_handle, stubName.c_str(),
(void *)name, 0);
if (rtRet != RT_ERROR_NONE) {
throw std::runtime_error(std::string("rtFunctionRegister failed, stubName = ") + stubName
+ std::string(" , 0x") + std::to_string(rtRet));
}
return func_stub_handle;
}
"""
launch_code = """
static inline void launchKernel(
std::function<int()> launch_call,
std::string&& kernel_name) {
launch_call();
}
""" if npu_config.aot_inductor.debug_kernel else """
static inline void launchKernel(
std::function<int()> launch_call,
std::string&& kernel_name) {
at_npu::native::OpCommand cmd;
cmd.Name(kernel_name.c_str())
.SetCustomHandler(launch_call)
.Run();
}
"""
extra_code = ""
source_codes = source_code + load_code + launch_code + extra_code
return source_codes
def abi_compatible_header(self):
return """
#include <fstream>
#include <vector>
#include <iostream>
#include <string>
#include <tuple>
#include <unordered_map>
#include <memory>
#include <filesystem>
#include <assert.h>
#include <stdbool.h>
#include <sys/syscall.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include "runtime/runtime/rt.h"
"""
def cpp_stream_type(self):
return "aclrtStream"
def aoti_get_stream(self):
return "aoti_torch_get_current_npu_stream"
def cpp_kernel_type(self):
return "void *"
def cpp_device_ptr(self):
return "void*"