import pathlib
import argparse
import os
import stat
import re
from collections import namedtuple, Counter, defaultdict
from typing import List, Dict, Union, Sequence, Optional, Set, Callable
import yaml
import torchgen
from torchgen.code_template import CodeTemplate
from torchgen.gen import (parse_tags_yaml, FileManager, parse_native_yaml,
get_grouped_native_functions, error_check_native_functions)
from torchgen.model import (BackendIndex, DispatchKey,
NativeFunction, NativeFunctionsGroup, OperatorName,
BackendMetadata, is_cuda_dispatch_key)
from torchgen.native_function_generation import add_generated_native_functions
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import Target, concatMap, context, NamespaceHelper
import torchgen.dest as dest
import torchgen.api.dispatcher as dispatcher
import torchgen.api.native as native
from torchgen.api.cpp import JIT_TO_CPP_DEFAULT
from torchgen.gen_backend_stubs import gen_dispatchkey_nativefunc_headers
from codegen.gen_functionalization_type import gen_functionalization_definition, gen_functionalization_registration
from codegen.utils import (get_torchgen_dir, rename_privateuse1_dispatch_key, gen_unstructured, add_header_to_template_file,
get_grouped_native_functions_optional_out, parse_npu_yaml, get_opplugin_wrap_name,
get_target_functions, merge_custom_yaml, field_tag, gen_custom_yaml_path,
update_opapi_info, is_opapi, update_internal_format_opapi_info, PathManager, filt_exposed_api, get_target_native_registration,
NativeFunctionsGroupOptionalOut, gen_device_check, filt_compositeimplicitautograd_api,
DEVICE_NOCHECK_SET)
from codegen.custom_functions import (parse_custom_yaml, gen_custom_trace, gen_custom_ops_patch,
gen_custom_functions_dispatch)
torchgen.model.dispatch_keys.append(torchgen.model.DispatchKey.AutogradPrivateUse1)
def create_backend_index(backend_ops: List[str],
symint_ops: Set[str],
dispatch_key: DispatchKey,
native_funcs_map: Dict[OperatorName, NativeFunction],
cpp_namespace: str,
) -> BackendIndex:
metadata: Dict[OperatorName, BackendMetadata] = {}
for op in backend_ops:
op_name = OperatorName.parse(op)
if op_name not in native_funcs_map:
raise KeyError(f"Found an invalid operator name: {op_name}")
kernel_name = dispatcher.name(native_funcs_map[op_name].func)
if op in symint_ops:
kernel_name += "_symint"
m = BackendMetadata(kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace)
metadata[op_name] = m
return BackendIndex(
dispatch_key=dispatch_key,
use_out_as_primary=False,
external=True,
device_guard=True,
index=metadata)
def check_grouped_native_functions(
backend_key: DispatchKey,
autograd_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]):
for g in grouped_native_functions:
if isinstance(g, NativeFunction):
forward_kernels = [] if backend_key is None else \
[m for m in [backend_indices[backend_key].get_kernel(g)] if m is not None]
backward_kernels = [] if autograd_key is None else \
[m for m in [backend_indices[autograd_key].get_kernel(g)] if m is not None]
else:
if backend_key is None:
forward_kernels = []
else:
forward_kernels = []
for f in g.functions():
kernel = backend_indices[backend_key].get_kernel(f)
if kernel is not None:
forward_kernels.append(kernel)
if autograd_key is None:
backward_kernels = []
else:
backward_kernels = []
for f in g.functions():
kernel = backend_indices[autograd_key].get_kernel(f)
if kernel is not None:
backward_kernels.append(kernel)
forward_kernels = [f for f in forward_kernels if f is not None]
backward_kernels = [f for f in backward_kernels if f is not None]
if len(forward_kernels) != 0 and len(backward_kernels) != 0:
raise ValueError(f'Currently, all variants of an op must either be registered to a backend key, \
or to a backend\'s autograd key. They cannot be mix and matched. If this is \
something you need, feel free to create an issue! {forward_kernels[0].kernel} \
is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".')
_GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
ParsedYaml = namedtuple('ParsedYaml', ['native_functions', 'backend_indices'])
def modify_func_in_native_yaml(func: str) -> str:
func_to_modify = {"matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)":
"matmul_backward(Tensor grad_out, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)"}
if func in func_to_modify:
return func_to_modify[func]
return func
def parse_native_and_custom_yaml(path: str, tag_path: str, custom_path: str) -> ParsedYaml:
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
valid_tags = parse_tags_yaml(tag_path)
PathManager.check_directory_path_readable(path)
with open(path, 'r') as f:
es = yaml.safe_load(f)
if not isinstance(es, list):
raise TypeError("es is not list")
rs: List[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es:
e["func"] = modify_func_in_native_yaml(e["func"])
func, m = NativeFunction.from_yaml(e, "Location", valid_tags)
rs.append(func)
BackendIndex.grow_index(bs, m)
source_es = parse_npu_yaml(custom_path)
custom_es = source_es.get('custom', []) + source_es.get('custom_autograd', [])
supported_es = source_es.get('supported', []) + source_es.get('autograd', []) + custom_es
for es in supported_es:
update_opapi_info(es)
update_internal_format_opapi_info(es)
custom_es = field_tag(custom_es)
for e in custom_es:
func, m = NativeFunction.from_yaml(e, "Location", valid_tags)
rs.append(func)
BackendIndex.grow_index(bs, m)
error_check_native_functions(rs)
indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex(
dispatch_key=DispatchKey.Undefined,
use_out_as_primary=True,
device_guard=True,
external=False,
index={}))
add_generated_native_functions(rs, bs)
for k, v in bs.items():
indices[k] = BackendIndex(dispatch_key=k,
use_out_as_primary=True,
external=False,
device_guard=is_cuda_dispatch_key(k),
index=v)
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = ParsedYaml(rs, indices)
return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
ParsedExternalYaml = namedtuple('ParsedExternalYaml', [
'true_backend', 'backend_key', 'autograd_key', 'cpp_namespace', 'backend_indices'])
def parse_backend_yaml(
native_yaml_path:str,
backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex]
) -> ParsedExternalYaml:
native_functions_map = {}
for f in grouped_native_functions:
if isinstance(f, NativeFunction):
native_functions_map[f.func.name] = f
else:
for func in f.functions():
native_functions_map[func.func.name] = func
PathManager.check_directory_path_readable(backend_yaml_path)
with open(backend_yaml_path, 'r') as f:
yaml_values = yaml.safe_load(f)
if not isinstance(yaml_values, dict):
raise TypeError("yaml_values is not dict")
valid_keys = ['backend', 'cpp_namespace', 'supported', 'autograd', 'custom', 'custom_autograd', 'symint', 'quant']
yaml_backend = yaml_values.pop('backend', None)
true_backend = 'PrivateUse1' if yaml_backend == 'NPU' else yaml_backend
if true_backend is None:
raise ValueError("You must provide a value for 'backend'")
backend = "NPU"
cpp_namespace = yaml_values.pop('cpp_namespace', None)
if cpp_namespace is None:
raise ValueError("You must provide a value for 'cpp_namespace'")
supported = yaml_values.pop('supported', [])
if supported is None:
supported = []
if not isinstance(supported, list):
raise TypeError(f'expected "supported" to be a list, but got type {type(supported)}')
symint = yaml_values.pop("symint", [])
if symint is None:
symint = []
if not (isinstance(symint, list)):
raise RuntimeError(f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})')
symint = [op['func'].split("(")[0] if isinstance(op, Dict) else op for op in symint]
symint_set = set(symint)
supported_autograd = yaml_values.pop('autograd', [])
if not isinstance(supported_autograd, list):
raise TypeError(f'expected "autograd" to be a list, but got: {supported_autograd}')
supported_list = []
for op in supported:
if isinstance(op, Dict) and op.get('device_check', None) == 'NoCheck':
DEVICE_NOCHECK_SET.add(op['func'].split("(")[0])
if isinstance(op, Dict) and ({"impl_ns", "op_api", "device_check"} & set(op.keys())):
supported_list.append(op['func'].split("(")[0])
elif not isinstance(op, Dict):
supported_list.append(op)
supported = supported_list
supported_autograd = [op['func'].split("(")[0] if isinstance(op, Dict) else op for op in supported_autograd]
supported_autograd += filt_compositeimplicitautograd_api(native_yaml_path, supported)
custom = yaml_values.pop('custom', [])
if not isinstance(custom, list):
raise TypeError(f'expected "autograd" to be a list, but got: {custom}')
for item in custom:
try:
supported.append(item['func'][:item['func'].index('(')])
except ValueError as e:
raise Exception(f'Wrong format for function: {item["func"]}') from e
custom_autograd = yaml_values.pop('custom_autograd', [])
if not isinstance(custom_autograd, list):
raise TypeError(f'expected "autograd" to be a list, but got: {custom_autograd}')
for item in custom_autograd:
supported_autograd.append(item['func'][:item['func'].index('(')])
quant = yaml_values.pop('quant', [])
if not isinstance(quant, list):
raise TypeError(f'expected "quant" to be a list, but got: {quant}')
quant = [op['func'].split("(")[0] if isinstance(op, Dict) else op for op in quant]
yaml_values.pop('custom_supported', [])
if (len(yaml_values.keys()) > 0):
raise KeyError(f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
Only the following keys are supported: {", ".join(valid_keys)}')
backend_key: Optional[DispatchKey] = None
opapi_key = "OpApi"
if len(supported) > 0:
with context(lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'):
backend_key = DispatchKey.parse(backend)
backend_idx = create_backend_index(supported, symint_set, backend_key, native_functions_map, cpp_namespace)
opapi_backend_idx = create_backend_index([op for op in supported if is_opapi(op)],
symint_set, backend_key, native_functions_map, cpp_namespace)
if backend_key in backend_indices:
raise KeyError("backend_key should not be in backend_indices.")
backend_indices[backend_key] = backend_idx
backend_indices[str(backend_key) + opapi_key] = opapi_backend_idx
autograd_key: Optional[DispatchKey] = None
if len(supported_autograd) > 0:
with context(lambda: f'The "autograd" key was specified, which indicates that you would like to override \
the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'):
autograd_key = DispatchKey.parse(f'Autograd{backend}')
autograd_idx = create_backend_index(supported_autograd, symint_set,
autograd_key, native_functions_map, cpp_namespace)
opapi_autograd_idx = create_backend_index([op for op in supported_autograd if is_opapi(op)],
symint_set, autograd_key, native_functions_map, cpp_namespace)
backend_indices[autograd_key] = autograd_idx
backend_indices[str(autograd_key) + opapi_key] = opapi_autograd_idx
quant_key = "Quantize"
if len(quant) > 0:
quant_idx = create_backend_index(quant, symint_set, backend_key, native_functions_map, cpp_namespace)
if quant_key in backend_indices:
raise KeyError("quant_key should not be in backend_indices.")
backend_indices[str(backend_key) + quant_key] = quant_idx
return ParsedExternalYaml(true_backend, backend_key, autograd_key, cpp_namespace, backend_indices)
def op_plugin_kernel_conut(op_plugin_ops_dir: str):
actual_backend_kernel_name_counts = Counter()
file_path = os.path.join(op_plugin_ops_dir, "OpInterface.h")
PathManager.check_directory_path_readable(file_path)
try:
with open(file_path, 'r') as f:
backend_defns = f.read()
except IOError as e:
raise AssertionError(f'Unable to read from the specified impl_path file: {file_path}') from e
kernel_defn_regex = rf'\w+(?=\()'
actual_backend_kernel_name_counts += Counter(re.findall(kernel_defn_regex, backend_defns))
return actual_backend_kernel_name_counts
def pta_kernel_conut(class_name: str, pta_op_dir: str):
actual_backend_kernel_name_counts = Counter()
for cur_dir, _, filenames in os.walk(pta_op_dir):
for filename in filenames:
if not filename.endswith('.cpp'):
continue
file_path = os.path.join(cur_dir, filename)
PathManager.check_directory_path_readable(file_path)
try:
with open(file_path, 'r') as f:
backend_defns = f.read()
except IOError:
raise AssertionError(f'Unable to read from the specified impl_path file: {file_path}')
kernel_defn_regex = rf'{class_name}::([\w\d]*)\([^\)]*\)\s*{{'
actual_backend_kernel_name_counts += Counter(re.findall(kernel_defn_regex, backend_defns))
return actual_backend_kernel_name_counts
def check_op_plugin_kernels(
native_functions: Sequence[NativeFunction],
expected_kernel_counts: Dict[str, List[NativeFunction]],
actual_kernel_counts: Dict[str, List[NativeFunction]]):
for f in native_functions:
wrap_name = get_opplugin_wrap_name(f)
expect_op_plugin_kernel_count = len(expected_kernel_counts[wrap_name])
if expect_op_plugin_kernel_count > actual_kernel_counts[wrap_name]:
return False
return True
def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files')
parser.add_argument(
'-s',
'--source_yaml',
help='path to source yaml file containing operator external definitions')
parser.add_argument(
'-o', '--output_dir', help='output directory')
parser.add_argument(
'--dry_run', type=bool, default=False, help='output directory')
parser.add_argument(
'--impl_path', type=str, default=None, help='path to the source C++ file containing kernel definitions')
parser.add_argument(
'--op_plugin_impl_path', type=str, default=None,
help='path to the source C++ file containing kernel definitions in op_plugin')
parser.add_argument(
'--op_plugin_yaml_path', type=str, default=None,
help='path to the source yaml file containing kernel definitions in op_plugin')
options = parser.parse_args()
run(options.source_yaml, options.output_dir, options.dry_run,
options.impl_path, options.op_plugin_impl_path, options.op_plugin_yaml_path)
def gen_dispatcher_registrations(
fm: FileManager,
class_name: str,
backend_indices: Dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_dispatch_key: DispatchKey,
dispatch_key: DispatchKey,
selector: "SelectiveBuilder",
dispatch_key_name: str,
register_dispatch_key_func: Callable,
native_function_registrations: str = '',
):
backend_index = backend_indices[backend_dispatch_key]
ns_helper = NamespaceHelper(namespace_str="at")
native_func_header = """\
#include "torch_npu/csrc/core/npu/NPURecovery.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#ifndef BUILD_LIBTORCH
#include "torch_npu/csrc/profiler/utils.h"
#endif
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/interface/EnvVariables.h"
#include "torch_npu/csrc/aten/NPUOpApiNativeFunctions.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/framework/utils/ForceAclnnList.h"
#include "torch_npu/csrc/framework/OpHook.h"
#include "op_plugin/OpInterface.h"
"""
static_template = CodeTemplate(
"""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
$dispatch_registrations_body
};"""
)
static_init_dispatch_registrations = static_template.substitute(
dispatch_key=dispatch_key_name,
dispatch_registrations_body=list(
concatMap(
register_dispatch_key_func(
backend_index,
Target.REGISTRATION,
selector,
rocm=False,
symint=True,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
),
)
fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': '',
'external_backend_headers': native_func_header,
'namespaced_headers': '',
'DispatchKey': dispatch_key,
'dispatch_headers': dest.gen_registration_headers(
backend_index, per_operator_headers=False, rocm=False
),
'ops_headers': '',
'dispatch_helpers': dest.gen_registration_helpers(backend_index),
'dispatch_definitions': fm.substitute_with_template(
'RegisterDispatchDefinitions.ini',
lambda: {
'ns_prologue': ns_helper.prologue,
'ns_epilogue': ns_helper.epilogue,
'static_init_dispatch_registrations': static_init_dispatch_registrations,
'deferred_dispatch_registrations': '',
'dispatch_namespace': dispatch_key.lower(),
'dispatch_namespaced_definitions': native_function_registrations,
'dispatch_anonymous_definitions': list(
concatMap(
register_dispatch_key_func(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=False,
symint=True,
class_method_name=f'{class_name}',
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
),
},
).split('\n'),
})
def get_supported_grouped_native_functions(
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_index: BackendIndex,
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
supported_grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]] = []
for funcs in grouped_native_functions:
if isinstance(funcs, NativeFunctionsGroup) and not backend_index.has_kernel(funcs.out):
for f in funcs.functions():
if backend_index.has_kernel(f):
supported_grouped_native_functions.append(f)
continue
supported_grouped_native_functions.append(funcs)
return supported_grouped_native_functions
def gen_foreach_register(
fm: FileManager,
tags_yaml_path: str,
native_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: BackendIndex,
):
cpu_backend_indices = parse_native_yaml(native_yaml_path, tags_yaml_path).backend_indices[DispatchKey.CPU]
foreach_dict: Dict[str, str] = {}
header_set = set()
def get_foreach_kernel(func: NativeFunction):
schema = func.func.name
if not str(schema).startswith("_foreach"):
return
if schema in cpu_backend_indices.index and schema not in backend_indices.index:
foreach_dict[str(schema)] = cpu_backend_indices.index[schema].kernel
for f in grouped_native_functions:
if isinstance(f, NativeFunctionsGroup):
header_set.add(str(f.signature().name.name.base))
for func in f.functions():
get_foreach_kernel(func)
else:
header_set.add(str(f.func.name.name.base))
get_foreach_kernel(f)
kernel_template = CodeTemplate(
"""\
m.impl("${schema}", TORCH_FN(at::native::${kernel}));"""
)
header_template = CodeTemplate(
"""\
#include <ATen/ops/${function}_native.h>"""
)
fm.write_with_template(f'ForeachRegister.cpp', 'ForeachRegister.cpp', lambda: {
'include_headers': [header_template.substitute(function=h) for h in header_set if h.startswith("_foreach")],
'foreach_kernel': [kernel_template.substitute(schema=kv[0], kernel=kv[1]) for kv in foreach_dict.items()]
})
def gen_quantize_register(
fm: FileManager,
backend_indices: BackendIndex,
):
ns_helper = NamespaceHelper(namespace_str="at")
quantize_dict: Dict[str, str] = {}
for op_name, metadata in backend_indices.index.items():
quantize_dict[op_name] = metadata.kernel
native_func_header = """\
#include <ATen/ops/quantize_per_tensor.h>
#include "op_plugin/OpInterface.h"
"""
static_template = CodeTemplate(
"""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
$dispatch_registrations_body
m.impl("q_scale", TORCH_FN(at::native::q_scale_quant));
m.impl("q_per_channel_scales", TORCH_FN(at::native::q_per_channel_scales));
m.impl("q_zero_point", TORCH_FN(at::native::q_zero_point_quant));
m.impl("q_per_channel_zero_points", TORCH_FN(at::native::q_per_channel_zero_points));
m.impl("q_per_channel_axis", TORCH_FN(at::native::q_per_channel_axis));
m.impl("qscheme", TORCH_FN(at::native::qscheme_quant));
};"""
)
kernel_template = CodeTemplate(
"""\
m.impl("${schema}", TORCH_FN(op_plugin::${kernel}));"""
)
static_init_dispatch_registrations = static_template.substitute(
dispatch_key="QuantizedPrivateUse1",
dispatch_registrations_body=[kernel_template.substitute(schema=kv[0], kernel=kv[1]) for kv in quantize_dict.items()]
)
fm.write_with_template(f'QuantizedRegister.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': '',
'external_backend_headers': native_func_header,
'namespaced_headers': '',
'DispatchKey': 'NPU',
'dispatch_headers': '',
'ops_headers': '',
'dispatch_helpers': '',
'dispatch_definitions': fm.substitute_with_template(
'RegisterDispatchDefinitions.ini',
lambda: {
'ns_prologue': ns_helper.prologue,
'ns_epilogue': ns_helper.epilogue,
'static_init_dispatch_registrations': static_init_dispatch_registrations,
'deferred_dispatch_registrations': '',
'dispatch_namespace': '',
'dispatch_namespaced_definitions': '',
'dispatch_anonymous_definitions': '',
},
).split('\n'),
})
def gen_functionalization(fm: FileManager,
selector: "SelectiveBuilder",
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroupOptionalOut]],
):
def key_func(
fn: Union[NativeFunction, NativeFunctionsGroupOptionalOut]
) -> str:
return fn.root_name
def functionalization_env_callable(g):
definition = gen_functionalization_definition(selector, g)
register = gen_functionalization_registration(selector, g)
return {
"func_definitions": definition,
"func_registrations": register,
}
fm.write_sharded(
"RegisterFunctionalization.cpp",
grouped_native_functions,
key_fn=key_func,
env_callable=functionalization_env_callable,
num_shards=2,
sharded_keys={
"func_definitions",
"func_registrations",
},
)
return
def gen_target_registration(
target_op_type: str,
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
op_plugin_yaml_path: str,
fm: FileManager,
selector: "SelectiveBuilder",
native_functions: List[NativeFunction] = None,
):
target_ops = get_target_functions(op_plugin_yaml_path, target_op_type=target_op_type)
target_native_functions = []
for f in grouped_native_functions:
if isinstance(f, NativeFunctionsGroup):
for func in f.functions():
if func.func in target_ops:
target_native_functions.append(func)
elif f.func in target_ops:
target_native_functions.append(f)
metadata: Dict[OperatorName, BackendMetadata] = {}
for op in target_ops:
kernel_name = dispatcher.name(op)
metadata[op.name] = BackendMetadata(kernel=kernel_name, structured=False, cpp_namespace=target_op_type)
backend_indices[dispatch_key] = BackendIndex(
dispatch_key=dispatch_key,
use_out_as_primary=False,
external=True,
device_guard=True,
index=metadata)
native_registration = get_target_native_registration(dispatch_key, backend_indices, metadata, native_functions)
gen_dispatcher_registrations(
fm,
backend_indices[dispatch_key].native_function_class_name(),
backend_indices,
target_native_functions,
dispatch_key,
dispatch_key,
selector,
dispatch_key_name=dispatch_key.name,
register_dispatch_key_func=dest.RegisterDispatchKey,
native_function_registrations=native_registration,
)
def run(source_yaml: str, output_dir: str, dry_run: bool,
impl_path: Optional[str], op_plugin_impl_path: Optional[str], op_plugin_yaml_path: Optional[str]) -> None:
rename_privateuse1_dispatch_key()
torchgen_path = get_torchgen_dir()
template_dir = os.path.join(torchgen_path, "packaged/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=dry_run)
fm = make_file_manager(output_dir)
merge_custom_yaml(source_yaml, op_plugin_yaml_path)
source_yaml = gen_custom_yaml_path(source_yaml)
tags_yaml_path = os.path.join(torchgen_path, 'packaged/ATen/native/tags.yaml')
native_yaml_path = os.path.join(torchgen_path, 'packaged/ATen/native/native_functions.yaml')
parsed_yaml = parse_native_and_custom_yaml(native_yaml_path, tags_yaml_path, source_yaml)
get_target_functions(op_plugin_yaml_path)
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
grouped_native_functions = get_grouped_native_functions(native_functions)
parsed_backend_yaml = parse_backend_yaml(native_yaml_path, source_yaml, grouped_native_functions, backend_indices)
true_backend = parsed_backend_yaml.true_backend
backend_key = parsed_backend_yaml.backend_key
autograd_key = parsed_backend_yaml.autograd_key
cpp_namespace = parsed_backend_yaml.cpp_namespace
backend_indices = parsed_backend_yaml.backend_indices
selector = SelectiveBuilder.get_nop_selector()
if backend_key is not None:
backend_dispatch_key: DispatchKey = backend_key
autograd_dispatch_key: DispatchKey = autograd_key
class_name = backend_indices[backend_dispatch_key].native_function_class_name()
gen_dispatchkey_nativefunc_headers(
fm,
class_name,
cpp_namespace,
backend_indices,
grouped_native_functions,
backend_key,
None,
)
gen_dispatchkey_nativefunc_headers(
fm,
"NPUNativeOpApiFunctions",
cpp_namespace,
backend_indices,
grouped_native_functions,
str(backend_key) + "OpApi",
None,
)
for dispatch_key in [backend_dispatch_key, autograd_dispatch_key]:
if not dispatch_key:
continue
gen_dispatcher_registrations(
fm,
class_name,
backend_indices,
get_supported_grouped_native_functions(grouped_native_functions, backend_indices[dispatch_key]),
dispatch_key,
dispatch_key,
selector,
dispatch_key_name=dispatch_key.name.replace("NPU", true_backend),
register_dispatch_key_func=dest.RegisterDispatchKey,
)
gen_quantize_register(fm, backend_indices["NPUQuantize"])
pta_template_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "templates")
fm = FileManager(install_dir=output_dir, template_dir=pta_template_dir, dry_run=dry_run)
custom_functions, custom_backend_indices = parse_custom_yaml(source_yaml, tags_yaml_path)
grouped_custom_functions = get_grouped_native_functions_optional_out(custom_functions)
gen_functionalization(fm, selector, grouped_custom_functions)
gen_custom_trace(fm, custom_functions, custom_backend_indices)
gen_custom_functions_dispatch(fm, custom_functions)
gen_foreach_register(fm,
tags_yaml_path,
native_yaml_path,
grouped_native_functions,
backend_indices[backend_dispatch_key])
custom_ops_patch_dir = os.path.join(output_dir, "../../utils/")
fm = FileManager(install_dir=custom_ops_patch_dir, template_dir=pta_template_dir, dry_run=dry_run)
gen_custom_ops_patch(fm, custom_functions)
filt_exposed_list = filt_exposed_api(source_yaml)
exposed_path = pathlib.Path(__file__).parents[1].joinpath('torch_npu/utils/exposed_api.py')
PathManager.remove_path_safety(exposed_path)
with os.fdopen(os.open(exposed_path, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as f:
f.write(f'public_npu_functions = {filt_exposed_list}')
os.chmod(exposed_path, stat.S_IRUSR | stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
fm = make_file_manager(output_dir)
gen_target_registration(
"sparse",
DispatchKey.SparsePrivateUse1,
backend_indices,
grouped_native_functions,
op_plugin_yaml_path,
fm,
selector,
native_functions
)
gen_target_registration(
"sparse_csr",
DispatchKey.SparseCsrPrivateUse1,
backend_indices,
grouped_native_functions,
op_plugin_yaml_path,
fm,
selector,
native_functions
)
def apply_torchgen_patch():
dest.RegisterDispatchKey.gen_unstructured = gen_unstructured
dest.RegisterDispatchKey.gen_device_check = gen_device_check
JIT_TO_CPP_DEFAULT["contiguous_format"] = "c10::MemoryFormat::Contiguous"
add_header_to_template_file()
dispatcher.arguments = native.arguments
if __name__ == '__main__':
apply_torchgen_patch()
main()