{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 Baseline 跑通\n",
    "\n",
    "本节将跑通 `Sana-Video`,只覆盖模型在 NPU 上运行所需的最小兼容适配。\n",
    "\n",
    "**硬件要求**:\n",
    "- NPU 显存:32GB(910B 满足要求)\n",
    "- 主机内存:**最低 16GB**,**推荐 32GB**(模型加载峰值约 20GB)\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 环境准备\n",
    "定位教程目录、准备运行工作区,并把 CANN 环境变量导入当前 Notebook 进程。为加速模型下载,教程默认使用 HF-Mirror 镜像源。在线体验请直接在 GitCode Notebook 环境中执行,本地运行时再使用独立 Python/conda 环境。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import shutil\n",
    "import subprocess\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "def locate_repo_root():\n",
    "    candidates = []\n",
    "    if os.environ.get('GITCODE_REPO_ROOT'):\n",
    "        candidates.append(Path(os.environ['GITCODE_REPO_ROOT']))\n",
    "    candidates.extend([\n",
    "        Path('/opt/atomgit/cann-learning-hub'),\n",
    "    ])\n",
    "    try:\n",
    "        cwd = Path.cwd()\n",
    "        candidates.extend([cwd, *cwd.parents])\n",
    "    except FileNotFoundError:\n",
    "        pass\n",
    "\n",
    "    seen = set()\n",
    "    for candidate in candidates:\n",
    "        key = str(candidate)\n",
    "        if key in seen:\n",
    "            continue\n",
    "        seen.add(key)\n",
    "        if (candidate / 'reference_practice/model_inference_optimization/sana_video/src').exists():\n",
    "            return candidate\n",
    "    raise FileNotFoundError('Cannot locate cann-learning-hub repository root.')\n",
    "\n",
    "REPO_ROOT = locate_repo_root()\n",
    "os.chdir(REPO_ROOT)\n",
    "TUTORIAL_DIR = REPO_ROOT / 'reference_practice' / 'model_inference_optimization' / 'sana_video'\n",
    "SOURCE_DIR = REPO_ROOT / 'Sources' / 'model_inference_optimization' / 'sana_video'\n",
    "UPSTREAM_SANA_DIR = SOURCE_DIR / 'Sana_upstream'\n",
    "WORKSPACE = SOURCE_DIR / 'Sana'\n",
    "SOURCE_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "if not os.environ.get('ASCEND_TOOLKIT_HOME'):\n",
    "    raise EnvironmentError('ASCEND_TOOLKIT_HOME is not set. Please activate the CANN environment before running this notebook.')\n",
    "cann_script = Path(os.environ['ASCEND_TOOLKIT_HOME']) / 'set_env.sh'\n",
    "env_cmd = f'source {cann_script} && env'\n",
    "env = subprocess.check_output(f\"bash -lc '{env_cmd}'\", shell=True, text=True, cwd=REPO_ROOT)\n",
    "for line in env.splitlines():\n",
    "    if '=' in line:\n",
    "        os.environ.__setitem__(*line.split('=', 1))\n",
    "\n",
    "print('REPO_ROOT =', REPO_ROOT)\n",
    "print('UPSTREAM_SANA_DIR =', UPSTREAM_SANA_DIR)\n",
    "print('WORKSPACE =', WORKSPACE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 拉取上游 Sana 源码并构建教程工作区\n",
    "Notebook 会先拉取上游 `Sana` 指定版本,再以教程目录自身为主体构建工作区。工作区中的 `pyproject.toml` 使用教程自带配置,上游代码以非覆盖方式复制进来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "if not UPSTREAM_SANA_DIR.exists():\n",
    "    max_retries = 3\n",
    "    clone_success = False\n",
    "    for attempt in range(max_retries):\n",
    "        try:\n",
    "            print(f'Cloning Sana repository (attempt {attempt + 1}/{max_retries})...')\n",
    "            subprocess.run(\n",
    "                ['git', 'clone', 'https://github.com/NVlabs/Sana.git', str(UPSTREAM_SANA_DIR)],\n",
    "                check=True,\n",
    "                timeout=300,\n",
    "            )\n",
    "            clone_success = True\n",
    "            break\n",
    "        except subprocess.TimeoutExpired:\n",
    "            print(f'Clone timed out after 300 seconds')\n",
    "            if attempt < max_retries - 1:\n",
    "                print('Retrying in 5 seconds...')\n",
    "                time.sleep(5)\n",
    "        except subprocess.CalledProcessError as e:\n",
    "            print(f'Clone failed with exit code {e.returncode}')\n",
    "            if attempt < max_retries - 1:\n",
    "                print('Retrying in 5 seconds...')\n",
    "                time.sleep(5)\n",
    "    if not clone_success:\n",
    "        print('\\nERROR: Failed to clone after 3 attempts.')\n",
    "        print('Please try running manually in a terminal:')\n",
    "        print(f'  git clone https://github.com/NVlabs/Sana.git {UPSTREAM_SANA_DIR}')\n",
    "        raise RuntimeError('Failed to clone Sana repository')\n",
    "subprocess.run(['git', 'checkout', '08c656c3'], cwd=UPSTREAM_SANA_DIR, check=True)\n",
    "subprocess.run(['git', 'rev-parse', 'HEAD'], cwd=UPSTREAM_SANA_DIR, check=True)\n",
    "\n",
    "WORKSPACE.mkdir(parents=True, exist_ok=True)\n",
    "(WORKSPACE / 'inference_video_scripts').mkdir(parents=True, exist_ok=True)\n",
    "(WORKSPACE / 'asset' / 'samples').mkdir(parents=True, exist_ok=True)\n",
    "shutil.copy2(TUTORIAL_DIR / 'src' / 'pyproject.toml', WORKSPACE / 'pyproject.toml')\n",
    "shutil.copy2(TUTORIAL_DIR / 'src' / 'sana_npu_adaptation.py', WORKSPACE / 'sana_npu_adaptation.py')\n",
    "shutil.copy2(\n",
    "    TUTORIAL_DIR / 'src' / 'inference_video_scripts' / 'inference_sana_video.py',\n",
    "    WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py',\n",
    ")\n",
    "shutil.copy2(\n",
    "    TUTORIAL_DIR / 'src' / 'samples' / 'video_prompts_samples.txt',\n",
    "    WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt',\n",
    ")\n",
    "subprocess.run(\n",
    "    ['bash', '-lc', f'cp -rn \"{UPSTREAM_SANA_DIR}/.\" \"{WORKSPACE}\"'],\n",
    "    check=True,\n",
    ")\n",
    "print('Workspace ready:', WORKSPACE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 安装依赖\n",
    "首次执行可能耗时较长(mmcv 编译约需 10 分钟)。为加速下载,教程默认使用清华 PyPI 镜像源。GitCode Notebook 环境会直接复用已预装的 `torch`、`torch_npu`、`torchvision` 与 `torchaudio`,这里只安装教程工作区的其余 Python 依赖。\n",
    "\n",
    "**注意**:安装过程中部分依赖(如 mmcv、clip 等)需从 GitHub 拉取源码,可能由于网络不稳定拉取失败,通常可通过**重新执行 cell 解决**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PIP_INDEX_URL = 'https://pypi.tuna.tsinghua.edu.cn/simple'\n",
    "PIP_TRUSTED_HOST = 'pypi.tuna.tsinghua.edu.cn'\n",
    "\n",
    "subprocess.run([\n",
    "    sys.executable, '-m', 'pip', 'install', '-e', str(WORKSPACE),\n",
    "    '-i', PIP_INDEX_URL,\n",
    "    '--trusted-host', PIP_TRUSTED_HOST,\n",
    "], check=True)\n",
    "subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', 'opencv-python'], check=False)\n",
    "subprocess.run([\n",
    "    sys.executable, '-m', 'pip', 'install', 'opencv-python-headless==4.8.0.76',\n",
    "    '-i', PIP_INDEX_URL,\n",
    "    '--trusted-host', PIP_TRUSTED_HOST,\n",
    "], check=True)\n",
    "\n",
    "MMCV_DIR = SOURCE_DIR / 'mmcv-1x'\n",
    "if not MMCV_DIR.exists():\n",
    "    subprocess.run(['git', 'clone', '-b', '1.x', 'https://github.com/open-mmlab/mmcv.git', str(MMCV_DIR)], check=True)\n",
    "subprocess.run(\n",
    "    'MMCV_WITH_OPS=1 FORCE_NPU=1 pip install -e . --no-build-isolation',\n",
    "    cwd=MMCV_DIR,\n",
    "    shell=True,\n",
    "    check=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 执行 Baseline 推理\n",
    "下面脚本将加载模型权重,生成480p 5s的视频,本教程示例使用 20 个 sample step,当前环境下推理时长约为2.5分钟。\n",
    "\n",
    "首次运行会自动下载模型文件:\n",
    "- **VAE**: `vae/Wan2.1_VAE.pth` (508 MB)\n",
    "- **主模型**: `checkpoints/SANA_Video_2B_480p.pth` (8.25 GB)\n",
    "\n",
    "**预估下载时间**:约 10-20 分钟(取决于网络速度)\n",
    "\n",
    "**实时监控下载进度**(在 GitCode 终端执行):\n",
    "```bash\n",
    "# 查看已下载文件大小\n",
    "du -sh ~/.cache/huggingface/hub/models--Efficient-Large-Model--SANA-Video_2B_480p/\n",
    "\n",
    "# 实时监控下载进度(每 5 秒刷新)\n",
    "watch -n 5 'du -sh ~/.cache/huggingface/hub/models--Efficient-Large-Model--SANA-Video_2B_480p/'\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BASELINE_WORK_DIR = SOURCE_DIR / 'run_outputs' / 'baseline_demo'\n",
    "BASELINE_WORK_DIR.mkdir(parents=True, exist_ok=True)\n",
    "cmd = [\n",
    "    sys.executable,\n",
    "    str(WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py'),\n",
    "    '--config', str(WORKSPACE / 'configs' / 'sana_video_config' / 'Sana_2000M_480px_AdamW_fsdp.yaml'),\n",
    "    '--model_path', 'hf://Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth',\n",
    "    '--txt_file', str(WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt'),\n",
    "    '--cfg_scale', '6',\n",
    "    '--motion_score', '30',\n",
    "    '--flow_shift', '8',\n",
    "    '--work_dir', str(BASELINE_WORK_DIR),\n",
    "    '--sample_nums', '1',\n",
    "    '--step', '20',\n",
    "    '--metrics_tag', 'baseline_demo',\n",
    "    '--model.fp32_attention', 'False',\n",
    "]\n",
    "subprocess.run(cmd, cwd=WORKSPACE, check=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 查看结果\n",
    "教程版推理入口会在 `metrics/<tag>_summary.json` 中保存平均采样总时长与平均单步时延。生成的视频保存在 `metrics['save_root']` 指向的目录中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_path = BASELINE_WORK_DIR / 'metrics' / 'baseline_demo_summary.json'\n",
    "metrics = json.loads(metrics_path.read_text(encoding='utf-8'))\n",
    "metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于平台暂无法直接展示生成的.mp4文件,可执行下面脚本将视频压缩到指定目录,下载到本地并解压查看。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "video_dir = Path(metrics['save_root'])\n",
    "print('video_dir =', video_dir)\n",
    "\n",
    "mp4_files = sorted(str(path) for path in video_dir.glob('*.mp4'))\n",
    "print('生成的视频文件:')\n",
    "for f in mp4_files:\n",
    "    print(f)\n",
    "\n",
    "archive_path = (video_dir.parent / \"baseline_demo_vis.tar.gz\").resolve()\n",
    "!tar -czf \"{archive_path}\" -C \"{video_dir.parent}\" \"{video_dir.name}\"\n",
    "\n",
    "print('\\n已打包为:', archive_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}