import ast
import importlib
import inspect
import os
import sys
import textwrap
from typing import List, Tuple, Dict
import astor
from tinker.model.block_adapters import BlockAdapter, legacy_block_adapters, mcore_block_adapters
from tinker.utils.config import TINKER_DIR
from tinker.utils.logger import logger
from tinker.utils.utils import write_lines, project_root, find_keywords_line_idx, get_lines, read_file, path_to_package
from tinker.utils.constant import MODULE_NAME, PYTHON_STANDARD_INDENT
block_adapter_file_path = os.path.join(TINKER_DIR, f'model/{MODULE_NAME}.py')
def find_source_code(location_list: List[List[str]]) -> Tuple[List[str], List]:
"""
根据输入的外部地址,返回生成 model adapter 所需条件
:param location_list: 用于定位
:return: source_code_list 前向所在源码列表 以及 模块对象列表
"""
source_code_list = []
module_obj_list = []
for locations in location_list:
for location in locations:
last_dot_index = location.rfind('.')
module_path = location[:last_dot_index]
class_or_method_name = location[last_dot_index + 1:]
try:
module_obj = importlib.import_module(module_path)
class_or_method_obj = getattr(module_obj, class_or_method_name)
method_obj = getattr(class_or_method_obj, 'forward') if inspect.isclass(
class_or_method_obj) else class_or_method_obj
source_code = inspect.getsource(method_obj)
source_code_list.append(source_code)
module_obj_list.append(module_obj)
except (ImportError, AttributeError):
logger.debug('location: %s is error', location, exc_info=True)
else:
logger.info(f'Successfully find location {location}')
break
else:
location_text = "\n".join(locations)
raise RuntimeError(f'The package is not supported in the current version:\n{location_text}')
return source_code_list, module_obj_list
def get_top_level_import(tinker_patch_dict: dict, module_source_code: str) -> List[str]:
"""
从module_source_code中搜索顶层import或from ..import,若 from ...import xxx中的xxx命中tinker_patch_dict的key,则用value 替换此时的 ...
:param tinker_patch_dict: 提供的写死的返回的参数
:param module_source_code: 提供的写死的返回的参数
:return: import 或 from import list
"""
node = ast.parse(module_source_code)
import_statements = []
for n in node.body:
if isinstance(n, ast.Import):
for alias in n.names:
import_statement = f"import {alias.name} as {alias.asname}" if alias.asname else f"import {alias.name}"
import_statements.append(import_statement)
elif isinstance(n, ast.ImportFrom):
if n.level != 0:
continue
module = n.module if n.module else ''
for alias in n.names:
real_module = tinker_patch_dict.get(alias.name, module)
import_statement = f"from {real_module} import {alias.name} as {alias.asname}" if alias.asname else \
f"from {real_module} import {alias.name}"
import_statements.append(import_statement)
return import_statements
def scan_tinker_megatron_patch(scan_path: str):
"""
扫描 scan_path 下的所有patch文件,返回 method -> module package 对
:param scan_path: 是tinker根目录下的相对路径
:return:
"""
project_path = project_root()
megatron_patch_path = os.path.join(project_path, scan_path)
try:
patch_files = os.listdir(megatron_patch_path)
except FileNotFoundError:
logger.error('Cannot find path: %s', megatron_patch_path)
raise
res = dict()
for patch_file in patch_files:
if patch_file == '__init__.py' or not patch_file.endswith('.py'):
continue
file = read_file(os.path.join(megatron_patch_path, patch_file))
node = ast.parse(file)
for n in node.body:
if isinstance(n, ast.FunctionDef):
package_path = path_to_package(scan_path)
base_name = os.path.splitext(patch_file)[0]
res[n.name] = '.'.join([package_path, base_name])
return res
class ImportTracker(ast.NodeVisitor):
"""
统计所有import内容,存在self.imports中
"""
def __init__(self):
self.imports: Dict[str, str] = {}
def visit_Import(self, node):
for alias in node.names:
mod_name = alias.name.split('.')[0]
self.imports[alias.asname or mod_name] = mod_name
def visit_ImportFrom(self, node):
module = node.module.split('.')[0] if node.module else ''
for alias in node.names:
full_name = f"{module}.{alias.name}" if module else alias.name
self.imports[alias.asname or alias.name] = full_name
class FuncUsageFinder(ast.NodeVisitor):
def __init__(self, target_modules: Dict[str, str], imports: Dict[str, str]):
self.targets = target_modules
self.import_map = imports
self.used_funcs_code = []
def visit_Name(self, node):
self._check_usage(node.id)
def visit_Attribute(self, node):
"""处理属性链中的顶级对象"""
if isinstance(node.value, ast.Name):
self._check_usage(node.value.id)
elif isinstance(node.value, ast.Attribute):
self.visit(node.value)
def _check_usage(self, name: str):
"""核心匹配逻辑"""
if name not in self.import_map and name in self.targets:
self.used_funcs_code.append(self.targets.get(name))
def get_import_code_str(module_obj_list):
"""
从源码中抽取并汇总 所有import 部分代码
:param module_obj_list: 模块源码列表
:return:
"""
tinker_patch_dict = scan_tinker_megatron_patch('tinker/megatron_patch')
all_imports = set()
for module_obj in module_obj_list:
module_source_code = inspect.getsource(module_obj)
top_level_import = get_top_level_import(tinker_patch_dict, module_source_code)
all_imports.update(top_level_import)
return '\n'.join(all_imports)
def get_module_methods(module_obj):
"""
获取模块对象所有方法和源码的组合
:param module_obj: 模块对象
:return:
"""
functions = {
name: inspect.getsource(getattr(module_obj, name))
for name in dir(module_obj)
if inspect.isfunction(getattr(module_obj, name))
}
return functions
def error_free_import():
"""确保各框架版本,megatron均被patch"""
package_path = os.getenv('ML_PATH', None)
if package_path:
sys.path.append(package_path)
module_names = ['modellink', 'mindspeed_llm', 'ascendspeed.megatron_adaptor']
for module_name in module_names:
try:
logger.debug(f'try to import {module_name}')
importlib.import_module(module_name)
if module_name == 'mindspeed_llm':
sys.modules['modellink'] = sys.modules['mindspeed_llm']
logger.debug(f'import {module_name} success')
break
except ImportError:
logger.debug(f'import {module_name} failed', exc_info=True)
else:
raise RuntimeError(f'No available patch framework')
def gen_block_adapter(use_mcore_models):
"""
从用户提供的版本以及是否启用mcore_model,动态生成适配每个版本的 block_adapter
:param version: 版本号
:param use_mcore_models: 是否使用 mcore
:return:
"""
package_path = os.getenv('ML_PATH', None)
if not package_path:
raise RuntimeError("ML_PATH is not set")
if not os.path.exists(package_path):
raise RuntimeError(f'The package path is not exist: {package_path}')
logger.info('The package_path is: %s', package_path)
sys.path.append(package_path)
error_free_import()
block_adapters = mcore_block_adapters if use_mcore_models else legacy_block_adapters
source_method_paths = [adapter.source_method_path for adapter in block_adapters]
method_forward_source_code_list, module_obj_list = find_source_code(source_method_paths)
import_code_str = get_import_code_str(module_obj_list)
result = [import_code_str]
for method_forward_source_code, block_adapter, module_obj in zip(method_forward_source_code_list,
block_adapters,
module_obj_list):
method_forward_str = gen_method_forward(method_forward_source_code, block_adapter)
used_funcs_code = find_used_top_func(import_code_str, method_forward_str, module_obj)
result.extend(used_funcs_code)
result.append(method_forward_str)
file_content = '\n\n\n'.join(result)
try:
if os.path.exists(block_adapter_file_path):
os.remove(block_adapter_file_path)
except OSError:
logger.error("Error occurred when attempting to delete the historical old file during the preparation "
"for dynamic block forward method generation. Please check file permissions, "
"whether the file is occupied by other processes, or the disk status.")
raise
write_lines(file_content.splitlines(), block_adapter_file_path)
def find_used_top_func(import_code_str, method_forward_str, module_obj):
"""
前向方法中用到,但import中没有,那么需要加入这个方法
:param import_code_str:
:param method_forward_str:
:param module_obj: 当前操作的模块,用于获取该模块顶层方法
:return: 被调用的、需要放进生成代码的方法定义
"""
module_methods = get_module_methods(module_obj)
temp_target_code = '\n'.join([import_code_str, method_forward_str])
tree = ast.parse(temp_target_code)
import_tracker = ImportTracker()
import_tracker.visit(tree)
usage_finder = FuncUsageFinder(module_methods, import_tracker.imports)
usage_finder.visit(tree)
used_funcs_code = usage_finder.used_funcs_code
return used_funcs_code
def modify_method(method_forward_head_body, function_args: list, block_name: str) -> str:
"""
修改 针对forward方法做修改
:param method_forward_head_body: 给定识别条件
:param function_args: 需要增加的参数列表
:param block_name: block名称,用于生成forward函数名
:return:
"""
method_forward_head_body_tree = ast.parse(method_forward_head_body)
function_node = method_forward_head_body_tree.body[0]
change_func_name(function_node, block_name)
set_method_param_default_none(function_node)
add_params_if_not_exist(function_node, function_args)
return astor.to_source(method_forward_head_body_tree)
def has_return_statement(func_code):
"""
查询方法节点是否包含 return 子节点
:param func_code: 方法代码
:return:
"""
func_node = ast.parse(func_code)
for node in ast.walk(func_node):
if isinstance(node, ast.Return):
return True
return False
def gen_method_forward(source_code: str, block_adapter: BlockAdapter) -> str:
"""
获取 forward 及 get_output_name 方法
:param source_code: 给定识别条件
:param block_adapter: 给定识别条件
:return:
"""
target_code = get_effective_part(block_adapter, source_code)
target_code = textwrap.dedent(target_code)
try:
target_code_tree = ast.parse(target_code)
except SyntaxError as e:
logger.error(f'Cannot parse target forward method code for {block_adapter.block_name}, '
f'please check keywords and source code')
raise RuntimeError('Cannot parse target code') from e
first_node_of_tree = target_code_tree.body[0]
if isinstance(first_node_of_tree, ast.FunctionDef):
method_forward_head_body = target_code
else:
method_forward_head = get_function_header(textwrap.dedent(source_code))
method_forward_head = textwrap.dedent(method_forward_head)
method_forward_body = textwrap.indent(target_code, PYTHON_STANDARD_INDENT)
method_forward_head_body = '\n'.join([method_forward_head, method_forward_body])
method_forward_head_body = modify_method(method_forward_head_body,
block_adapter.append_method_signatures, block_adapter.block_name)
has_return = has_return_statement(method_forward_head_body)
if has_return:
method_forward_return = ''
else:
return_values = ', '.join(block_adapter.return_values)
method_forward_return = f'return {return_values}'
method_forward_return = textwrap.indent(method_forward_return, PYTHON_STANDARD_INDENT)
method_forward = '\n'.join([method_forward_head_body, method_forward_return])
if block_adapter.block_name == 'post_process':
try:
tree = ast.parse(method_forward)
function_def = tree.body[0]
if isinstance(function_def, ast.FunctionDef):
assign_node = ast.Assign(
targets=[ast.Name(id='output_weight', ctx=ast.Store())],
value=ast.Constant(value=None)
)
function_def.body.insert(0, assign_node)
method_forward = astor.to_source(tree).strip()
except Exception as e:
logger.debug(f'Error adding output_weight definition: {e}')
return method_forward
def cut_lines(source_code: str, start_idx: int, end_idx: int):
"""
将source_code扣掉 start_idx 和 end_idx 之间的部分
:param source_code: 源代码
:param start_idx: 起始行
:param end_idx: 截止行
:return:
"""
lines = source_code.splitlines()
res = list()
res.extend(lines[: start_idx])
res.extend(lines[end_idx + 1:])
return '\n'.join(res)
def get_effective_part(block_adapter: BlockAdapter, source_code: str):
"""
根据给定关键字,提取源码中对应的部分
:param block_adapter: 存储block前向代码识别条件
:param source_code: 目标代码所在源码
:return: 匹配到的目标代码
"""
start_key_word, end_key_word = block_adapter.key_words
if start_key_word:
start_line_idx = find_keywords_line_idx(source_code, start_key_word)[0]
else:
start_line_idx = 0
if end_key_word:
end_line_idx = find_keywords_line_idx(source_code, end_key_word)[-1]
else:
end_line_idx = len(source_code.splitlines()) - 1
if block_adapter.method_location.cut_mode:
target_code = cut_lines(source_code, start_line_idx, end_line_idx + 1)
else:
target_code = get_lines(source_code, start_line_idx, end_line_idx + 1)
return target_code
def get_function_header(function_code: str):
"""
获取方法头
:param function_code: 方法代码
:return: 方法头
"""
tree = ast.parse(function_code)
first_node = tree.body[0]
if isinstance(first_node, ast.FunctionDef):
first_statement = first_node.body[0]
first_statement_line = first_statement.lineno
return get_lines(function_code, 0, first_statement_line - 1)
else:
raise RuntimeError('When parsing function head line, the first line should be func.')
def add_params_if_not_exist(function_node, function_args: list):
"""
若参数不存在,则给函数节点添加参数
:param function_node: 函数节点
:param function_args: 要添加的参数
:return:
"""
if not function_args:
return
exist_args = [arg.arg for arg in function_node.args.args]
for function_arg in function_args:
if function_arg in exist_args:
continue
new_arg = ast.arg(arg=function_arg, annotation=None)
default_value = ast.Constant(value=None)
function_node.args.args.append(new_arg)
function_node.args.defaults.append(default_value)
def set_method_param_default_none(function_node):
"""
把函数的所有方法参数置为None
:param function_node: 函数节点
:return:
"""
num_params_not_contains_self = sum(arg.arg != 'self' for arg in function_node.args.args)
num_params_contains_defaults = len(function_node.args.defaults)
num_default_none = num_params_not_contains_self - num_params_contains_defaults
need_insert = [ast.Constant(value=None) for _ in range(0, num_default_none)]
function_node.args.defaults[:0] = need_insert
def change_func_name(function_node, block_name):
"""
改函数名
:param function_node: 函数节点
:param block_name: block名称,用于生成forward函数名
:return:
"""
function_node.name = get_forward_func_name(block_name)
def get_forward_func_name(block_name: str) -> str:
return f'tinker_{block_name}_forward'