from pathlib import Path
from typing import List, Dict
import yaml
from torchgen.model import NativeFunction, FunctionSchema
from torchgen.api.autograd import (
match_differentiability_info, NativeFunctionWithDifferentiabilityInfo,
DifferentiabilityInfo
)
from torchgen.packaged.autograd.load_derivatives import load_derivatives
from codegen.utils import get_torchgen_dir, CUSTOM_YAML_NAME, PathManager
from codegen.gen_backend_stubs import parse_native_and_custom_yaml
AUTOGRAD_BLACK_LIST = {'npu_format_cast.Tensor', 'npu_format_cast_', 'npu_format_cast_.acl_format'}
torch_npu_root = Path(__file__).parent.parent.parent
PathManager.check_directory_path_readable(torch_npu_root / "version.txt")
with open(torch_npu_root / "version.txt") as version_f:
version = version_f.read().strip()
VERSION_PART = version.split('.')
def parse_derivatives(
native_functions_path: str,
tags_path: str,
autograd_dir: str,
npu_native_functions_path: str
):
derivatives_path = str(Path(autograd_dir).parents[1].joinpath(
f'third_party/op-plugin/op_plugin/config/v{VERSION_PART[0]}r{VERSION_PART[1]}/derivatives.yaml'
))
differentiability_infos, _ = load_derivatives(
derivatives_path, native_functions_path, tags_path)
native_funcs = parse_native_and_custom_yaml(native_functions_path,
tags_path, npu_native_functions_path).native_functions
funcs = filte_out_native_autograd_function(native_funcs, differentiability_infos)
funcs_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo] = []
funcs_with_diff_infos = match_differentiability_info(funcs, differentiability_infos)
filt_funcs_with_diff_infos = [f for f in funcs_with_diff_infos if str(f.func.func.name) not in AUTOGRAD_BLACK_LIST]
return (differentiability_infos, native_funcs, filt_funcs_with_diff_infos)
def filt_npu_autograd_functions(
native_functions_path: str,
funcs_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo]
):
npu_funcs_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo] = []
torch_functions = set()
PathManager.check_directory_path_readable(native_functions_path)
with open(native_functions_path, 'r') as f:
es = yaml.safe_load(f)
for e in es:
torch_functions.add(e.get('func').split('(')[0])
npu_autograd_functions = set()
torch_derivatives_functions = set()
for f in funcs_with_diff_infos:
name = str(f.func.func.name)
if f.info and name not in torch_functions:
npu_funcs_with_diff_infos.append(f)
npu_autograd_functions.add(name)
if f.info and name in torch_functions:
torch_derivatives_functions.add(name)
return npu_funcs_with_diff_infos, npu_autograd_functions, torch_derivatives_functions
def filte_out_native_autograd_function(
native_funcs: List[NativeFunction],
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
):
result: List[NativeFunction] = []
derivatives_name_list: List[str] = []
for diffinfo_dict in differentiability_infos.values():
for info in diffinfo_dict.values():
derivatives_name_list.append(str(info.func.func.name))
for funcs in native_funcs:
func_name = str(funcs.func.name)
func_base_name = str(funcs.func.name.name.base)
if (func_name in derivatives_name_list) or (func_base_name in derivatives_name_list):
result.append(funcs)
return result
_, NPU_AUTOGRAD_FUNCTION, TORCH_AUTOGRAD_FUNCTION = filt_npu_autograd_functions(
str(Path(get_torchgen_dir()).joinpath('packaged/ATen/native/native_functions.yaml')),
parse_derivatives(
str(Path(get_torchgen_dir()).joinpath('packaged/ATen/native/native_functions.yaml')),
str(Path(get_torchgen_dir()).joinpath('packaged/ATen/native/tags.yaml')),
str(Path(__file__).parent),
str(Path(__file__).parents[2].joinpath(f'torch_npu/csrc/aten/{CUSTOM_YAML_NAME}')))[-1]
)