diff --git a/batch_eval.py b/batch_eval.py

index 535f5f6..4291e15 100644

--- a/batch_eval.py

+++ b/batch_eval.py

@@ -1,11 +1,12 @@

 import logging

 import os

+import time

 from pathlib import Path

 

 import hydra

 import torch

+import torch_npu

 import torch.distributed as distributed

-import torchaudio

 from hydra.core.hydra_config import HydraConfig

 from omegaconf import DictConfig

 from tqdm import tqdm

@@ -15,20 +16,46 @@ from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate

 from mmaudio.model.flow_matching import FlowMatching

 from mmaudio.model.networks import MMAudio, get_my_mmaudio

 from mmaudio.model.utils.features_utils import FeaturesUtils

+from mmaudio.utils.audio_io import safe_torchaudio_save

 

-torch.backends.cuda.matmul.allow_tf32 = True

-torch.backends.cudnn.allow_tf32 = True

+if torch.cuda.is_available():

+    torch.backends.cuda.matmul.allow_tf32 = True

+    torch.backends.cudnn.allow_tf32 = True

+

+# Disable Transformer fast path to avoid unsupported op fallback on some NPUs.

+torch.backends.mha.set_fastpath_enabled(False)

 

 local_rank = int(os.environ['LOCAL_RANK'])

 world_size = int(os.environ['WORLD_SIZE'])

 log = logging.getLogger()

 

 

+def sync_device(device: str):

+    if device == 'npu':

+        torch.npu.synchronize()

+    elif device == 'cuda':

+        torch.cuda.synchronize()

+

+

+def percentile(values, q):

+    if not values:

+        return 0.0

+    sorted_values = sorted(values)

+    index = min(len(sorted_values) - 1, max(0, int(len(sorted_values) * q) - 1))

+    return sorted_values[index]

+

+

 @torch.inference_mode()

 @hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')

 def main(cfg: DictConfig):

-    device = 'cuda'

-    torch.cuda.set_device(local_rank)

+    if hasattr(torch, 'npu') and torch.npu.is_available():

+        device = 'npu'

+        torch.npu.set_device(local_rank)

+    elif torch.cuda.is_available():

+        device = 'cuda'

+        torch.cuda.set_device(local_rank)

+    else:

+        device = 'cpu'

 

     if cfg.model not in all_model_cfg:

         raise ValueError(f'Unknown model variant: {cfg.model}')

@@ -74,9 +101,32 @@ def main(cfg: DictConfig):

         feature_utils.compile()

 

     dataset, loader = setup_eval_dataset(cfg.dataset, cfg)

+    benchmark = bool(cfg.get('benchmark', False))

+    warmup_batches = max(int(cfg.get('benchmark_warmup_batches', 0)), 0)

+    log_every_n = max(int(cfg.get('benchmark_log_every_n_batches', 0)), 0)

+

+    timed_batches = 0

+    timed_samples = 0

+    timed_audio_seconds = 0.0

+    total_gen_time_s = 0.0

+    total_save_time_s = 0.0

+    total_batch_time_s = 0.0

+    batch_latency_s = []

+

+    if benchmark:

+        log.info('[Benchmark] enabled: warmup_batches=%d, log_every_n_batches=%d', warmup_batches,

+                 log_every_n)

 

     with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):

-        for batch in tqdm(loader):

+        for batch_idx, batch in enumerate(tqdm(loader)):

+            is_warmup = benchmark and batch_idx < warmup_batches

+            if benchmark and batch_idx == warmup_batches:

+                sync_device(device)

+                log.info('[Benchmark] warmup finished, start timing from batch index %d',

+                         warmup_batches)

+

+            batch_start = time.perf_counter()

+            gen_start = time.perf_counter()

             audios = generate(batch.get('clip_video', None),

                               batch.get('sync_video', None),

                               batch.get('caption', None),

@@ -87,14 +137,71 @@ def main(cfg: DictConfig):

                               cfg_strength=cfg.cfg_strength,

                               clip_batch_size_multiplier=64,

                               sync_batch_size_multiplier=64)

+            sync_device(device)

+            gen_time_s = time.perf_counter() - gen_start

+

+            save_start = time.perf_counter()

             audios = audios.float().cpu()

             names = batch['name']

             for audio, name in zip(audios, names):

-                torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)

+                safe_torchaudio_save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)

+            save_time_s = time.perf_counter() - save_start

+            batch_time_s = time.perf_counter() - batch_start

+

+            if benchmark and not is_warmup:

+                bs = len(names)

+                timed_batches += 1

+                timed_samples += bs

+                timed_audio_seconds += bs * float(seq_cfg.duration)

+                total_gen_time_s += gen_time_s

+                total_save_time_s += save_time_s

+                total_batch_time_s += batch_time_s

+                batch_latency_s.append(batch_time_s)

+

+                if log_every_n > 0 and timed_batches % log_every_n == 0:

+                    log.info(

+                        '[Benchmark][rank=%d] batches=%d samples=%d batch_time=%.4fs gen_time=%.4fs save_time=%.4fs',

+                        local_rank, timed_batches, timed_samples, batch_time_s, gen_time_s, save_time_s)

+

+    if benchmark:

+        global_counts = torch.tensor([

+            float(timed_batches),

+            float(timed_samples),

+            float(timed_audio_seconds),

+            float(total_gen_time_s),

+            float(total_save_time_s),

+        ],

+                                     device=device)

+        distributed.all_reduce(global_counts, op=distributed.ReduceOp.SUM)

+

+        global_wall_time = torch.tensor(float(total_batch_time_s), device=device)

+        distributed.all_reduce(global_wall_time, op=distributed.ReduceOp.MAX)

+

+        global_audio_seconds_per_rank = torch.tensor(float(timed_audio_seconds), device=device)

+        distributed.all_reduce(global_audio_seconds_per_rank, op=distributed.ReduceOp.MAX)

+

+        global_batches = int(global_counts[0].item())

+        global_samples = int(global_counts[1].item())

+        global_batch_time_s = global_wall_time.item()

+        max_rank_audio_seconds = global_audio_seconds_per_rank.item()

+

+        if local_rank == 0 and global_batches > 0:

+            throughput_samples_per_s = global_samples / max(global_batch_time_s, 1e-6)

+            rtf = global_batch_time_s / max(max_rank_audio_seconds, 1e-6)

+

+            log.info('[Benchmark][global] batches=%d samples=%d', global_batches, global_samples)

+            log.info('[Benchmark][global] RTF=%.6f', rtf)

+            log.info('[Benchmark][global] throughput=%.4f samples/s', throughput_samples_per_s)

 

 

 def distributed_setup():

-    distributed.init_process_group(backend="nccl")

+    if hasattr(torch, 'npu') and torch.npu.is_available():

+        backend = 'hccl'

+    elif torch.cuda.is_available():

+        backend = 'nccl'

+    else:

+        backend = 'gloo'

+    distributed.init_process_group(backend=backend)

     local_rank = distributed.get_rank()

     world_size = distributed.get_world_size()

     log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')

diff --git a/config/eval_config.yaml b/config/eval_config.yaml

index f8d015b..b7da597 100644

--- a/config/eval_config.yaml

+++ b/config/eval_config.yaml

@@ -14,4 +14,9 @@ duration_s: 8.0

 

 # for inference, this is the per-GPU batch size

 batch_size: 16

-output_name: null

\ No newline at end of file

+output_name: null

+

+# benchmark controls (for batch_eval.py)

+benchmark: true

+benchmark_warmup_batches: 2

+benchmark_log_every_n_batches: 0

diff --git a/demo.py b/demo.py

index 9f073d6..db06410 100644

--- a/demo.py

+++ b/demo.py

@@ -3,16 +3,21 @@ from argparse import ArgumentParser

 from pathlib import Path

 

 import torch

-import torchaudio

+import torch_npu

+

+# 禁用 PyTorch Transformer Fast Path(aten::_transformer_encoder_layer_fwd

+# 在 Ascend NPU 上不支持,否则触发 CPU fallback 和 D2H/H2D 数据搬移)

+torch.backends.mha.set_fastpath_enabled(False)

 

 from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,

                                 setup_eval_logging)

 from mmaudio.model.flow_matching import FlowMatching

 from mmaudio.model.networks import MMAudio, get_my_mmaudio

 from mmaudio.model.utils.features_utils import FeaturesUtils

+from mmaudio.utils.audio_io import safe_torchaudio_save

 

-torch.backends.cuda.matmul.allow_tf32 = True

-torch.backends.cudnn.allow_tf32 = True

+# torch.backends.cuda.matmul.allow_tf32 = True

+# torch.backends.cudnn.allow_tf32 = True

 

 log = logging.getLogger()

 

@@ -63,7 +68,9 @@ def main():

     mask_away_clip: bool = args.mask_away_clip

 

     device = 'cpu'

-    if torch.cuda.is_available():

+    if torch.npu.is_available():

+        device = 'npu'

+    elif torch.cuda.is_available():

         device = 'cuda'

     elif torch.backends.mps.is_available():

         device = 'mps'

@@ -126,7 +133,7 @@ def main():

     else:

         safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')

         save_path = output_dir / f'{safe_filename}.flac'

-    torchaudio.save(save_path, audio, seq_cfg.sampling_rate)

+    safe_torchaudio_save(save_path, audio, seq_cfg.sampling_rate)

 

     log.info(f'Audio saved to {save_path}')

     if video_path is not None and not skip_video_composite:

@@ -134,7 +141,10 @@ def main():

         make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)

         log.info(f'Video saved to {output_dir / video_save_path}')

 

-    log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))

+    if device == 'npu':

+        log.info('Memory usage: %.2f GB', torch.npu.max_memory_allocated() / (2**30))

+    else:

+        log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))

 

 

 if __name__ == '__main__':

diff --git a/gradio_demo.py b/gradio_demo.py

index 7bbdbf5..b0dc47f 100644

--- a/gradio_demo.py

+++ b/gradio_demo.py

@@ -7,7 +7,11 @@ from pathlib import Path

 

 import gradio as gr

 import torch

-import torchaudio

+import torch_npu

+

+# 禁用 PyTorch Transformer Fast Path(aten::_transformer_encoder_layer_fwd

+# 在 Ascend NPU 上不支持,否则可能触发 CPU fallback 和 D2H/H2D 数据搬移)

+torch.backends.mha.set_fastpath_enabled(False)

 

 from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,

                                 load_video, make_video, setup_eval_logging)

@@ -15,14 +19,19 @@ from mmaudio.model.flow_matching import FlowMatching

 from mmaudio.model.networks import MMAudio, get_my_mmaudio

 from mmaudio.model.sequence_config import SequenceConfig

 from mmaudio.model.utils.features_utils import FeaturesUtils

+from mmaudio.utils.audio_io import safe_torchaudio_save

 

-torch.backends.cuda.matmul.allow_tf32 = True

-torch.backends.cudnn.allow_tf32 = True

+# NPU 推理路径下不使用 CUDA TF32 开关,避免误导/误配置。

+# torch.backends.cuda.matmul.allow_tf32 = True

+# torch.backends.cudnn.allow_tf32 = True

 

 log = logging.getLogger()

 

 device = 'cpu'

-if torch.cuda.is_available():

+if torch.npu.is_available():

+    # Ascend NPU 可用时优先选择 NPU,保持与 demo.py 推理行为一致。

+    device = 'npu'

+elif torch.cuda.is_available():

     device = 'cuda'

 elif torch.backends.mps.is_available():

     device = 'mps'

@@ -163,7 +172,7 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int,

     current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')

     output_dir.mkdir(exist_ok=True, parents=True)

     audio_save_path = output_dir / f'{current_time_string}.flac'

-    torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)

+    safe_torchaudio_save(audio_save_path, audio, seq_cfg.sampling_rate)

     gc.collect()

     return audio_save_path

 

diff --git a/mmaudio/data/eval/video_dataset.py b/mmaudio/data/eval/video_dataset.py

index 0b84a96..992da72 100644

--- a/mmaudio/data/eval/video_dataset.py

+++ b/mmaudio/data/eval/video_dataset.py

@@ -8,8 +8,8 @@ import pandas as pd

 import torch

 from torch.utils.data.dataset import Dataset

 from torchvision.transforms import v2

-from torio.io import StreamingMediaDecoder

 

+from mmaudio.data.av_utils import read_frames

 from mmaudio.utils.dist_utils import local_rank

 

 log = logging.getLogger()

@@ -58,23 +58,12 @@ class VideoDataset(Dataset):

         video_id = self.videos[idx]

         caption = self.captions[video_id]

 

-        reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))

-        reader.add_basic_video_stream(

-            frames_per_chunk=int(_CLIP_FPS * self.duration_sec),

-            frame_rate=_CLIP_FPS,

-            format='rgb24',

-        )

-        reader.add_basic_video_stream(

-            frames_per_chunk=int(_SYNC_FPS * self.duration_sec),

-            frame_rate=_SYNC_FPS,

-            format='rgb24',

-        )

-

-        reader.fill_buffer()

-        data_chunk = reader.pop_chunks()

-

-        clip_chunk = data_chunk[0]

-        sync_chunk = data_chunk[1]

+        output_frames, _, _ = read_frames(self.video_root / (video_id + '.mp4'),

+                                          list_of_fps=[_CLIP_FPS, _SYNC_FPS],

+                                          start_sec=0,

+                                          end_sec=self.duration_sec,

+                                          need_all_frames=False)

+        clip_chunk, sync_chunk = output_frames

         if clip_chunk is None:

             raise RuntimeError(f'CLIP video returned None {video_id}')

         if clip_chunk.shape[0] < self.clip_expected_length:

diff --git a/mmaudio/ext/autoencoder/autoencoder.py b/mmaudio/ext/autoencoder/autoencoder.py

index b77db4e..3181913 100644

--- a/mmaudio/ext/autoencoder/autoencoder.py

+++ b/mmaudio/ext/autoencoder/autoencoder.py

@@ -49,4 +49,6 @@ class AutoEncoderModule(nn.Module):

 

     @torch.inference_mode()

     def vocode(self, spec: torch.Tensor) -> torch.Tensor:

+        if isinstance(self.vocoder, BigVGANv2):

+            spec = spec.float()

         return self.vocoder(spec)

diff --git a/mmaudio/ext/bigvgan_v2/bigvgan.py b/mmaudio/ext/bigvgan_v2/bigvgan.py

index d92e9b8..4c28f0c 100644

--- a/mmaudio/ext/bigvgan_v2/bigvgan.py

+++ b/mmaudio/ext/bigvgan_v2/bigvgan.py

@@ -29,6 +29,17 @@ def load_hparams_from_json(path) -> AttrDict:

     return AttrDict(json.loads(data))

 

 

+def _floating_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:

+    if tensor.is_floating_point() and tensor.dtype != torch.float32:

+        return tensor.float()

+    return tensor

+

+

+def _keep_module_fp32(module: torch.nn.Module) -> torch.nn.Module:

+    """Convert all floating parameters/buffers of a module to fp32 in one place."""

+    return torch.nn.Module._apply(module, _floating_tensor_to_fp32)

+

+

 class AMPBlock1(torch.nn.Module):

     """

     AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.

@@ -233,6 +244,7 @@ class BigVGAN(

         super().__init__()

         self.h = h

         self.h["use_cuda_kernel"] = use_cuda_kernel

+        self._keep_fp32 = False

 

         # Select which Activation1d, lazy-load cuda version to ensure backward compatibility

         if self.h.get("use_cuda_kernel", False):

@@ -304,6 +316,19 @@ class BigVGAN(

         # Final tanh activation. Defaults to True for backward compatibility

         self.use_tanh_at_final = h.get("use_tanh_at_final", True)

 

+    def keep_fp32(self, enabled: bool = True):

+        """Keep BigVGAN parameters/buffers in fp32 even after parent .to(dtype)."""

+        self._keep_fp32 = enabled

+        if enabled:

+            _keep_module_fp32(self)

+        return self

+

+    def _apply(self, fn, recurse=True):

+        module = super()._apply(fn, recurse=recurse)

+        if self._keep_fp32:

+            super()._apply(_floating_tensor_to_fp32, recurse=recurse)

+        return module

+

     def forward(self, x):

         # Pre-conv

         x = self.conv_pre(x)

@@ -441,4 +466,5 @@ class BigVGAN(

             model.remove_weight_norm()

             model.load_state_dict(checkpoint_dict["generator"])

 

+        model.keep_fp32()

         return model

diff --git a/mmaudio/ext/synchformer/motionformer.py b/mmaudio/ext/synchformer/motionformer.py

index f02141e..31b064f 100644

--- a/mmaudio/ext/synchformer/motionformer.py

+++ b/mmaudio/ext/synchformer/motionformer.py

@@ -273,6 +273,9 @@ class BaseEncoderLayer(nn.TransformerEncoderLayer):

                  *args_transformer_enc,

                  **kwargs_transformer_enc):

         super().__init__(*args_transformer_enc, **kwargs_transformer_enc)

+        # ===== NPU disable BetterTransformer nested tensor fast path =====

+        self.enable_nested_tensor = False

+        # =========================================================================

         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))

         trunc_normal_(self.cls_token, std=.02)

 

diff --git a/mmaudio/model/utils/features_utils.py b/mmaudio/model/utils/features_utils.py

index 7e63c55..3d8b7f0 100644

--- a/mmaudio/model/utils/features_utils.py

+++ b/mmaudio/model/utils/features_utils.py

@@ -46,8 +46,10 @@ class FeaturesUtils(nn.Module):

         super().__init__()

 

         if enable_conditions:

-            self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',

-                                                           return_transform=False)

+            self.clip_model = create_model_from_pretrained(

+                model_name='ViT-H-14-378',

+                pretrained='./DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin',

+                return_transform=False)

             self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],

                                              std=[0.26862954, 0.26130258, 0.27577711])

             self.clip_model = patch_clip(self.clip_model)

diff --git a/pyproject.toml b/pyproject.toml

index db0c689..ba418f3 100644

--- a/pyproject.toml

+++ b/pyproject.toml

@@ -50,6 +50,9 @@ dependencies = [

   'av >= 14.0.1',

   'timm >= 1.0.12',

   'python-dotenv',

+  'matplotlib',

+  'soundfile',

+  'datasets',

 ]

 

 [tool.hatch.build.targets.wheel]