{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2 Baseline 跑通\n",
"\n",
"本节将跑通 `Sana-Video`,只覆盖模型在 NPU 上运行所需的最小兼容适配。\n",
"\n",
"**硬件要求**:\n",
"- NPU 显存:32GB(910B 满足要求)\n",
"- 主机内存:**最低 16GB**,**推荐 32GB**(模型加载峰值约 20GB)\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1 环境准备\n",
"定位教程目录、准备运行工作区,并把 CANN 环境变量导入当前 Notebook 进程。为加速模型下载,教程默认使用 HF-Mirror 镜像源。在线体验请直接在 GitCode Notebook 环境中执行,本地运行时再使用独立 Python/conda 环境。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import shutil\n",
"import subprocess\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
"def locate_repo_root():\n",
" candidates = []\n",
" if os.environ.get('GITCODE_REPO_ROOT'):\n",
" candidates.append(Path(os.environ['GITCODE_REPO_ROOT']))\n",
" candidates.extend([\n",
" Path('/opt/atomgit/cann-learning-hub'),\n",
" ])\n",
" try:\n",
" cwd = Path.cwd()\n",
" candidates.extend([cwd, *cwd.parents])\n",
" except FileNotFoundError:\n",
" pass\n",
"\n",
" seen = set()\n",
" for candidate in candidates:\n",
" key = str(candidate)\n",
" if key in seen:\n",
" continue\n",
" seen.add(key)\n",
" if (candidate / 'reference_practice/model_inference_optimization/sana_video/src').exists():\n",
" return candidate\n",
" raise FileNotFoundError('Cannot locate cann-learning-hub repository root.')\n",
"\n",
"REPO_ROOT = locate_repo_root()\n",
"os.chdir(REPO_ROOT)\n",
"TUTORIAL_DIR = REPO_ROOT / 'reference_practice' / 'model_inference_optimization' / 'sana_video'\n",
"SOURCE_DIR = REPO_ROOT / 'Sources' / 'model_inference_optimization' / 'sana_video'\n",
"UPSTREAM_SANA_DIR = SOURCE_DIR / 'Sana_upstream'\n",
"WORKSPACE = SOURCE_DIR / 'Sana'\n",
"SOURCE_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"if not os.environ.get('ASCEND_TOOLKIT_HOME'):\n",
" raise EnvironmentError('ASCEND_TOOLKIT_HOME is not set. Please activate the CANN environment before running this notebook.')\n",
"cann_script = Path(os.environ['ASCEND_TOOLKIT_HOME']) / 'set_env.sh'\n",
"env_cmd = f'source {cann_script} && env'\n",
"env = subprocess.check_output(f\"bash -lc '{env_cmd}'\", shell=True, text=True, cwd=REPO_ROOT)\n",
"for line in env.splitlines():\n",
" if '=' in line:\n",
" os.environ.__setitem__(*line.split('=', 1))\n",
"\n",
"print('REPO_ROOT =', REPO_ROOT)\n",
"print('UPSTREAM_SANA_DIR =', UPSTREAM_SANA_DIR)\n",
"print('WORKSPACE =', WORKSPACE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2 拉取上游 Sana 源码并构建教程工作区\n",
"Notebook 会先拉取上游 `Sana` 指定版本,再以教程目录自身为主体构建工作区。工作区中的 `pyproject.toml` 使用教程自带配置,上游代码以非覆盖方式复制进来。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"if not UPSTREAM_SANA_DIR.exists():\n",
" max_retries = 3\n",
" clone_success = False\n",
" for attempt in range(max_retries):\n",
" try:\n",
" print(f'Cloning Sana repository (attempt {attempt + 1}/{max_retries})...')\n",
" subprocess.run(\n",
" ['git', 'clone', 'https://github.com/NVlabs/Sana.git', str(UPSTREAM_SANA_DIR)],\n",
" check=True,\n",
" timeout=300,\n",
" )\n",
" clone_success = True\n",
" break\n",
" except subprocess.TimeoutExpired:\n",
" print(f'Clone timed out after 300 seconds')\n",
" if attempt < max_retries - 1:\n",
" print('Retrying in 5 seconds...')\n",
" time.sleep(5)\n",
" except subprocess.CalledProcessError as e:\n",
" print(f'Clone failed with exit code {e.returncode}')\n",
" if attempt < max_retries - 1:\n",
" print('Retrying in 5 seconds...')\n",
" time.sleep(5)\n",
" if not clone_success:\n",
" print('\\nERROR: Failed to clone after 3 attempts.')\n",
" print('Please try running manually in a terminal:')\n",
" print(f' git clone https://github.com/NVlabs/Sana.git {UPSTREAM_SANA_DIR}')\n",
" raise RuntimeError('Failed to clone Sana repository')\n",
"subprocess.run(['git', 'checkout', '08c656c3'], cwd=UPSTREAM_SANA_DIR, check=True)\n",
"subprocess.run(['git', 'rev-parse', 'HEAD'], cwd=UPSTREAM_SANA_DIR, check=True)\n",
"\n",
"WORKSPACE.mkdir(parents=True, exist_ok=True)\n",
"(WORKSPACE / 'inference_video_scripts').mkdir(parents=True, exist_ok=True)\n",
"(WORKSPACE / 'asset' / 'samples').mkdir(parents=True, exist_ok=True)\n",
"shutil.copy2(TUTORIAL_DIR / 'src' / 'pyproject.toml', WORKSPACE / 'pyproject.toml')\n",
"shutil.copy2(TUTORIAL_DIR / 'src' / 'sana_npu_adaptation.py', WORKSPACE / 'sana_npu_adaptation.py')\n",
"shutil.copy2(\n",
" TUTORIAL_DIR / 'src' / 'inference_video_scripts' / 'inference_sana_video.py',\n",
" WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py',\n",
")\n",
"shutil.copy2(\n",
" TUTORIAL_DIR / 'src' / 'samples' / 'video_prompts_samples.txt',\n",
" WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt',\n",
")\n",
"subprocess.run(\n",
" ['bash', '-lc', f'cp -rn \"{UPSTREAM_SANA_DIR}/.\" \"{WORKSPACE}\"'],\n",
" check=True,\n",
")\n",
"print('Workspace ready:', WORKSPACE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3 安装依赖\n",
"首次执行可能耗时较长(mmcv 编译约需 10 分钟)。为加速下载,教程默认使用清华 PyPI 镜像源。GitCode Notebook 环境会直接复用已预装的 `torch`、`torch_npu`、`torchvision` 与 `torchaudio`,这里只安装教程工作区的其余 Python 依赖。\n",
"\n",
"**注意**:安装过程中部分依赖(如 mmcv、clip 等)需从 GitHub 拉取源码,可能由于网络不稳定拉取失败,通常可通过**重新执行 cell 解决**。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PIP_INDEX_URL = 'https://pypi.tuna.tsinghua.edu.cn/simple'\n",
"PIP_TRUSTED_HOST = 'pypi.tuna.tsinghua.edu.cn'\n",
"\n",
"subprocess.run([\n",
" sys.executable, '-m', 'pip', 'install', '-e', str(WORKSPACE),\n",
" '-i', PIP_INDEX_URL,\n",
" '--trusted-host', PIP_TRUSTED_HOST,\n",
"], check=True)\n",
"subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', 'opencv-python'], check=False)\n",
"subprocess.run([\n",
" sys.executable, '-m', 'pip', 'install', 'opencv-python-headless==4.8.0.76',\n",
" '-i', PIP_INDEX_URL,\n",
" '--trusted-host', PIP_TRUSTED_HOST,\n",
"], check=True)\n",
"\n",
"MMCV_DIR = SOURCE_DIR / 'mmcv-1x'\n",
"if not MMCV_DIR.exists():\n",
" subprocess.run(['git', 'clone', '-b', '1.x', 'https://github.com/open-mmlab/mmcv.git', str(MMCV_DIR)], check=True)\n",
"subprocess.run(\n",
" 'MMCV_WITH_OPS=1 FORCE_NPU=1 pip install -e . --no-build-isolation',\n",
" cwd=MMCV_DIR,\n",
" shell=True,\n",
" check=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4 执行 Baseline 推理\n",
"下面脚本将加载模型权重,生成480p 5s的视频,本教程示例使用 20 个 sample step,当前环境下推理时长约为2.5分钟。\n",
"\n",
"首次运行会自动下载模型文件:\n",
"- **VAE**: `vae/Wan2.1_VAE.pth` (508 MB)\n",
"- **主模型**: `checkpoints/SANA_Video_2B_480p.pth` (8.25 GB)\n",
"\n",
"**预估下载时间**:约 10-20 分钟(取决于网络速度)\n",
"\n",
"**实时监控下载进度**(在 GitCode 终端执行):\n",
"```bash\n",
"# 查看已下载文件大小\n",
"du -sh ~/.cache/huggingface/hub/models--Efficient-Large-Model--SANA-Video_2B_480p/\n",
"\n",
"# 实时监控下载进度(每 5 秒刷新)\n",
"watch -n 5 'du -sh ~/.cache/huggingface/hub/models--Efficient-Large-Model--SANA-Video_2B_480p/'\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"BASELINE_WORK_DIR = SOURCE_DIR / 'run_outputs' / 'baseline_demo'\n",
"BASELINE_WORK_DIR.mkdir(parents=True, exist_ok=True)\n",
"cmd = [\n",
" sys.executable,\n",
" str(WORKSPACE / 'inference_video_scripts' / 'inference_sana_video.py'),\n",
" '--config', str(WORKSPACE / 'configs' / 'sana_video_config' / 'Sana_2000M_480px_AdamW_fsdp.yaml'),\n",
" '--model_path', 'hf://Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth',\n",
" '--txt_file', str(WORKSPACE / 'asset' / 'samples' / 'tutorial_video_prompts_samples.txt'),\n",
" '--cfg_scale', '6',\n",
" '--motion_score', '30',\n",
" '--flow_shift', '8',\n",
" '--work_dir', str(BASELINE_WORK_DIR),\n",
" '--sample_nums', '1',\n",
" '--step', '20',\n",
" '--metrics_tag', 'baseline_demo',\n",
" '--model.fp32_attention', 'False',\n",
"]\n",
"subprocess.run(cmd, cwd=WORKSPACE, check=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5 查看结果\n",
"教程版推理入口会在 `metrics/<tag>_summary.json` 中保存平均采样总时长与平均单步时延。生成的视频保存在 `metrics['save_root']` 指向的目录中。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metrics_path = BASELINE_WORK_DIR / 'metrics' / 'baseline_demo_summary.json'\n",
"metrics = json.loads(metrics_path.read_text(encoding='utf-8'))\n",
"metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"由于平台暂无法直接展示生成的.mp4文件,可执行下面脚本将视频压缩到指定目录,下载到本地并解压查看。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"video_dir = Path(metrics['save_root'])\n",
"print('video_dir =', video_dir)\n",
"\n",
"mp4_files = sorted(str(path) for path in video_dir.glob('*.mp4'))\n",
"print('生成的视频文件:')\n",
"for f in mp4_files:\n",
" print(f)\n",
"\n",
"archive_path = (video_dir.parent / \"baseline_demo_vis.tar.gz\").resolve()\n",
"!tar -czf \"{archive_path}\" -C \"{video_dir.parent}\" \"{video_dir.name}\"\n",
"\n",
"print('\\n已打包为:', archive_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}