"""步骤3:验证权重一致性(精简版)。"""
import argparse
import glob
import os
import sys
import torch
def _load_weights(model_path):
try:
from safetensors.torch import load_file
files = sorted(glob.glob(os.path.join(model_path, "*.safetensors")))
if files:
merged = {}
for file in files:
merged.update(load_file(file))
return merged
except Exception:
pass
pt_path = os.path.join(model_path, "pytorch_model.bin")
if os.path.exists(pt_path):
return torch.load(pt_path, map_location="cpu")
return {}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--original-path", required=True)
parser.add_argument("--quantized-path", required=True)
parser.add_argument("--tolerance", type=float, default=1e-5)
args = parser.parse_args()
left = _load_weights(args.original_path)
right = _load_weights(args.quantized_path)
if not left or not right:
print("[ERROR] step3失败: 权重加载失败")
return 1
left_keys = set(left.keys())
right_keys = set(right.keys())
if left_keys != right_keys:
print("[ERROR] step3失败: 权重键不一致")
print(f"[INFO] 仅左侧数量: {len(left_keys - right_keys)}")
print(f"[INFO] 仅右侧数量: {len(right_keys - left_keys)}")
return 1
max_diff = 0.0
for key in sorted(left_keys):
l = left[key]
r = right[key]
if l.shape != r.shape:
print(f"[ERROR] step3失败: 形状不一致 {key}")
return 1
diff = torch.abs(l.float() - r.float()).max().item()
if diff > max_diff:
max_diff = diff
if diff > args.tolerance:
print(f"[ERROR] step3失败: 权重差异超阈值 {key} diff={diff:.2e}")
return 1
print(f"[OK] step3完成: max_diff={max_diff:.2e}")
return 0
if __name__ == "__main__":
sys.exit(main())