{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 Profiling 分析\n",
    "\n",
    "本节使用 `torch_npu.profiler` 采集 Baseline 路径的性能数据。默认已完成上一节的依赖安装。\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 环境与工作区准备\n",
    "如需独立执行本 Notebook,会重复完成环境初始化、源码拉取和教程文件覆盖。为加速模型下载,教程默认使用 HF-Mirror 镜像源。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import subprocess\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "def locate_repo_root():\n",
    "    candidates = []\n",
    "    if os.environ.get('GITCODE_REPO_ROOT'):\n",
    "        candidates.append(Path(os.environ['GITCODE_REPO_ROOT']))\n",
    "    candidates.extend([\n",
    "        Path('/opt/atomgit/cann-learning-hub'),\n",
    "    ])\n",
    "    try:\n",
    "        cwd = Path.cwd()\n",
    "        candidates.extend([cwd, *cwd.parents])\n",
    "    except FileNotFoundError:\n",
    "        pass\n",
    "\n",
    "    seen = set()\n",
    "    for candidate in candidates:\n",
    "        key = str(candidate)\n",
    "        if key in seen:\n",
    "            continue\n",
    "        seen.add(key)\n",
    "        if (candidate / 'reference_practice/model_inference_optimization/sana_video/src').exists():\n",
    "            return candidate\n",
    "    raise FileNotFoundError('Cannot locate cann-learning-hub repository root.')\n",
    "\n",
    "REPO_ROOT = locate_repo_root()\n",
    "os.chdir(REPO_ROOT)\n",
    "TUTORIAL_DIR = REPO_ROOT / 'reference_practice' / 'model_inference_optimization' / 'sana_video'\n",
    "SOURCE_DIR = REPO_ROOT / 'Sources' / 'model_inference_optimization' / 'sana_video'\n",
    "UPSTREAM_SANA_DIR = SOURCE_DIR / 'Sana_upstream'\n",
    "WORKSPACE = SOURCE_DIR / 'Sana'\n",
    "SOURCE_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "if not os.environ.get('ASCEND_TOOLKIT_HOME'):\n",
    "    raise EnvironmentError('ASCEND_TOOLKIT_HOME is not set. Please activate the CANN environment before running this notebook.')\n",
    "cann_script = Path(os.environ['ASCEND_TOOLKIT_HOME']) / 'set_env.sh'\n",
    "env_cmd = f'source {cann_script} && env'\n",
    "env = subprocess.check_output(f\"bash -lc '{env_cmd}'\", shell=True, text=True, cwd=REPO_ROOT)\n",
    "for line in env.splitlines():\n",
    "    if '=' in line:\n",
    "        os.environ.__setitem__(*line.split('=', 1))\n",
    "\n",
    "import time\n",
    "\n",
    "if not UPSTREAM_SANA_DIR.exists():\n",
    "    max_retries = 3\n",
    "    clone_success = False\n",
    "    for attempt in range(max_retries):\n",
    "        try:\n",
    "            print(f'Cloning Sana repository (attempt {attempt + 1}/{max_retries})...')\n",
    "            subprocess.run(\n",
    "                ['git', 'clone', 'https://github.com/NVlabs/Sana.git', str(UPSTREAM_SANA_DIR)],\n",
    "                check=True,\n",
    "                timeout=300,\n",
    "            )\n",
    "            clone_success = True\n",
    "            break\n",
    "        except subprocess.TimeoutExpired:\n",
    "            print(f'Clone timed out after 300 seconds')\n",
    "            if attempt < max_retries - 1:\n",
    "                print('Retrying in 5 seconds...')\n",
    "                time.sleep(5)\n",
    "        except subprocess.CalledProcessError as e:\n",
    "            print(f'Clone failed with exit code {e.returncode}')\n",
    "            if attempt < max_retries - 1:\n",
    "                print('Retrying in 5 seconds...')\n",
    "                time.sleep(5)\n",
    "    if not clone_success:\n",
    "        print('\\nERROR: Failed to clone after 3 attempts.')\n",
    "        print('Please try running manually in a terminal:')\n",
    "        print(f'  git clone https://github.com/NVlabs/Sana.git {UPSTREAM_SANA_DIR}')\n",
    "        raise RuntimeError('Failed to clone Sana repository')\n",
    "subprocess.run(['git', 'checkout', '08c656c3'], cwd=UPSTREAM_SANA_DIR, check=True)\n",
    "WORKSPACE.mkdir(parents=True, exist_ok=True)\n",
    "(WORKSPACE / 'inference_video_scripts').mkdir(parents=True, exist_ok=True)\n",
    "(WORKSPACE / 'asset' / 'samples').mkdir(parents=True, exist_ok=True)\n",
    "shutil.copy2(TUTORIAL_DIR / 'src' / 'pyproject.toml', WORKSPACE / 'pyproject.toml')\n",
    "shutil.copy2(TUTORIAL_DIR / 'src' / 'sana_npu_adaptation.py', WORKSPACE / 'sana_npu_adaptation.py')\n",
    "shutil.copy2(TUTORIAL_DIR / 'src' / 'inference_video_scripts' / 'inference_sana_video.py', WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py')\n",
    "shutil.copy2(TUTORIAL_DIR / 'src' / 'samples' / 'video_prompts_samples.txt', WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt')\n",
    "subprocess.run(\n",
    "    ['bash', '-lc', f'cp -rn \"{UPSTREAM_SANA_DIR}/.\" \"{WORKSPACE}\"'],\n",
    "    check=True,\n",
    ")\n",
    "print('Workspace ready:', WORKSPACE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 采集 Baseline Profiling\n",
    "这里启用 `torch_npu.profiler`,同时关闭视频保存,避免编码阶段干扰模型采样分析,这里使用 10 个 sample step,本节的时间结果主要用于profiling分析。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROFILE_WORK_DIR = SOURCE_DIR / 'run_outputs' / 'baseline_profile'\n",
    "PROFILE_DIR = SOURCE_DIR / 'profiler' / 'baseline'\n",
    "PROFILE_WORK_DIR.mkdir(parents=True, exist_ok=True)\n",
    "PROFILE_DIR.mkdir(parents=True, exist_ok=True)\n",
    "cmd = [\n",
    "    sys.executable,\n",
    "    str(WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py'),\n",
    "    '--config', str(WORKSPACE / 'configs' / 'sana_video_config' / 'Sana_2000M_480px_AdamW_fsdp.yaml'),\n",
    "    '--model_path', 'hf://Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth',\n",
    "    '--txt_file', str(WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt'),\n",
    "    '--cfg_scale', '6',\n",
    "    '--motion_score', '30',\n",
    "    '--flow_shift', '8',\n",
    "    '--work_dir', str(PROFILE_WORK_DIR),\n",
    "    '--sample_nums', '1',\n",
    "    '--step', '10',\n",
    "    '--metrics_tag', 'baseline_profile',\n",
    "    '--skip_save', 'True',\n",
    "    '--enable_torch_profiler', 'True',\n",
    "    '--profiler_dir', str(PROFILE_DIR),\n",
    "    '--profiler_active', '1',\n",
    "    '--profiler_with_stack', 'True',\n",
    "    '--model.fp32_attention', 'False',\n",
    "]\n",
    "subprocess.run(cmd, cwd=WORKSPACE, check=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 查看 Profiler 摘要\n",
    "先观察整网 step 级时间与全局热点,再结合下一节的 `RMSNorm` 源码理解为什么值得尝试融合替换。\n",
    "\n",
    "说明:\n",
    "- `mean_single_step_latency_s` 统计的是 `solver.sample(...)` 总采样时间除以 `sample_steps`,只反映采样阶段。\n",
    "- `Step trace` 中的 `Stage / Computing / Free` 对应一次 profiler step 的整段 batch 时间,会包含采样前后的其他 batch 内开销。\n",
    "- 运行时 tqdm 显示的 `xx s/it` 只反映 DPM-Solver multistep 主循环中可见迭代的耗时;由于前面还有初始化和 warmup,因此不会与最终平均值完全一致。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "profile_metrics_path = PROFILE_WORK_DIR / 'metrics' / 'baseline_profile_summary.json'\n",
    "profile_metrics = json.loads(profile_metrics_path.read_text(encoding='utf-8'))\n",
    "prof_output = next(PROFILE_DIR.rglob('ASCEND_PROFILER_OUTPUT'))\n",
    "print('profiler output =', prof_output)\n",
    "print()\n",
    "print('Single-step latency baseline:')\n",
    "print(f\"  sample_steps: {profile_metrics['sample_steps']}\")\n",
    "print(f\"  mean_sampling_time_s: {profile_metrics['mean_sampling_time_s']:.2f} s\")\n",
    "print(f\"  mean_single_step_latency_s: {profile_metrics['mean_single_step_latency_s']:.4f} s\")\n",
    "print()\n",
    "with (prof_output / 'step_trace_time.csv').open(newline='', encoding='utf-8') as f:\n",
    "    step_trace = next(csv.DictReader(f))\n",
    "\n",
    "print('Step trace:')\n",
    "for key in ['Stage', 'Computing', 'Free', 'Preparing']:\n",
    "    print(f\"  {key}: {float(step_trace[key]) / 1000:.2f} ms\")\n",
    "\n",
    "agg = defaultdict(lambda: {'count': 0, 'total_us': 0.0, 'ratio': 0.0})\n",
    "with (prof_output / 'op_statistic.csv').open(newline='', encoding='utf-8') as f:\n",
    "    reader = csv.DictReader(f)\n",
    "    for row in reader:\n",
    "        op = row['OP Type']\n",
    "        agg[op]['count'] += int(row['Count'])\n",
    "        agg[op]['total_us'] += float(row['Total Time(us)'])\n",
    "        agg[op]['ratio'] += float(row['Ratio(%)'])\n",
    "\n",
    "rows = sorted(\n",
    "    ({'op': op, **stats} for op, stats in agg.items()),\n",
    "    key=lambda r: r['total_us'],\n",
    "    reverse=True,\n",
    ")\n",
    "\n",
    "print()\n",
    "print('Top 10 ops by total time:')\n",
    "for row in rows[:10]:\n",
    "    print(\n",
    "        f\"{row['op']:<15} count={row['count']:>5} total={row['total_us']/1000:>10.2f} ms ratio={row['ratio']:.3f}%\"\n",
    "    )\n",
    "\n",
    "focus_ops = ['TransData', 'Mul', 'Cast', 'Pows', 'RealDiv', 'Rsqrt']\n",
    "print()\n",
    "print('Ops relevant to decomposed RMSNorm:')\n",
    "for name in focus_ops:\n",
    "    row = next((r for r in rows if r['op'] == name), None)\n",
    "    if row:\n",
    "        print(\n",
    "            f\"{row['op']:<15} count={row['count']:>5} total={row['total_us']/1000:>10.2f} ms ratio={row['ratio']:.3f}%\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('profiler_dir =', PROFILE_DIR)\n",
    "sorted(str(path.relative_to(PROFILE_DIR)) for path in PROFILE_DIR.rglob('*') if path.is_file())[:20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 查看 Baseline `RMSNorm` 源码\n",
    "上游 `Sana` 的 `RMSNorm` 在 Baseline 中仍是原始实现,会拆解为 `pow`、`mean`、`rsqrt`、`mul` 等小算子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "norms_lines = (WORKSPACE / 'diffusion' / 'model' / 'norms.py').read_text(encoding='utf-8').splitlines()\n",
    "for line_no in range(182, 232):\n",
    "    print(f'{line_no}: {norms_lines[line_no - 1]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 分析结论\n",
    "从 Baseline 的 Profiling 结果和 `RMSNorm` 源码可以看到:\n",
    "- 当前 `mean_single_step_latency_s` 仅作为 profiling 场景下的采样参考值;正式性能对比以下一节关闭 profiler 后的同口径指标为准。\n",
    "- 整网中 `Mul`、`Cast`、`Pows`、`TransData` 等小算子存在可见开销。\n",
    "- 结合 Baseline `RMSNorm` 源码,可以看到它仍是分解实现,因此适合在保持其余流程不变的前提下,尝试替换为 `torch_npu.npu_rms_norm`。\n",
    "\n",
    "下一节会在相同 prompt、seed 和 step 设置下,对比优化前后的性能变化。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}