import os
import shutil
import logging
import gemm_operation
import library
LOGGER = logging.getLogger(__name__)
class OperationRegistry:
archs = list(library.ArchTag)
register_functions = {arch: {} for arch in archs}
register_functions_high_priority = {arch: {} for arch in archs}
@classmethod
def register(cls, name, arch_list=None):
def decorator(func):
archs = arch_list or [library.ArchTag.A2]
for arch in archs:
cls.register_functions[arch][name] = func
return func
return decorator
@classmethod
def register_high_priority(cls, name, arch_list=None):
def decorator(func):
archs = arch_list or [library.ArchTag.A2]
for arch in archs:
cls.register_functions_high_priority[arch][name] = func
return func
return decorator
class Manifest:
def __init__(self, args):
self.args = args
self.operations = []
self.operations_dict = {}
self.enable_filter_out = True
self.filtered_kernels = [args.kernels]
self.arch = library.ArchTag.A2
self.target_generator = {
'gemm': gemm_operation.GemmOperationGenerator
}
if args.arch in library.ARCH_TAG_DICT.keys():
self.arch = library.ARCH_TAG_DICT[args.arch]
else:
raise Exception(f'unknown arch {args.arch}')
LOGGER.info(f'arch tag is {self.arch.to_code()}')
for arch, inner_dict in OperationRegistry.register_functions_high_priority.items():
if arch is not self.arch:
continue
for _, func in inner_dict.items():
func(self)
for arch, inner_dict in OperationRegistry.register_functions.items():
if arch is not self.arch:
continue
for name, func in inner_dict.items():
if name in OperationRegistry.register_functions_high_priority[arch]:
LOGGER.info(
f'skip seach space registration of {name} in search_space.py'
f' due to a duplicate registration in seach_sapce_config.py'
)
else:
func(self)
LOGGER.info(f'operations that will be generated in total: {len(self.operations)}')
if len(self.operations) > 10000:
raise Exception(
'Due to limits of bisheng compiler, compiling more than 10,000 operations are not guaranteed'
)
self.register_all_operations_template = """
#include "catlass/library/operation.h"
#include "catlass/library/manifest.h"
namespace Catlass {{
namespace Library {{
{api_decl_src}
void RegisterAllKernels(Manifest &manifest)
{{
{api_call_src}
}}
}}
}}
"""
self.function_decl_template = """void Register_{kernel_name}(Manifest &manifest);\n"""
self.function_call_template = """ Register_{kernel_name}(manifest);\n"""
self.register_template = """
#include "catlass/library/operation.h"
#include "catlass/library/manifest.h"
namespace Catlass {{
namespace Library {{
{function_decls}
void RegisterCatlass{operation_type}Operations(Manifest &manifest)
{{
{function_calls}
}}
}}
}}
"""
def append(self, operation):
if self.filter_out(operation):
return
self.operations.append(operation)
if operation.operation_type not in self.operations_dict.keys():
self.operations_dict[operation.operation_type] = {}
if operation.get_name() not in self.operations_dict[operation.operation_type].keys():
self.operations_dict[operation.operation_type][operation.get_name()] = {}
self.operations_dict[operation.operation_type][operation.get_name()] = operation
def filter_out(self, operation):
if not self.enable_filter_out:
return False
operation_name = operation.get_name()
for kernel_name in self.filtered_kernels:
if kernel_name in operation_name:
return False
return True
def generate_code(self):
workspace_dir = self.args.workspace_dir
generated_dir = os.path.join(workspace_dir, 'generated')
LOGGER.debug(f'generated_dir={generated_dir}')
if os.path.exists(generated_dir) and not os.path.islink(generated_dir):
shutil.rmtree(generated_dir)
elif os.path.islink(generated_dir):
raise PermissionError(
f'generated directory {generated_dir} is a soft link, which is not allowed to be removed.'
)
else:
pass
os.mkdir(generated_dir)
api_decl_src = []
api_call_src = []
for operation_type, names in self.operations_dict.items():
api_decl_src.append('void RegisterCatlass{}Operations(Manifest &manifest);'.format(operation_type))
api_call_src.append(' RegisterCatlass{}Operations(manifest);'.format(operation_type))
kernel_names = []
operation_subdir = os.path.join(generated_dir, operation_type)
if not os.path.exists(operation_subdir):
os.mkdir(operation_subdir)
with self.target_generator[operation_type](operation_type, generated_dir) as generator:
for name, operation in names.items():
LOGGER.info(f'generating kernel: {name}')
kernel_names.append(name)
generator.gen(name, operation)
function_calls = ''
function_decls = ''
for kernel_name in kernel_names:
function_calls += self.function_call_template.format(kernel_name=kernel_name)
function_decls += self.function_decl_template.format(kernel_name=kernel_name)
operation_register_src = self.register_template.format(
operation_type=operation_type,
function_calls=function_calls,
function_decls=function_decls
)
self._write_to_register_file(
os.path.join(operation_subdir,
f'register_all_{operation_type}_operations.cpp'),
operation_register_src)
register_all_kernels_src = self.register_all_operations_template.format(
api_decl_src='\n'.join(api_decl_src), api_call_src='\n'.join(api_call_src)
)
self._write_to_register_file(
os.path.join(generated_dir, 'register_all_kernels_generated.cpp'),
register_all_kernels_src)
@staticmethod
def _write_to_register_file(reg_filename, content):
try:
os.remove(reg_filename)
except FileNotFoundError:
pass
fd = None
try:
fd = os.open(reg_filename,
os.O_CREAT | os.O_WRONLY | os.O_TRUNC,
0o550)
with os.fdopen(fd, "w") as f:
f.write(content)
fd = None
finally:
if fd is not None:
os.close(fd)