import glob
import json
import os
import re
from pathlib import Path
from typing import List
from tinker.utils.constant import Version, version_parse
from tinker.utils.logger import logger
from tinker.version import optimizer_version
def project_root():
"""
获取tinker父目录
Returns:
"""
root_path = os.path.abspath(__file__)
for _ in range(3):
root_path = os.path.dirname(root_path)
return root_path
def read_file(file_path: str):
"""
读取文件
:param file_path 文件路径
Returns: 字符串
"""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
def extract_line(reserve_str, content):
"""
提取content中以reserve_str开头的指定行
:param reserve_str 起始字符串
:param content 待识别内容
Returns:
"""
res = list()
lines = content.splitlines()
for line in lines:
if line.startswith(reserve_str):
res.append(line)
return res
def extract_line_ignore_blank(reserve_str, content):
"""
提取content中以reserve_str开头的指定行(忽略空格)
:param reserve_str 起始字符串
:param content 待识别内容
Returns:
"""
res = list()
lines = content.splitlines()
for line in lines:
if line.lstrip().startswith(reserve_str):
res.append(line)
return res
def extract_between(start_word, end_word, content):
"""
提取 start_word 和 end_word 之间的字符串(包含start_word和end_word)
:param start_word 起始字符串
:param end_word 结束字符串
:param content 待识别内容
Returns:
"""
pattern = re.escape(start_word) + r'(.*?)' + re.escape(end_word)
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(0)
else:
return None
def del_line(del_params: list, content: str):
"""
删除content中以del_params开头的内容
:param del_params 待删除关键字
:param content 给定内容
Returns:
"""
lines = content.splitlines()
filtered_lines = [line for line in lines if not any(start_str in line for start_str in del_params)]
return '\n'.join(filtered_lines)
def del_content(start_word, end_word, content: str):
"""
删除content 中 start_word 开始到 end_word结束的内容,不能是同一行(最相邻原则)
:param start_word:
:param end_word:
:param content:
:return:
"""
lines = content.splitlines()
start_idx = -1
end_idx = -1
del_idx_pair = None
for i, line in enumerate(lines):
if line.startswith(start_word):
start_idx = i
continue
if line.startswith(end_word):
end_idx = i
if start_idx != -1:
del_idx_pair = (start_idx, end_idx)
break
continue
if del_idx_pair is None:
raise RuntimeError('cannot find del word pair in content')
del lines[del_idx_pair[0]: del_idx_pair[1] + 1]
return '\n'.join(lines)
def write_lines(final_res: list, dest_file: str):
"""
写入指定内容至文件
:param final_res 待写入内容
:param dest_file 写入路径
Returns:
"""
try:
with open(dest_file, 'w', encoding='utf-8') as file:
for line in final_res:
file.write(line + '\n')
except Exception as e:
raise RuntimeError(f'write to file: {dest_file} failed.') from e
def load_infos(args):
"""读取前序流程保存的模型结构等信息,并保存到全局变量args中"""
model_info = find_files(args.profiled_data_path, 'model_info*.json')
if not model_info:
logger.info('model_info未找到,seq_length取4096')
args.seq_length = 4096
else:
with open(model_info, 'r') as file:
data = json.load(file)
for k, v in data.items():
if k == 'num_layers' and v == 1:
continue
setattr(args, k, v)
task_info = find_files(args.profiled_data_path, 'VERSION*.json')
if task_info:
with open(task_info, 'r') as file:
data = json.load(file)
if 'version_profiler' not in data:
args.version_profiler = data['version']
args.version_framework = Version.MindSpeed_LLM_1_0_rc3
else:
args.version_profiler = data['version_profiler']
args.model_name = data.get('model_name')
args.model_size = data.get('model_size')
if args.pretrain_script_path_search is None:
args.pretrain_script_path = data.get('pretrain_script_path')
args.version_optimizer = optimizer_version()
def find_files(dir_path, pattern):
load_path = os.path.join(dir_path, pattern)
files = glob.glob(load_path)
if files:
return files[0]
return None
def extract_and_format_model_size(model_size: str):
"""
提取模型尺寸中的数字部分,可能是小数(除去b\B)
:param model_size 用户输入的模型尺寸,待统一化
Returns:
"""
model_size_search = re.search(r'\d+(?:\.\d+)?[bB]?', model_size)
if model_size_search is None:
raise RuntimeError(f'The model size {model_size} is not valid, accept pattern like xxb, xxB or xx.')
model_size = model_size_search.group(0)[:-1]
return f'{model_size}b'
def byte_to_mb(x):
"""
将以字节为单位的内存尺寸转换为MB为单位的内存尺寸
:param x: 内存开销(Bytes)
:return: 内存开销(MB)
"""
return x / 1024.0 / 1024.0
def find_keywords_line_idx(source_code: str, key_word: str):
"""
提取 source_code 中 key_word 所在行号的列表
:param source_code 用户输入的模型尺寸,待统一化
:param key_word 用户输入的模型尺寸,待统一化
Returns: line 索引列表
"""
lines = source_code.splitlines()
res = []
for line_idx, line in enumerate(lines):
if key_word in line:
res.append(line_idx)
if not res:
raise RuntimeError(f'Cannot find key word: {key_word} in source code')
return res
def get_lines(module_code: str, start_idx: int, end_idx: int):
"""
获取 module_code 中指定起止位置代码
:param module_code: 给定代码段
:param start_idx: 给起始点
:param end_idx: 给截止点
:return: 区间代码段
"""
lines = module_code.splitlines()
selected_lines = lines[start_idx:end_idx]
return '\n'.join(selected_lines)
def path_to_package(file_system_path):
"""
将路径形式转为包形式
:param file_system_path: 给定路径
:return: 包形式字符串
"""
path = Path(file_system_path)
parts = path.parts
package_path = '.'.join(parts)
return package_path
def extract_arg_value_from_json(json_path: str):
"""提取配置文件parameter_config.json中的参数值"""
with open(json_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def check_path_exist(path: str):
"""校验路径存在"""
if not os.path.exists(path):
logger.error(f'The file path {path} does not exist.')
raise Exception
def check_path_type(path: str, path_type: str):
"""校验路径类型,文件or目录"""
if type == 'file':
if not os.path.isfile(path):
logger.error(f'The {path} should be a file!')
raise Exception
if type == 'dir':
if not os.path.isdir(path):
logger.error(f'The {path} should be a directory!')
raise Exception
def check_file_suffix(path: str, suffix: str):
"""校验文件类型"""
if suffix:
if not path.endswith(suffix):
logger.error('The {path} should be a {suffix} file!')
raise Exception
def check_path_before_create(path: str):
"""创建目录/文件前的路径校验"""
parent_dir = os.path.dirname(path)
check_path_exist(parent_dir)
check_path_type(parent_dir, 'dir')
def check_files_in_dir(path: str):
"""校验目录下存在文件"""
if os.path.isdir(path) and len(os.listdir(path)) == 0:
logger.error(f'No files in {path}')
raise Exception
def convert_to_pp_stage_block_idx(num_layer_list: List[int], num_all_blocks_len: int):
"""
格式转换
:param num_layer_list: 一种可能的划分方式, num_layer_list中的元素为每个stage的长度
:param num_all_blocks_len: 加上头尾blocks的长度
:return:
"""
interval_layer_list = list()
start_num = 1
for stage_length in num_layer_list:
interval_layer_list.append((start_num, start_num + stage_length - 1))
start_num += stage_length
first_tuple = interval_layer_list[0]
interval_layer_list[0] = (0, first_tuple[1])
last_tuple = interval_layer_list[-1]
interval_layer_list[-1] = (last_tuple[0], num_all_blocks_len - 1)
return interval_layer_list
def convert_to_num_layers(interval_layer_list):
num_layer_list = [interval[1] - interval[0] + 1 for interval in interval_layer_list]
num_layer_list[0] -= 1
num_layer_list[-1] -= 2
num_layers = ','.join(map(str, num_layer_list))
return num_layers