"""
验证流程步骤4:验证量化描述文件
根据规则文件检查 quant_model_description.json 中的层量化类型是否符合预期。
"""
import os
import sys
import json
import argparse
from typing import Any
def load_json(path: str) -> Any:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def find_description_file(path: str) -> str:
"""在指定路径查找描述文件"""
if os.path.isfile(path):
return path
p = os.path.join(path, "quant_model_description.json")
if os.path.exists(p):
return p
return None
def verify_description(desc_path: str, rules_path: str) -> bool:
print("=" * 60)
print("步骤4: 验证量化描述文件")
print("=" * 60)
real_desc_path = find_description_file(desc_path)
if not real_desc_path:
print(f"[ERROR] 未找到量化描述文件 (在路径: {desc_path})")
print(" 期望文件: quant_model_description.json 或 quant_model_description.json")
return False
print(f"[INFO] 描述文件: {real_desc_path}")
try:
desc_data = load_json(real_desc_path)
except Exception as e:
print(f"[ERROR] 加载描述文件失败: {e}")
return False
if not isinstance(desc_data, dict):
print("[ERROR] 描述文件格式错误: 期望为 JSON Object (dict)")
return False
print(f"[INFO] 规则文件: {rules_path}")
if not os.path.exists(rules_path):
print(f"[ERROR] 规则文件不存在: {rules_path}")
return False
try:
rules = load_json(rules_path)
except Exception as e:
print(f"[ERROR] 加载规则文件失败: {e}")
return False
if not isinstance(rules, list):
print("[ERROR] 规则文件格式错误: 期望为 JSON Array (list)")
return False
print("\n[CHECK] 开始匹配规则...")
all_passed = True
total_checked_keys = 0
for i, rule in enumerate(rules):
quant_type = rule.get("quant_type")
keywords = rule.get("keywords", [])
if not quant_type or not keywords:
print(f"[WARNING] 规则 #{i+1} 格式无效 (缺少 quant_type 或 keywords),跳过")
continue
print(f" > 规则 #{i+1}: 期望包含 {keywords} 的权重为 '{quant_type}'")
matched_keys = []
failed_keys = []
for key, value in desc_data.items():
if not isinstance(key, str):
continue
is_match = any(kw in key for kw in keywords)
if is_match:
if value != quant_type:
failed_keys.append((key, value))
else:
matched_keys.append(key)
total_checked_keys += len(matched_keys) + len(failed_keys)
if failed_keys:
all_passed = False
print(f" [FAILED] 发现 {len(failed_keys)} 个不匹配项 (展示前10个):")
for k, v in failed_keys[:10]:
print(f" - {k}: 实际值='{v}', 期望值='{quant_type}'")
if len(failed_keys) > 10:
print(f" ... 还有 {len(failed_keys) - 10} 个")
elif not matched_keys:
print(" [WARNING] 未找到匹配该规则关键字的任何权重键 (可能是关键字有误?)")
else:
print(f" [OK] {len(matched_keys)} 个权重项验证通过")
print("-" * 60)
if all_passed and total_checked_keys > 0:
print("[SUCCESS] 验证通过!所有匹配项均符合预期量化类型。")
return True
if total_checked_keys == 0:
print("[FAILED] 验证失败:未匹配到任何符合规则的权重项,请检查规则关键字。")
return False
print("[FAILED] 验证失败:存在量化类型不匹配的权重项。")
return False
def main():
parser = argparse.ArgumentParser(description="验证量化描述文件内容")
parser.add_argument("--desc-path", required=True, help="量化输出目录或描述文件路径")
parser.add_argument("--rules-path", required=True, help="校验规则JSON文件路径")
args = parser.parse_args()
success = verify_description(args.desc_path, args.rules_path)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()