from typing import Tuple, List
from codegen.api.types import (BaseCType, TupleCType, OptionalCType, BaseCppType)
from codegen.api.signature import DispatcherSignature, NativeSignature
from codegen.model import SchemaKind, NativeFunction
from codegen.api.native import arguments
backend = None
def transfer_args_of_wrapper_func_to_cpu(sig: DispatcherSignature, func: NativeFunction) -> Tuple[str, List[str]]:
convert: str = f"// Convert args to cpu in order to use at::native kernel \n " \
f"static auto warn_once = [](){{ \n " \
f"std::cout << \"Warning: kernel [{sig.func.name}] is not supported by NPU currently. " \
f"Now this kernel is running on CPU.\" << std::endl; \n " \
f"return true; \n " \
f"}}(); \n "
args_names: List[str] = []
args = arguments(sig.func)
for arg in args:
arg_type = str(arg.type)
if 'Tensor' not in arg_type:
args_names.append(f"{arg.nctype.name}")
continue
cpu_arg_name = f"{arg.nctype.name}_cpu"
args_names.append(cpu_arg_name)
if arg_type == 'at::Tensor &':
convert += f"auto {cpu_arg_name} = {arg.nctype.name}.cpu();\n "
elif arg_type == 'const at::Tensor &':
convert += f"const auto {cpu_arg_name} = {arg.nctype.name}.cpu();\n "
elif arg_type == 'const c10::optional<at::Tensor> &':
convert += f"auto {cpu_arg_name} = (*{arg.nctype.name}).cpu();\n "
elif arg_type == 'at::TensorList':
convert += f"""\
::std::vector<at::Tensor> {cpu_arg_name}({arg.nctype.name}.size());
::std::transform({arg.nctype.name}.begin(), {arg.nctype.name}.end(), {cpu_arg_name}.begin(),
[](const Tensor & temp) {{ return temp.cpu(); }});
"""
else:
raise TypeError(f'Do not support cur type {arg.type}')
return convert, args_names
def transfer_ret_of_wrapper_func_to_xla(sig: DispatcherSignature, func_call: str) -> str:
ret_code = ''
if sig.func.kind() == SchemaKind.functional:
if sig.returns_type().cpp_type() == 'at::Tensor':
ret_code = f"return {func_call}.toBackend(Backend::{backend});"
elif sig.returns_type().cpp_type() == '::std::vector<at::Tensor>':
ret_code += f"""\
auto cpu_ret = {func_call};
::std::vector<at::Tensor> ret_xla(cpu_ret.size());
::std::transform(cpu_ret.begin(), cpu_ret.end(), ret_xla.begin(),
[](const Tensor & temp) {{return temp.toBackend(Backend::{backend}); }});
return ret_xla;
"""
elif type(sig.returns_type()) == BaseCType:
ret_code = f"return {func_call};"
elif type(sig.returns_type()) == TupleCType:
ret_code += f"auto cpu_ret = {func_call}; \n "
tuple_ele_names: List[str] = []
for i, e in enumerate(sig.returns_type().elems):
if e.cpp_type() != 'at::Tensor' and type(e) != BaseCppType:
raise TypeError(f'do not support cur type {e.cpp_type()}')
if str(e.type) == 'at::Tensor':
ret_code += f"auto xla_tuple_ele_{i} = ::std::get<{i}>(cpu_ret).toBackend(Backend::{backend}); \n "
tuple_ele_names.append(f"xla_tuple_ele_{i}")
else:
ret_code += f"const auto & tuple_ele_{i} = ::std::get<{i}>(cpu_ret); \n "
tuple_ele_names.append(f"tuple_ele_{i}")
tuple_ele_names_str = ','.join(_ for _ in tuple_ele_names)
ret_code += f"return ::std::make_tuple({tuple_ele_names_str});"
else:
raise TypeError(f'Do not support cur type {sig.returns_type()}')
elif sig.func.kind() == SchemaKind.out:
out_names = [_.name for _ in sig.func.arguments.out]
ret_code = f"{func_call};\n "
for out_name in out_names:
ret_code += f"{out_name}.copy_({out_name}_cpu);\n "
if sig.returns_type().cpp_type() == 'at::Tensor &':
ret_code += f"return {out_names[0]};"
elif type(sig.returns_type()) == TupleCType:
return_types: List[str] = []
for i, e in enumerate(sig.returns_type().elems):
if e.cpp_type() != 'at::Tensor &':
raise TypeError(f'Do not support cur type {e.cpp_type()}')
return_types.append(e.cpp_type())
tuple_args_str = ','.join(_ for _ in out_names)
return_type_str = ','.join(_ for _ in return_types)
ret_code += f"::std::tuple<{return_type_str}> ret_xla({tuple_args_str});\n " \
f"return ret_xla;"
else:
raise TypeError(f'Do not support cur type {sig.returns_type()}')
elif sig.func.kind() == SchemaKind.inplace:
ret_code = f"{func_call};\n "
self_arg_name = sig.func.arguments.self_arg.argument.name
if sig.returns_type().cpp_type() == 'at::Tensor &':
ret_code += f"{self_arg_name}.copy_({self_arg_name}_cpu);\n "
ret_code += f"return {self_arg_name};"
elif sig.returns_type().cpp_type() == 'void':
if str(sig.func.arguments.self_arg.argument.type) == 'Tensor[]':
ret_code += f"""\n \
for (size_t i = 0; i < {self_arg_name}.size(); i++) {{
{self_arg_name}[i].copy_({self_arg_name}_cpu[i]);
}}\n
"""
ret_code += f" return;"
else:
ret_code += f"{self_arg_name}.copy_({self_arg_name}_cpu);\n "
ret_code += f"return;"
else:
raise TypeError(f'Do not support cur type {sig.returns_type()}')
else:
raise TypeError(f'Do not support cur func type {sig.func.kind()}')
return ret_code