"""
We can't use assert in our code for codecheck, so create this auxiliary function to wrap
the assert case in ut for ci.
"""
import os
import hashlib
import logging
import re
import json
import glob
import sys
from concurrent.futures import ProcessPoolExecutor
import subprocess
import torch
import torch_npu
import xxhash
import pytest
import megatron.core.parallel_state as mpu
from megatron.core.parallel_state import initialize_model_parallel
from mindspeed.core.parallel_state import initialize_model_parallel_wrapper
from mindspeed.core.context_parallel.model_parallel_utils import initialize_model_parallel_cp_wrapper
from mindspeed_llm.core.parallel_state import initialize_model_parallel_decorator
def judge_expression(expression):
if not expression:
raise AssertionError
def hash_tensor_in_chunks(tensor, chunk_size=1024 * 1024):
"""分块计算张量的哈希值,适用于超大张量,支持 BFloat16 等类型。"""
hasher = xxhash.xxh3_64()
numel = tensor.numel()
tensor_flat = tensor.view(-1)
if tensor.dtype == torch.bfloat16:
tensor_flat = tensor_flat.to(torch.float32)
for i in range(0, numel, chunk_size):
chunk = tensor_flat[i:i + chunk_size]
if not chunk.is_contiguous():
chunk = chunk.contiguous()
hasher.update(chunk.cpu().numpy().tobytes())
return hasher.digest()
def calculate_hash_for_model(data, chunk_size=1024 * 1024):
final_hasher = xxhash.xxh3_64()
tensor_data = {k: v for k, v in data.items() if torch.is_tensor(v)}
non_tensor_data = {k: v for k, v in data.items() if not torch.is_tensor(v)}
if tensor_data:
tensor_hashes = [
hash_tensor_in_chunks(value, chunk_size)
for key, value in sorted(tensor_data.items())
]
for key, tensor_hash in zip(sorted(tensor_data.keys()), tensor_hashes):
final_hasher.update(key.encode('utf-8'))
final_hasher.update(tensor_hash)
for key in sorted(non_tensor_data.keys()):
final_hasher.update(key.encode('utf-8'))
value = non_tensor_data[key]
if isinstance(value, (int, float)):
final_hasher.update(str(value).encode('utf-8'))
elif isinstance(value, str):
final_hasher.update(value.encode('utf-8'))
else:
final_hasher.update(repr(value).encode('utf-8'))
return final_hasher.hexdigest()
def compare_state_dicts(state_dict1, state_dict2):
if 'rerun_state_machine' not in state_dict1.keys() and 'rerun_state_machine' in state_dict2.keys():
state_dict1['rerun_state_machine'] = state_dict2['rerun_state_machine']
if state_dict1.keys() != state_dict2.keys():
print(f"base:{state_dict1.keys()} != save:{state_dict2.keys()}")
return False
for key in state_dict1.keys():
value1 = state_dict1[key]
value2 = state_dict2[key]
if isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
if not torch.equal(value1, value2):
print(f"Difference found in key: {key}")
return False
elif isinstance(value1, dict) and isinstance(value2, dict):
if not compare_state_dicts(value1, value2):
return False
else:
pass
return True
def process_file(file_path):
data = torch.load(file_path, map_location='cpu', weights_only=False)
layer_ckpt = {}
for key in data.keys():
if key.startswith('model'):
layer_ckpt.update(data[key])
data = layer_ckpt
return data
def compare_with_base_hash(file_path, base_hash, file_type='pt'):
if not os.path.exists(file_path):
return f"Error: File {file_path} does not exist"
if file_type == 'pt':
try:
data = process_file(file_path)
if isinstance(data, str):
return data
current_hash = calculate_hash_for_model(data)
except Exception as e:
raise ValueError(f"Error: Failed to process file {file_path} - {str(e)}") from e
elif file_type == 'safetensors' or 'bin':
current_hash = get_md5sum(file_path)
else:
raise ValueError(f"Unsupported file type: {file_type}")
return current_hash == base_hash
def weight_compare_hash(model_dir, base_hash, suffix="pt"):
models_path = glob.glob(os.path.join(model_dir, '**', f'*.{suffix}'), recursive=True)
models_path.sort()
if not models_path:
raise ValueError(f"Error: No .{suffix} files found in current directory")
if len(models_path) != len(base_hash):
raise ValueError(f"Error: Number of files don't match ({len(models_path)} vs {len(base_hash)})")
cpu_count = os.cpu_count() or 1
max_workers = min(cpu_count, len(models_path))
logging.info(f"Using {max_workers} workers based on CPU count: {cpu_count}")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
tasks = [
executor.submit(compare_with_base_hash, models_path[i], base_hash[i], suffix)
for i in range(len(models_path))
]
for _, future in enumerate(tasks):
result = future.result()
if not result:
return False
return True
def weight_compare(dir_1, dir_2, suffix="pt", use_md5=False):
models_path = glob.glob(os.path.join(dir_1, '**', f'*.{suffix}'), recursive=True)
if not models_path:
print(f"Can't find any weight files in {dir_1}.")
return False
for path_1 in models_path:
path_1 = os.path.normpath(path_1)
path_2 = path_1.replace(os.path.normpath(dir_1), os.path.normpath(dir_2))
if use_md5:
are_equal = (get_md5sum(path_1) == get_md5sum(path_2))
else:
state_dict1 = torch.load(path_1, weights_only=False)
state_dict2 = torch.load(path_2, weights_only=False)
are_equal = compare_state_dicts(state_dict1, state_dict2)
if not are_equal:
return False
return True
def weight_compare_optim(dir_1, dir_2, suffix="pt", use_md5=False):
models_path = glob.glob(os.path.join(dir_1, '**', f'*.{suffix}'), recursive=True)
if not models_path:
raise FileNotFoundError(f"{dir_1} is not a file or not exists !")
for path_1 in models_path:
path_1 = os.path.normpath(path_1)
path_2 = path_1.replace(os.path.normpath(dir_1), os.path.normpath(dir_2))
file_name = os.path.basename(path_1)
if file_name == 'distrib_optim.pt':
use_md5 = True
elif file_name == 'model_optim_rng.pt':
use_md5 = False
if use_md5:
are_equal = (get_md5sum(path_1) == get_md5sum(path_2))
else:
state_dict1 = torch.load(path_1, weights_only=False)
state_dict2 = torch.load(path_2, weights_only=False)
are_equal = compare_state_dicts(state_dict1, state_dict2)
if not are_equal:
return False
return True
def compare_file_md5_same(file1, file2):
return get_md5sum(file1) == get_md5sum(file2)
def get_md5sum(fpath):
if not os.path.isfile(fpath):
raise FileNotFoundError(f"{fpath} is not a file or not exists !")
md5sum = hashlib.md5()
with open(fpath, 'rb') as f:
md5sum.update(f.read())
return md5sum.hexdigest()
def delete_distrib_optim_files(folder_path):
for root, dirs, files in os.walk(folder_path):
for file in files:
if file == "distrib_optim.pt":
file_path = os.path.join(root, file)
try:
os.remove(file_path)
logging.info(f"Deleted: {file_path}")
except Exception as e:
logging.exception(f"Failed to delete {file_path}: {e}")
raise
@pytest.fixture
def build_args(request, monkeypatch):
params = request.getfixturevalue("params")
argv = [sys.argv[0]]
for k, v in params.items():
if v is None:
argv.append(f"--{k}")
elif isinstance(v, list):
argv.extend([f"--{k}"] + [str(value) for value in v])
else:
argv.extend([f"--{k}", str(v)])
monkeypatch.setattr(sys, "argv", argv)
def create_testconfig(path: str, cmd: bool = False):
with open(path) as f:
raw_data = json.load(f)
res = {k: [tuple(s.values()) if len(s) > 1 else tuple(s.values())[0] for s in v] for k, v in raw_data.items()}
if not cmd:
return res
def __dict2cmdlist(param_value):
cmdlsts = []
cmdlst = []
for target in param_value:
if not isinstance(target, dict):
continue
for k, v in target.items():
cmdlst.append(f"--{k}")
if v is not None:
if isinstance(v, str):
cmdlst.extend(v.split())
else:
cmdlst.append(v)
cmdlsts.extend(cmdlst)
return cmdlsts
res_cmd = {key: __dict2cmdlist(value) for key, value in res.items()}
return res_cmd
class ListHandler(logging.Handler):
def __init__(self, pattern):
super().__init__()
self.log_capture = []
self.pattern = pattern
def emit(self, record):
log_entry = self.format(record)
if re.search(self.pattern, log_entry, re.DOTALL):
self.log_capture.append(log_entry)
def setup_logger(pattern):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = ListHandler(pattern)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
return handler, handler.log_capture
def run_cmd(cmd_strlist):
return subprocess.run(cmd_strlist).returncode