{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3 Profiling 分析\n",
"\n",
"本节使用 `torch_npu.profiler` 采集 Baseline 路径的性能数据。默认已完成上一节的依赖安装。\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1 环境与工作区准备\n",
"如需独立执行本 Notebook,会重复完成环境初始化、源码拉取和教程文件覆盖。为加速模型下载,教程默认使用 HF-Mirror 镜像源。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import shutil\n",
"import subprocess\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
"def locate_repo_root():\n",
" candidates = []\n",
" if os.environ.get('GITCODE_REPO_ROOT'):\n",
" candidates.append(Path(os.environ['GITCODE_REPO_ROOT']))\n",
" candidates.extend([\n",
" Path('/opt/atomgit/cann-learning-hub'),\n",
" ])\n",
" try:\n",
" cwd = Path.cwd()\n",
" candidates.extend([cwd, *cwd.parents])\n",
" except FileNotFoundError:\n",
" pass\n",
"\n",
" seen = set()\n",
" for candidate in candidates:\n",
" key = str(candidate)\n",
" if key in seen:\n",
" continue\n",
" seen.add(key)\n",
" if (candidate / 'reference_practice/model_inference_optimization/sana_video/src').exists():\n",
" return candidate\n",
" raise FileNotFoundError('Cannot locate cann-learning-hub repository root.')\n",
"\n",
"REPO_ROOT = locate_repo_root()\n",
"os.chdir(REPO_ROOT)\n",
"TUTORIAL_DIR = REPO_ROOT / 'reference_practice' / 'model_inference_optimization' / 'sana_video'\n",
"SOURCE_DIR = REPO_ROOT / 'Sources' / 'model_inference_optimization' / 'sana_video'\n",
"UPSTREAM_SANA_DIR = SOURCE_DIR / 'Sana_upstream'\n",
"WORKSPACE = SOURCE_DIR / 'Sana'\n",
"SOURCE_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"if not os.environ.get('ASCEND_TOOLKIT_HOME'):\n",
" raise EnvironmentError('ASCEND_TOOLKIT_HOME is not set. Please activate the CANN environment before running this notebook.')\n",
"cann_script = Path(os.environ['ASCEND_TOOLKIT_HOME']) / 'set_env.sh'\n",
"env_cmd = f'source {cann_script} && env'\n",
"env = subprocess.check_output(f\"bash -lc '{env_cmd}'\", shell=True, text=True, cwd=REPO_ROOT)\n",
"for line in env.splitlines():\n",
" if '=' in line:\n",
" os.environ.__setitem__(*line.split('=', 1))\n",
"\n",
"import time\n",
"\n",
"if not UPSTREAM_SANA_DIR.exists():\n",
" max_retries = 3\n",
" clone_success = False\n",
" for attempt in range(max_retries):\n",
" try:\n",
" print(f'Cloning Sana repository (attempt {attempt + 1}/{max_retries})...')\n",
" subprocess.run(\n",
" ['git', 'clone', 'https://github.com/NVlabs/Sana.git', str(UPSTREAM_SANA_DIR)],\n",
" check=True,\n",
" timeout=300,\n",
" )\n",
" clone_success = True\n",
" break\n",
" except subprocess.TimeoutExpired:\n",
" print(f'Clone timed out after 300 seconds')\n",
" if attempt < max_retries - 1:\n",
" print('Retrying in 5 seconds...')\n",
" time.sleep(5)\n",
" except subprocess.CalledProcessError as e:\n",
" print(f'Clone failed with exit code {e.returncode}')\n",
" if attempt < max_retries - 1:\n",
" print('Retrying in 5 seconds...')\n",
" time.sleep(5)\n",
" if not clone_success:\n",
" print('\\nERROR: Failed to clone after 3 attempts.')\n",
" print('Please try running manually in a terminal:')\n",
" print(f' git clone https://github.com/NVlabs/Sana.git {UPSTREAM_SANA_DIR}')\n",
" raise RuntimeError('Failed to clone Sana repository')\n",
"subprocess.run(['git', 'checkout', '08c656c3'], cwd=UPSTREAM_SANA_DIR, check=True)\n",
"WORKSPACE.mkdir(parents=True, exist_ok=True)\n",
"(WORKSPACE / 'inference_video_scripts').mkdir(parents=True, exist_ok=True)\n",
"(WORKSPACE / 'asset' / 'samples').mkdir(parents=True, exist_ok=True)\n",
"shutil.copy2(TUTORIAL_DIR / 'src' / 'pyproject.toml', WORKSPACE / 'pyproject.toml')\n",
"shutil.copy2(TUTORIAL_DIR / 'src' / 'sana_npu_adaptation.py', WORKSPACE / 'sana_npu_adaptation.py')\n",
"shutil.copy2(TUTORIAL_DIR / 'src' / 'inference_video_scripts' / 'inference_sana_video.py', WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py')\n",
"shutil.copy2(TUTORIAL_DIR / 'src' / 'samples' / 'video_prompts_samples.txt', WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt')\n",
"subprocess.run(\n",
" ['bash', '-lc', f'cp -rn \"{UPSTREAM_SANA_DIR}/.\" \"{WORKSPACE}\"'],\n",
" check=True,\n",
")\n",
"print('Workspace ready:', WORKSPACE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2 采集 Baseline Profiling\n",
"这里启用 `torch_npu.profiler`,同时关闭视频保存,避免编码阶段干扰模型采样分析,这里使用 10 个 sample step,本节的时间结果主要用于profiling分析。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PROFILE_WORK_DIR = SOURCE_DIR / 'run_outputs' / 'baseline_profile'\n",
"PROFILE_DIR = SOURCE_DIR / 'profiler' / 'baseline'\n",
"PROFILE_WORK_DIR.mkdir(parents=True, exist_ok=True)\n",
"PROFILE_DIR.mkdir(parents=True, exist_ok=True)\n",
"cmd = [\n",
" sys.executable,\n",
" str(WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py'),\n",
" '--config', str(WORKSPACE / 'configs' / 'sana_video_config' / 'Sana_2000M_480px_AdamW_fsdp.yaml'),\n",
" '--model_path', 'hf://Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth',\n",
" '--txt_file', str(WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt'),\n",
" '--cfg_scale', '6',\n",
" '--motion_score', '30',\n",
" '--flow_shift', '8',\n",
" '--work_dir', str(PROFILE_WORK_DIR),\n",
" '--sample_nums', '1',\n",
" '--step', '10',\n",
" '--metrics_tag', 'baseline_profile',\n",
" '--skip_save', 'True',\n",
" '--enable_torch_profiler', 'True',\n",
" '--profiler_dir', str(PROFILE_DIR),\n",
" '--profiler_active', '1',\n",
" '--profiler_with_stack', 'True',\n",
" '--model.fp32_attention', 'False',\n",
"]\n",
"subprocess.run(cmd, cwd=WORKSPACE, check=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3 查看 Profiler 摘要\n",
"先观察整网 step 级时间与全局热点,再结合下一节的 `RMSNorm` 源码理解为什么值得尝试融合替换。\n",
"\n",
"说明:\n",
"- `mean_single_step_latency_s` 统计的是 `solver.sample(...)` 总采样时间除以 `sample_steps`,只反映采样阶段。\n",
"- `Step trace` 中的 `Stage / Computing / Free` 对应一次 profiler step 的整段 batch 时间,会包含采样前后的其他 batch 内开销。\n",
"- 运行时 tqdm 显示的 `xx s/it` 只反映 DPM-Solver multistep 主循环中可见迭代的耗时;由于前面还有初始化和 warmup,因此不会与最终平均值完全一致。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import csv\n",
"import json\n",
"from collections import defaultdict\n",
"\n",
"profile_metrics_path = PROFILE_WORK_DIR / 'metrics' / 'baseline_profile_summary.json'\n",
"profile_metrics = json.loads(profile_metrics_path.read_text(encoding='utf-8'))\n",
"prof_output = next(PROFILE_DIR.rglob('ASCEND_PROFILER_OUTPUT'))\n",
"print('profiler output =', prof_output)\n",
"print()\n",
"print('Single-step latency baseline:')\n",
"print(f\" sample_steps: {profile_metrics['sample_steps']}\")\n",
"print(f\" mean_sampling_time_s: {profile_metrics['mean_sampling_time_s']:.2f} s\")\n",
"print(f\" mean_single_step_latency_s: {profile_metrics['mean_single_step_latency_s']:.4f} s\")\n",
"print()\n",
"with (prof_output / 'step_trace_time.csv').open(newline='', encoding='utf-8') as f:\n",
" step_trace = next(csv.DictReader(f))\n",
"\n",
"print('Step trace:')\n",
"for key in ['Stage', 'Computing', 'Free', 'Preparing']:\n",
" print(f\" {key}: {float(step_trace[key]) / 1000:.2f} ms\")\n",
"\n",
"agg = defaultdict(lambda: {'count': 0, 'total_us': 0.0, 'ratio': 0.0})\n",
"with (prof_output / 'op_statistic.csv').open(newline='', encoding='utf-8') as f:\n",
" reader = csv.DictReader(f)\n",
" for row in reader:\n",
" op = row['OP Type']\n",
" agg[op]['count'] += int(row['Count'])\n",
" agg[op]['total_us'] += float(row['Total Time(us)'])\n",
" agg[op]['ratio'] += float(row['Ratio(%)'])\n",
"\n",
"rows = sorted(\n",
" ({'op': op, **stats} for op, stats in agg.items()),\n",
" key=lambda r: r['total_us'],\n",
" reverse=True,\n",
")\n",
"\n",
"print()\n",
"print('Top 10 ops by total time:')\n",
"for row in rows[:10]:\n",
" print(\n",
" f\"{row['op']:<15} count={row['count']:>5} total={row['total_us']/1000:>10.2f} ms ratio={row['ratio']:.3f}%\"\n",
" )\n",
"\n",
"focus_ops = ['TransData', 'Mul', 'Cast', 'Pows', 'RealDiv', 'Rsqrt']\n",
"print()\n",
"print('Ops relevant to decomposed RMSNorm:')\n",
"for name in focus_ops:\n",
" row = next((r for r in rows if r['op'] == name), None)\n",
" if row:\n",
" print(\n",
" f\"{row['op']:<15} count={row['count']:>5} total={row['total_us']/1000:>10.2f} ms ratio={row['ratio']:.3f}%\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('profiler_dir =', PROFILE_DIR)\n",
"sorted(str(path.relative_to(PROFILE_DIR)) for path in PROFILE_DIR.rglob('*') if path.is_file())[:20]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4 查看 Baseline `RMSNorm` 源码\n",
"上游 `Sana` 的 `RMSNorm` 在 Baseline 中仍是原始实现,会拆解为 `pow`、`mean`、`rsqrt`、`mul` 等小算子。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"norms_lines = (WORKSPACE / 'diffusion' / 'model' / 'norms.py').read_text(encoding='utf-8').splitlines()\n",
"for line_no in range(182, 232):\n",
" print(f'{line_no}: {norms_lines[line_no - 1]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5 分析结论\n",
"从 Baseline 的 Profiling 结果和 `RMSNorm` 源码可以看到:\n",
"- 当前 `mean_single_step_latency_s` 仅作为 profiling 场景下的采样参考值;正式性能对比以下一节关闭 profiler 后的同口径指标为准。\n",
"- 整网中 `Mul`、`Cast`、`Pows`、`TransData` 等小算子存在可见开销。\n",
"- 结合 Baseline `RMSNorm` 源码,可以看到它仍是分解实现,因此适合在保持其余流程不变的前提下,尝试替换为 `torch_npu.npu_rms_norm`。\n",
"\n",
"下一节会在相同 prompt、seed 和 step 设置下,对比优化前后的性能变化。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}