diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py
index 8045117..9e50296 100644
--- a/diffsynth/models/wan_video_dit.py
+++ b/diffsynth/models/wan_video_dit.py
@@ -27,10 +27,13 @@ try:
 except ModuleNotFoundError:
     SAGE_ATTN_AVAILABLE = False
 
-from block_sparse_attn import block_sparse_attn_func
 from PIL import Image
 import numpy as np
 
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+import mindiesd
+from mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2 import rain_fusion_attention 
 
 # ----------------------------
 # Local / window masks
@@ -193,38 +196,57 @@ def generate_causal_block_mask(batch_size, nheads, seqlen, local_num, window_siz
     return causal_mask
 
 
+def get_mask_index(mask):
+    b, n, s1, s2 = mask.shape
+    device = mask.device
+    
+    mask_reshaped = mask.reshape(-1, s1, s2)
+    batch_size = mask_reshaped.shape[0]
+    
+    row_indices = torch.arange(s2, device=device).expand(batch_size, s1, s2)
+    sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32)
+    sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
+    valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
+    keep_mask = row_indices < valid_count
+    result = torch.where(keep_mask, sorted_vals, -1)
+    
+    pos_matrix = result.reshape(b, n, s1, s2).to(torch.int64)
+    return pos_matrix
+
+
 # ----------------------------
 # Attention kernels
 # ----------------------------
-def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
+def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, blockShape=None):
     if attention_mask is not None:
-        seqlen = q.shape[1]
-        seqlen_kv = k.shape[1]
-        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
-        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
-        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
-        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
-        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
-        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
-        streaming_info = None
-        base_blockmask = attention_mask
-        max_seqlen_q_ = seqlen
-        max_seqlen_k_ = seqlen_kv
-        p_dropout = 0.0
-        x = block_sparse_attn_func(
+        batch, _, headDims = q.shape
+        headDim = headDims // num_heads
+
+        selectIdx = get_mask_index(attention_mask)
+        selectIdx = selectIdx[0].transpose(0,1)
+        selectNumIdx = attention_mask[0].transpose(0,1).sum(dim=-1)
+
+        blockShape = [blockShape, blockShape]
+        actualSeqLengthsHost = [q.shape[1] for _ in range(batch)]
+        actualSeqLengthsKvHost = [k.shape[1] for _ in range(batch)]
+        scale = headDim ** -0.5
+
+        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+
+        x = rain_fusion_attention(
             q, k, v,
-            cu_seqlens_q, cu_seqlens_k,
-            head_mask_type,
-            streaming_info,
-            base_blockmask,
-            max_seqlen_q_, max_seqlen_k_,
-            p_dropout,
-            deterministic=False,
-            softmax_scale=None,
-            is_causal=False,
-            exact_streaming=False,
-            return_attn_probs=False,
-        ).unsqueeze(0)
+            scale=scale,
+            head_num=num_heads,
+            input_layout="BNSD",
+            select_idx=selectIdx,
+            select_num_idx=selectNumIdx,
+            blockshape=blockShape,
+            actual_seq_lengths=actualSeqLengthsHost,
+            actual_seq_lengths_kv=actualSeqLengthsKvHost,
+        )
+        x = rearrange(x, "b n s d -> b s n d", n=num_heads)
         x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
     elif compatibility_mode:
         q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
@@ -289,10 +311,10 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
 
 def rope_apply(x, freqs, num_heads):
     x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
-    x_out = torch.view_as_complex(x.to(torch.float64).reshape(
-        x.shape[0], x.shape[1], x.shape[2], -1, 2))
-    x_out = torch.view_as_real(x_out * freqs).flatten(2)
-    return x_out.to(x.dtype)
+    cos, sin = torch.chunk(torch.view_as_real(freqs), 2, dim=-1)
+    cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+    sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+    return mindiesd.rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)
 
 
 # ----------------------------
@@ -304,12 +326,8 @@ class RMSNorm(nn.Module):
         self.eps = eps
         self.weight = nn.Parameter(torch.ones(dim))
 
-    def norm(self, x):
-        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
-
     def forward(self, x):
-        dtype = x.dtype
-        return self.norm(x.float()).to(dtype) * self.weight
+        return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
 
 
 class AttentionModule(nn.Module):
@@ -317,8 +335,8 @@ class AttentionModule(nn.Module):
         super().__init__()
         self.num_heads = num_heads
         
-    def forward(self, q, k, v, attention_mask=None):
-        x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask)
+    def forward(self, q, k, v, attention_mask=None, blockShape=None):
+        x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, blockShape=blockShape)
         return x
 
 
@@ -386,8 +404,9 @@ class SelfAttention(nn.Module):
             self.local_attn_mask_w = w//8
             self.local_range = local_range
         attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
+        attention_mask[:, :, :, 0] = True
 
-        x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
+        x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask, blockShape=block_s)
 
         cur_block_n, cur_block_s, _ = k_w.shape
         cache_num = cur_block_n // one_len
diff --git a/infer.py b/infer.py
index 4b09a30..6c338df 100644
--- a/infer.py
+++ b/infer.py
@@ -19,6 +19,9 @@ from tqdm import tqdm
 import torch

 from einops import rearrange

 

+import torch_npu

+from torch_npu.contrib import transfer_to_npu

+

 # Add project path

 current_dir = os.path.dirname(os.path.abspath(__file__))

 sys.path.append(current_dir)

@@ -142,6 +145,8 @@ def parse_args():
                        help="Output FPS (default: match input or 30 for images)")

     parser.add_argument("--quality", type=int, default=10,

                        help="Output video quality (0-10)")

+    parser.add_argument("--warmup_file", type=str, default="./inputs/example0.mp4",

+                       help="Warm up file path")

     

     # Other parameters

     parser.add_argument("--device", type=str, default="cuda",

@@ -197,7 +202,6 @@ def save_video_with_audio_piped(frames, output_path, audio_source, fps=30, quali
 

     # Approximation of quality to CRF/CQ: quality 10 -> crf 3, quality 5 -> crf 13, quality 0 -> crf 23

     # For NVENC, we use -cq (Constant Quality) and -rc vbr

-    

     if NVENC_AVAILABLE:

         # NVENC settings - optimized for speed

         vcodec = 'h264_nvenc'

@@ -245,12 +249,20 @@ def save_video_with_audio_piped(frames, output_path, audio_source, fps=30, quali
 

     try:

         process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

+

+        for frame in tqdm(frames, desc="Saving"):

+            frame = np.array(frame)

+            if frame.shape[:2] != (h, w):

+                warnings.warn(f"帧尺寸不匹配,跳过!期望: {w}x{h}, 实际: {frame.shape[1]}x{frame.shape[0]}")

+                continue

+            process.stdin.write(frame.tobytes())

+

     except FileNotFoundError:

         print("Error: ffmpeg not found for piping.")

         return False

-        

-    for frame in tqdm(frames, desc="Saving"):

-        process.stdin.write(np.array(frame).tobytes())

+    except BrokenPipeError:

+        print("\n[错误] FFmpeg 管道断开!通常是:帧尺寸不匹配、编码不支持、无写入权限")

+        return False

         

     out, err = process.communicate()

     

@@ -666,6 +678,113 @@ def init_pipeline(args):
     

     return pipe, vae_system_instance

 

+def warm_up(pipe, args):

+    print()

+    print("warm up start!!!")

+    # Setup dtype

+    dtype_map = {

+        "fp32": torch.float32,

+        "fp16": torch.float16,

+        "bf16": torch.bfloat16,

+    }

+    dtype = dtype_map.get(args.dtype, torch.bfloat16)

+

+    LQ, th, tw, F, fps, input_video_path, total_frames_orig, exact_h, exact_w = prepare_input_tensor(

+        args.warmup_file, 

+        scale=args.scale, 

+        dtype=dtype,

+        device=args.device

+    )

+    

+    pipeline_kwargs = {

+        "prompt": "", 

+        "negative_prompt": "", 

+        "cfg_scale": 1.0, 

+        "num_inference_steps": 1, 

+        "seed": args.seed,

+        "LQ_video": LQ, 

+        "num_frames": F, 

+        "height": th, 

+        "width": tw, 

+        "is_full_block": False, 

+        "if_buffer": True,

+        "topk_ratio": args.sparse_ratio * 768 * 1280 / (th * tw), 

+        "kv_ratio": args.kv_ratio,

+        "local_range": args.local_range,

+        "color_fix": args.color_fix,

+    }

+

+    # Add VAE tiling parameters (for full and tiny mode)

+    if args.tile_vae:

+        pipeline_kwargs["tiled"] = True

+        # Ensure we pass sensible tile_size and tile_stride for VAE

+        # args.tile_size is from CLI (default 256 for simple Tile Utils, but here for VAE it is latent size?)

+        # FlashVSRTiny default is (60, 104) ~= 480x832 pixels / 8

+        # If user provides explicit tile size, use it. Otherwise, let's pick a reasonable default or respect args.tile_size

+        

+        # NOTE: args.tile_size is typically 256. 256 * 8 = 2048 pixels. This is a very large tile for VAE.

+        # It's likely args.tile_size is meant for DiT pixel blocks in other contexts? 

+        # But 'apply_tiled_inference_simple' uses it as pixel size for DiT.

+        # For VAE here, 'tile_size' argument to __call__ expects Latent Size.

+        # If we use 256 output pixels -> 32 latent size.

+        # Let's assume if tile-dit is OFF, args.tile_size might be irrelevant or we should interpret it.

+        # If tile-dit is ON, args.tile_size is used for DiT.

+        

+        # Let's interpret args.tile_size as PIXEL size for consistency if tile-dit is ON?

+        # No, for VAE tiling, usually we want bigger chunks than DiT tiling. 

+        # Let's use a safe default if not specified, or derive from args.tile_size / 8

+        # If user didn't change default 256... 256/8 = 32. 

+        

+        # Actually, let's trust the defaults in the pipeline if user didn't specifying anything specific for VAE?

+        # But user might want to control it using CLI args.

+        # Let's pass args.tile_size // 8 if plausible.

+        

+        # Current logic: If tile-vae is ON, we want to enforce tiling.

+        # Let's update tile_size and tile_stride in pipeline_kwargs

+        

+        # Safely convert pixel tile size (args.tile_size) to latent size

+        # Assuming args.tile_size is "pixel size"

+        vae_tile_size_latent = max(32, args.tile_size // 8)

+        vae_overlap_latent = max(4, args.overlap // 8)

+        

+        pipeline_kwargs["tile_size"] = (vae_tile_size_latent, vae_tile_size_latent)

+        pipeline_kwargs["tile_stride"] = (vae_tile_size_latent - vae_overlap_latent, vae_tile_size_latent - vae_overlap_latent)

+        

+        print(f"VAE Tiling Enabled: tile_size (latent)={pipeline_kwargs['tile_size']}, stride={pipeline_kwargs['tile_stride']}")

+

+    if args.tile_dit:

+        print(f"Tiled DiT: tile_size={args.tile_size}, overlap={args.overlap}")

+

+        # Create a copy of pipeline_kwargs and remove LQ_video

+        tile_kwargs = pipeline_kwargs.copy()

+        tile_kwargs.pop('LQ_video', None)  # Remove LQ_video because it's already passed as a positional argument

+

+        # Handle potential collision of 'tile_size' argument

+        # pipeline_kwargs might contain 'tile_size' (tuple) for VAE if tile-vae is on.

+        # apply_tiled_inference_simple takes 'tile_size' (int) for DiT.

+        # To avoid error, we pop 'tile_size' from kwargs and pass it as 'tile_size_vae' if present.

+        vae_tile_size_tuple = tile_kwargs.pop('tile_size', None)

+        

+        # Tiled inference

+        apply_tiled_inference_simple(

+            pipe,

+            LQ,

+            tile_size=args.tile_size,

+            overlap=args.overlap,

+            tile_size_vae=vae_tile_size_tuple,

+            **tile_kwargs

+        )

+    else:

+        # print("Running inference...") # Controlled inside pipeline or tqdm

+        if args.mode == 'tiny-long':

+            msg = "Running inference (Streaming"

+            if args.tile_vae:

+                msg += " & VAE-tiled"

+            msg += ")..."

+        pipe(**pipeline_kwargs)

+    print("warm up end!!!")

+    print()

+

 def main():

     total_start_time = time.time()

     

@@ -751,6 +870,12 @@ def main():
     print(f"Mode: {args.mode}")

     # Prepare input

     print(f"Processing: {args.input}")

+

+    # Initialize pipeline with VAE manager

+    pipe, vae_instance = init_pipeline(args)

+    

+    # Warm up

+    warm_up(pipe, args)

     

     # Set dtype

     if args.dtype == "fp16":

@@ -771,15 +896,6 @@ def main():
     if args.fps is not None:

         fps = args.fps

     

-    # Initialize pipeline with VAE manager

-    pipe, vae_instance = init_pipeline(args)

-    

-    # Ensure LQ is on the correct device (it should be already if prepare_input_tensor kept it there)

-    # This is a no-op if already on device, but safe to keep.

-    if LQ.device.type != args.device:

-         print(f"Moving input to {args.device}...")

-         LQ = LQ.to(args.device)

-    

     # Determine output file name

     input_name = os.path.basename(args.input.rstrip('/')).split('.')[0]

     if os.path.isdir(args.output):

diff --git a/requirements.txt b/requirements.txt
index 7211450..f9de4b9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,9 @@
-torch==2.6.0+cu124
-torchaudio==2.6.0+cu124
+torch==2.6.0
+torch-npu==2.6.0.post5
+torchaudio==2.6.0
 torchmetrics==1.7.3
 torchsde==0.2.6
-torchvision==0.21.0+cu124
+torchvision==0.21.0
 accelerate==1.8.1
 einops==0.8.1
 huggingface-hub==0.34.4