import ast
import sys
import os
import logging
logger = logging.getLogger('PATCH_SOURCE')
def check_ast_if_node_is_main_chunk(node: ast.If):
result = (isinstance(node.test, ast.Compare) and isinstance(node.test.left, ast.Name)
and node.test.left.id == '__name__')
result = result and (len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq))
result = result and (len(node.test.comparators) == 1 and isinstance(node.test.comparators[0], ast.Constant)
and node.test.comparators[0].value == '__main__')
return result
def extra_replacements(new_source_path: str):
""" 从new_source_path中提取所有顶层方法和__main__的源码片段 """
with open(new_source_path, 'r', encoding='utf-8') as f:
new_source = f.read()
try:
new_tree = ast.parse(new_source, filename=new_source_path)
except SyntaxError as e:
logger.error(f"Syntax error in new_source_path: {new_source_path}, {e}")
return None, new_source
replacements = {}
for node in new_tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
src = ast.get_source_segment(new_source, node)
if src is not None:
replacements[node.name] = src
elif isinstance(node, ast.If):
if check_ast_if_node_is_main_chunk(node):
src = ast.get_source_segment(new_source, node)
if src is not None:
replacements['__main__'] = src
return replacements, new_source
def patch_old_source_file(old_source_path: str, replacements: dict):
""" 读取old_source.py, 替换同名函数和__main__, 返回新的源码 """
with open(old_source_path, 'r', encoding='utf-8') as f:
old_source = f.read()
try:
old_tree = ast.parse(old_source, filename=old_source_path)
except SyntaxError as e:
logger.error(f"Syntax error in old_source_path: {old_source_path}, {e}")
return old_source
replace_ranges = []
for node in old_tree.body:
key = None
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name in replacements:
key = node.name
elif isinstance(node, ast.If):
if check_ast_if_node_is_main_chunk(node):
if '__main__' in replacements:
key = '__main__'
if key is not None:
old_src_seg = ast.get_source_segment(old_source, node)
if old_src_seg is None:
continue
pos = old_source.find(old_src_seg, max(0, old_source.find('\n', 0, node.lineno * 2)))
if pos == -1:
pos = old_source.find(old_src_seg)
if pos == -1:
logger.warning(f"Failed to find segment[{key}] in {old_source_path}")
continue
end = pos + len(old_src_seg)
replace_ranges.append((pos, end, replacements[key]))
new_source = old_source
for start, end, new_src_seg in sorted(replace_ranges, key=lambda x: x[0], reverse=True):
new_source = new_source[:start] + new_src_seg + new_source[end:]
return new_source
def patch_source(old_source_path: str, new_source_path: str, output_path: str = None) -> bool:
"""
用于将old_source_path所指定的旧py源文件中的方法(含if main块) 替换为 new_source_path 文件中的同名方法实现
输出到output_path
:param old_source_path: 待替换源文件.py路径
:param new_source_path: 待用于替换的新实现.py路径
:param output_path: 输出路径,缺省将直接覆盖old_source_path
:return: 是否成功
"""
if output_path is None:
output_path = old_source_path
if not os.path.exists(old_source_path):
logger.error(f"The old source file not found: {old_source_path}")
return False
if not os.path.exists(new_source_path):
logger.error(f"The new source file not found: {new_source_path}")
return False
replacements, _ = extra_replacements(new_source_path)
new_source = patch_old_source_file(old_source_path, replacements)
with open(output_path, 'w', encoding='utf-8', newline='') as f:
f.write(new_source)
logger.info(f"Successfully patched functions from new source path {new_source_path} into {output_path}")
return True