"""perf-breakdown 脚本共享工具函数。
各脚本以 `python scripts/<name>.py` 形式从 skill 根目录调用,scripts 目录位于
sys.path[0],故可直接 `from _common import ...`。
"""
import json
from pathlib import Path
def validate_file_exists(filepath: str) -> Path:
path = Path(filepath)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {filepath}")
return path
def load_json(filepath: Path) -> dict:
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"JSON 格式错误: {filepath}: {e}") from e
SHAPE_SEMANTIC_ALWAYS_REQUIRED = {
'MatMul', 'MatMulV2', 'QuantBatchMatmulV3', 'GroupedMatmul', 'GemmEx', 'BatchMatMul',
'FlashAttentionScore', 'FusedInferAttentionScore', 'KvQuantSparseFlashAttention',
'HcomAllGather', 'HcomReduceScatter', 'HcomAllToAll', 'hcom_allReduce', 'HcomAllReduce',
'RmsNorm', 'LayerNormV3', 'InplaceAddRmsNorm', 'AddRmsNormDynamicQuant',
'MlaPrologV3', 'DequantSwigluQuant', 'LightningIndexerQuant', 'MoeGatingTopKHash',
'RotaryMul',
'GatherV2', 'GatherV3',
'MoeDistributeDispatchV2', 'MoeDistributeCombineV2',
}
def is_shape_always_required(name: str) -> bool:
"""算子是否始终必填 shape_semantic(含 AddRmsNorm 前缀系列)。"""
return name in SHAPE_SEMANTIC_ALWAYS_REQUIRED or name.startswith('AddRmsNorm')