@@ -1,11 +1,12 @@
import logging
import os
+import time
from pathlib import Path
import hydra
import torch
+import torch_npu
import torch.distributed as distributed
-import torchaudio
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from tqdm import tqdm
@@ -15,20 +16,46 @@ from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
+from mmaudio.utils.audio_io import safe_torchaudio_save
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
+if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+# Disable Transformer fast path to avoid unsupported op fallback on some NPUs.
+torch.backends.mha.set_fastpath_enabled(False)
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
log = logging.getLogger()
+def sync_device(device: str):
+ if device == 'npu':
+ torch.npu.synchronize()
+ elif device == 'cuda':
+ torch.cuda.synchronize()
+
+
+def percentile(values, q):
+ if not values:
+ return 0.0
+ sorted_values = sorted(values)
+ index = min(len(sorted_values) - 1, max(0, int(len(sorted_values) * q) - 1))
+ return sorted_values[index]
+
+
@torch.inference_mode()
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
def main(cfg: DictConfig):
- device = 'cuda'
- torch.cuda.set_device(local_rank)
+ if hasattr(torch, 'npu') and torch.npu.is_available():
+ device = 'npu'
+ torch.npu.set_device(local_rank)
+ elif torch.cuda.is_available():
+ device = 'cuda'
+ torch.cuda.set_device(local_rank)
+ else:
+ device = 'cpu'
if cfg.model not in all_model_cfg:
raise ValueError(f'Unknown model variant: {cfg.model}')
@@ -74,9 +101,32 @@ def main(cfg: DictConfig):
feature_utils.compile()
dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
+ benchmark = bool(cfg.get('benchmark', False))
+ warmup_batches = max(int(cfg.get('benchmark_warmup_batches', 0)), 0)
+ log_every_n = max(int(cfg.get('benchmark_log_every_n_batches', 0)), 0)
+
+ timed_batches = 0
+ timed_samples = 0
+ timed_audio_seconds = 0.0
+ total_gen_time_s = 0.0
+ total_save_time_s = 0.0
+ total_batch_time_s = 0.0
+ batch_latency_s = []
+
+ if benchmark:
+ log.info('[Benchmark] enabled: warmup_batches=%d, log_every_n_batches=%d', warmup_batches,
+ log_every_n)
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
- for batch in tqdm(loader):
+ for batch_idx, batch in enumerate(tqdm(loader)):
+ is_warmup = benchmark and batch_idx < warmup_batches
+ if benchmark and batch_idx == warmup_batches:
+ sync_device(device)
+ log.info('[Benchmark] warmup finished, start timing from batch index %d',
+ warmup_batches)
+
+ batch_start = time.perf_counter()
+ gen_start = time.perf_counter()
audios = generate(batch.get('clip_video', None),
batch.get('sync_video', None),
batch.get('caption', None),
@@ -87,14 +137,71 @@ def main(cfg: DictConfig):
cfg_strength=cfg.cfg_strength,
clip_batch_size_multiplier=64,
sync_batch_size_multiplier=64)
+ sync_device(device)
+ gen_time_s = time.perf_counter() - gen_start
+
+ save_start = time.perf_counter()
audios = audios.float().cpu()
names = batch['name']
for audio, name in zip(audios, names):
- torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
+ safe_torchaudio_save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
+ save_time_s = time.perf_counter() - save_start
+ batch_time_s = time.perf_counter() - batch_start
+
+ if benchmark and not is_warmup:
+ bs = len(names)
+ timed_batches += 1
+ timed_samples += bs
+ timed_audio_seconds += bs * float(seq_cfg.duration)
+ total_gen_time_s += gen_time_s
+ total_save_time_s += save_time_s
+ total_batch_time_s += batch_time_s
+ batch_latency_s.append(batch_time_s)
+
+ if log_every_n > 0 and timed_batches % log_every_n == 0:
+ log.info(
+ '[Benchmark][rank=%d] batches=%d samples=%d batch_time=%.4fs gen_time=%.4fs save_time=%.4fs',
+ local_rank, timed_batches, timed_samples, batch_time_s, gen_time_s, save_time_s)
+
+ if benchmark:
+ global_counts = torch.tensor([
+ float(timed_batches),
+ float(timed_samples),
+ float(timed_audio_seconds),
+ float(total_gen_time_s),
+ float(total_save_time_s),
+ ],
+ device=device)
+ distributed.all_reduce(global_counts, op=distributed.ReduceOp.SUM)
+
+ global_wall_time = torch.tensor(float(total_batch_time_s), device=device)
+ distributed.all_reduce(global_wall_time, op=distributed.ReduceOp.MAX)
+
+ global_audio_seconds_per_rank = torch.tensor(float(timed_audio_seconds), device=device)
+ distributed.all_reduce(global_audio_seconds_per_rank, op=distributed.ReduceOp.MAX)
+
+ global_batches = int(global_counts[0].item())
+ global_samples = int(global_counts[1].item())
+ global_batch_time_s = global_wall_time.item()
+ max_rank_audio_seconds = global_audio_seconds_per_rank.item()
+
+ if local_rank == 0 and global_batches > 0:
+ throughput_samples_per_s = global_samples / max(global_batch_time_s, 1e-6)
+ rtf = global_batch_time_s / max(max_rank_audio_seconds, 1e-6)
+
+ log.info('[Benchmark][global] batches=%d samples=%d', global_batches, global_samples)
+ log.info('[Benchmark][global] RTF=%.6f', rtf)
+ log.info('[Benchmark][global] throughput=%.4f samples/s', throughput_samples_per_s)
def distributed_setup():
- distributed.init_process_group(backend="nccl")
+ if hasattr(torch, 'npu') and torch.npu.is_available():
+ backend = 'hccl'
+ elif torch.cuda.is_available():
+ backend = 'nccl'
+ else:
+ backend = 'gloo'
+ distributed.init_process_group(backend=backend)
local_rank = distributed.get_rank()
world_size = distributed.get_world_size()
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
@@ -14,4 +14,9 @@ duration_s: 8.0
# for inference, this is the per-GPU batch size
batch_size: 16
-output_name: null
\ No newline at end of file
+output_name: null
+
+# benchmark controls (for batch_eval.py)
+benchmark: true
+benchmark_warmup_batches: 2
+benchmark_log_every_n_batches: 0
@@ -3,16 +3,21 @@ from argparse import ArgumentParser
from pathlib import Path
import torch
-import torchaudio
+import torch_npu
+
+# 禁用 PyTorch Transformer Fast Path(aten::_transformer_encoder_layer_fwd
+# 在 Ascend NPU 上不支持,否则触发 CPU fallback 和 D2H/H2D 数据搬移)
+torch.backends.mha.set_fastpath_enabled(False)
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
+from mmaudio.utils.audio_io import safe_torchaudio_save
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
+# torch.backends.cuda.matmul.allow_tf32 = True
+# torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
@@ -63,7 +68,9 @@ def main():
mask_away_clip: bool = args.mask_away_clip
device = 'cpu'
- if torch.cuda.is_available():
+ if torch.npu.is_available():
+ device = 'npu'
+ elif torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
@@ -126,7 +133,7 @@ def main():
else:
safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
save_path = output_dir / f'{safe_filename}.flac'
- torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
+ safe_torchaudio_save(save_path, audio, seq_cfg.sampling_rate)
log.info(f'Audio saved to {save_path}')
if video_path is not None and not skip_video_composite:
@@ -134,7 +141,10 @@ def main():
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
log.info(f'Video saved to {output_dir / video_save_path}')
- log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
+ if device == 'npu':
+ log.info('Memory usage: %.2f GB', torch.npu.max_memory_allocated() / (2**30))
+ else:
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
if __name__ == '__main__':
@@ -7,7 +7,11 @@ from pathlib import Path
import gradio as gr
import torch
-import torchaudio
+import torch_npu
+
+# 禁用 PyTorch Transformer Fast Path(aten::_transformer_encoder_layer_fwd
+# 在 Ascend NPU 上不支持,否则可能触发 CPU fallback 和 D2H/H2D 数据搬移)
+torch.backends.mha.set_fastpath_enabled(False)
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
load_video, make_video, setup_eval_logging)
@@ -15,14 +19,19 @@ from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
+from mmaudio.utils.audio_io import safe_torchaudio_save
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
+# NPU 推理路径下不使用 CUDA TF32 开关,避免误导/误配置。
+# torch.backends.cuda.matmul.allow_tf32 = True
+# torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
device = 'cpu'
-if torch.cuda.is_available():
+if torch.npu.is_available():
+ # Ascend NPU 可用时优先选择 NPU,保持与 demo.py 推理行为一致。
+ device = 'npu'
+elif torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
@@ -163,7 +172,7 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int,
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir.mkdir(exist_ok=True, parents=True)
audio_save_path = output_dir / f'{current_time_string}.flac'
- torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
+ safe_torchaudio_save(audio_save_path, audio, seq_cfg.sampling_rate)
gc.collect()
return audio_save_path
@@ -8,8 +8,8 @@ import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
-from torio.io import StreamingMediaDecoder
+from mmaudio.data.av_utils import read_frames
from mmaudio.utils.dist_utils import local_rank
log = logging.getLogger()
@@ -58,23 +58,12 @@ class VideoDataset(Dataset):
video_id = self.videos[idx]
caption = self.captions[video_id]
- reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
- reader.add_basic_video_stream(
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
- frame_rate=_CLIP_FPS,
- format='rgb24',
- )
- reader.add_basic_video_stream(
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
- frame_rate=_SYNC_FPS,
- format='rgb24',
- )
-
- reader.fill_buffer()
- data_chunk = reader.pop_chunks()
-
- clip_chunk = data_chunk[0]
- sync_chunk = data_chunk[1]
+ output_frames, _, _ = read_frames(self.video_root / (video_id + '.mp4'),
+ list_of_fps=[_CLIP_FPS, _SYNC_FPS],
+ start_sec=0,
+ end_sec=self.duration_sec,
+ need_all_frames=False)
+ clip_chunk, sync_chunk = output_frames
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
if clip_chunk.shape[0] < self.clip_expected_length:
@@ -49,4 +49,6 @@ class AutoEncoderModule(nn.Module):
@torch.inference_mode()
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
+ if isinstance(self.vocoder, BigVGANv2):
+ spec = spec.float()
return self.vocoder(spec)
@@ -29,6 +29,17 @@ def load_hparams_from_json(path) -> AttrDict:
return AttrDict(json.loads(data))
+def _floating_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.is_floating_point() and tensor.dtype != torch.float32:
+ return tensor.float()
+ return tensor
+
+
+def _keep_module_fp32(module: torch.nn.Module) -> torch.nn.Module:
+ """Convert all floating parameters/buffers of a module to fp32 in one place."""
+ return torch.nn.Module._apply(module, _floating_tensor_to_fp32)
+
+
class AMPBlock1(torch.nn.Module):
"""
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
@@ -233,6 +244,7 @@ class BigVGAN(
super().__init__()
self.h = h
self.h["use_cuda_kernel"] = use_cuda_kernel
+ self._keep_fp32 = False
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
@@ -304,6 +316,19 @@ class BigVGAN(
# Final tanh activation. Defaults to True for backward compatibility
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
+ def keep_fp32(self, enabled: bool = True):
+ """Keep BigVGAN parameters/buffers in fp32 even after parent .to(dtype)."""
+ self._keep_fp32 = enabled
+ if enabled:
+ _keep_module_fp32(self)
+ return self
+
+ def _apply(self, fn, recurse=True):
+ module = super()._apply(fn, recurse=recurse)
+ if self._keep_fp32:
+ super()._apply(_floating_tensor_to_fp32, recurse=recurse)
+ return module
+
def forward(self, x):
# Pre-conv
x = self.conv_pre(x)
@@ -441,4 +466,5 @@ class BigVGAN(
model.remove_weight_norm()
model.load_state_dict(checkpoint_dict["generator"])
+ model.keep_fp32()
return model
@@ -273,6 +273,9 @@ class BaseEncoderLayer(nn.TransformerEncoderLayer):
*args_transformer_enc,
**kwargs_transformer_enc):
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
+ # ===== NPU disable BetterTransformer nested tensor fast path =====
+ self.enable_nested_tensor = False
+ # =========================================================================
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
trunc_normal_(self.cls_token, std=.02)
@@ -46,8 +46,10 @@ class FeaturesUtils(nn.Module):
super().__init__()
if enable_conditions:
- self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',
- return_transform=False)
+ self.clip_model = create_model_from_pretrained(
+ model_name='ViT-H-14-378',
+ pretrained='./DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin',
+ return_transform=False)
self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
self.clip_model = patch_clip(self.clip_model)
@@ -50,6 +50,9 @@ dependencies = [
'av >= 14.0.1',
'timm >= 1.0.12',
'python-dotenv',
+ 'matplotlib',
+ 'soundfile',
+ 'datasets',
]
[tool.hatch.build.targets.wheel]