diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py
index a69e2bb..735bf43 100644
--- a/diffsynth/models/wan_video_dit.py
+++ b/diffsynth/models/wan_video_dit.py
@@ -27,10 +27,14 @@ try:
 except ModuleNotFoundError:
     SAGE_ATTN_AVAILABLE = False
 
-from block_sparse_attn import block_sparse_attn_func
 from PIL import Image
 import numpy as np
 
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+import mindiesd
+from mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2 import rain_fusion_attention 
+
 
 # ----------------------------
 # Local / window masks
@@ -167,39 +171,56 @@ def generate_causal_block_mask(batch_size, nheads, seqlen, local_num, window_siz
     causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).repeat(batch_size, nheads, 1, 1)
     return causal_mask
 
+def get_mask_index(mask):
+    b, n, s1, s2 = mask.shape
+    device = mask.device
+    
+    mask_reshaped = mask.reshape(-1, s1, s2)
+    batch_size = mask_reshaped.shape[0]
+    
+    row_indices = torch.arange(s2, device=device).expand(batch_size, s1, s2)
+    sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32)
+    sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
+    valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
+    keep_mask = row_indices < valid_count
+    result = torch.where(keep_mask, sorted_vals, -1)
+    
+    pos_matrix = result.reshape(b, n, s1, s2).to(torch.int64)
+    return pos_matrix
 
 # ----------------------------
 # Attention kernels
 # ----------------------------
-def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
+def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, blockShape=None):
     if attention_mask is not None:
-        seqlen = q.shape[1]
-        seqlen_kv = k.shape[1]
-        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
-        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
-        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
-        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
-        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
-        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
-        streaming_info = None
-        base_blockmask = attention_mask
-        max_seqlen_q_ = seqlen
-        max_seqlen_k_ = seqlen_kv
-        p_dropout = 0.0
-        x = block_sparse_attn_func(
+        batch, _, headDims = q.shape
+        headDim = headDims // num_heads
+
+        selectIdx = get_mask_index(attention_mask)
+        selectIdx = selectIdx[0].transpose(0,1)
+        selectNumIdx = attention_mask[0].transpose(0,1).sum(dim=-1)
+
+        blockShape = [blockShape, blockShape]
+        actualSeqLengthsHost = [q.shape[1] for _ in range(batch)]
+        actualSeqLengthsKvHost = [k.shape[1] for _ in range(batch)]
+        scale = headDim ** -0.5
+
+        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+
+        x = rain_fusion_attention(
             q, k, v,
-            cu_seqlens_q, cu_seqlens_k,
-            head_mask_type,
-            streaming_info,
-            base_blockmask,
-            max_seqlen_q_, max_seqlen_k_,
-            p_dropout,
-            deterministic=False,
-            softmax_scale=None,
-            is_causal=False,
-            exact_streaming=False,
-            return_attn_probs=False,
-        ).unsqueeze(0)
+            scale=scale,
+            head_num=num_heads,
+            input_layout="BNSD",
+            select_idx=selectIdx,
+            select_num_idx=selectNumIdx,
+            blockshape=blockShape,
+            actual_seq_lengths=actualSeqLengthsHost,
+            actual_seq_lengths_kv=actualSeqLengthsKvHost,
+        )
+        x = rearrange(x, "b n s d -> b s n d", n=num_heads)
         x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
     elif compatibility_mode:
         q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
@@ -264,10 +285,10 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
 
 def rope_apply(x, freqs, num_heads):
     x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
-    x_out = torch.view_as_complex(x.to(torch.float64).reshape(
-        x.shape[0], x.shape[1], x.shape[2], -1, 2))
-    x_out = torch.view_as_real(x_out * freqs).flatten(2)
-    return x_out.to(x.dtype)
+    cos, sin = torch.chunk(torch.view_as_real(freqs), 2, dim=-1)
+    cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+    sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+    return mindiesd.rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)
 
 
 # ----------------------------
@@ -279,12 +300,8 @@ class RMSNorm(nn.Module):
         self.eps = eps
         self.weight = nn.Parameter(torch.ones(dim))
 
-    def norm(self, x):
-        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
-
     def forward(self, x):
-        dtype = x.dtype
-        return self.norm(x.float()).to(dtype) * self.weight
+        return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
 
 
 class AttentionModule(nn.Module):
@@ -292,8 +309,8 @@ class AttentionModule(nn.Module):
         super().__init__()
         self.num_heads = num_heads
         
-    def forward(self, q, k, v, attention_mask=None):
-        x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask)
+    def forward(self, q, k, v, attention_mask=None, blockShape=None):
+        x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, blockShape=blockShape)
         return x
 
 
@@ -361,8 +378,9 @@ class SelfAttention(nn.Module):
             self.local_attn_mask_w = w//8
             self.local_range = local_range
         attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
+        attention_mask[:, :, :, 0] = True
 
-        x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
+        x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask, blockShape=block_s)
 
         cur_block_n, cur_block_s, _ = k_w.shape
         cache_num = cur_block_n // one_len
diff --git a/examples/WanVSR/infer_flashvsr_v1.1_full.py b/examples/WanVSR/infer_flashvsr_v1.1_full.py
index 1e5a535..3597bff 100644
--- a/examples/WanVSR/infer_flashvsr_v1.1_full.py
+++ b/examples/WanVSR/infer_flashvsr_v1.1_full.py
@@ -12,6 +12,9 @@ from einops import rearrange
 from diffsynth import ModelManager, FlashVSRFullPipeline
 from utils.utils import Causal_LQ4x_Proj
 
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
 def tensor2video(frames: torch.Tensor):
     frames = rearrange(frames, "C T H W -> T H W C")
     frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -63,7 +66,7 @@ def upscale_then_center_crop(img: Image.Image, scale: int, tW: int, tH: int) ->
     l = max(0, (sW - tW) // 2); t = max(0, (sH - tH) // 2)
     return up.crop((l, t, l + tW, t + tH))
 
-def prepare_input_tensor(path: str, scale: int = 4, dtype=torch.bfloat16, device='cuda'):
+def prepare_input_tensor(path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda'):
     if os.path.isdir(path):
         paths0 = list_images_natural(path)
         if not paths0:
@@ -177,6 +180,23 @@ def init_pipeline():
     pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
     return pipe
 
+def warm_up(pipeline, path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda', sparse_ratio: float=2.0, seed=0):
+    name = os.path.basename(path.rstrip('/'))
+    try:
+        LQ, th, tw, F, fps = prepare_input_tensor(path, scale=scale, dtype=dtype, device=device)
+    except Exception as e:
+        print(f"[Error] {name}: {e}")
+    pipeline(
+        prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, 
+        tiled=False,# Disable tiling: faster inference but higher VRAM usage. 
+                    # Set to True for lower memory consumption at the cost of speed.
+        LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
+        topk_ratio=sparse_ratio*768*1280/(th*tw), 
+        kv_ratio=3.0,
+        local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
+        color_fix = True,
+    )
+
 def main():
     RESULT_ROOT = "./results"
     os.makedirs(RESULT_ROOT, exist_ok=True)
@@ -186,10 +206,14 @@ def main():
         "./inputs/example2.mp4",
         "./inputs/example3.mp4",
     ]
-    seed, scale, dtype, device = 0, 4, torch.bfloat16, 'cuda'
+    warm_up_file = "./inputs/example0.mp4"
+    seed, scale, dtype, device = 0, 2, torch.bfloat16, 'cuda'
     sparse_ratio = 2.0      # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
     pipe = init_pipeline()
 
+    # warm up
+    warm_up(pipe, warm_up_file, scale, dtype, device, sparse_ratio, seed)
+
     for p in inputs:
         torch.cuda.empty_cache(); torch.cuda.ipc_collect()
         name = os.path.basename(p.rstrip('/'))
@@ -201,6 +225,7 @@ def main():
             print(f"[Error] {name}: {e}")
             continue
 
+        inference_start_time = time.time()
         video = pipe(
             prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, 
             tiled=False,# Disable tiling: faster inference but higher VRAM usage. 
@@ -211,6 +236,10 @@ def main():
             local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
             color_fix = True,
         )
+        inference_end_time = time.time()
+        inference_duration = inference_end_time - inference_start_time
+        print(f"Inference completed in {inference_duration:.2f} seconds")
+
         video = tensor2video(video)
         save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Full_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=6)
     print("Done.")
diff --git a/examples/WanVSR/infer_flashvsr_v1.1_tiny.py b/examples/WanVSR/infer_flashvsr_v1.1_tiny.py
index e847c3c..b22c366 100644
--- a/examples/WanVSR/infer_flashvsr_v1.1_tiny.py
+++ b/examples/WanVSR/infer_flashvsr_v1.1_tiny.py
@@ -13,6 +13,9 @@ from diffsynth import ModelManager, FlashVSRTinyPipeline
 from utils.utils import Causal_LQ4x_Proj
 from utils.TCDecoder import build_tcdecoder
 
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
 def tensor2video(frames):
     frames = rearrange(frames, "C T H W -> T H W C")
     frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -84,7 +87,7 @@ def upscale_then_center_crop(img: Image.Image, scale: float, tW: int, tH: int) -
     return up.crop((l, t, l + tW, t + tH))
 
 
-def prepare_input_tensor(path: str, scale: float = 4, dtype=torch.bfloat16, device='cuda'):
+def prepare_input_tensor(path: str, scale: float = 2, dtype=torch.bfloat16, device='cuda'):
     if os.path.isdir(path):
         paths0 = list_images_natural(path)
         if not paths0:
@@ -193,6 +196,23 @@ def init_pipeline():
     pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
     return pipe
 
+def warm_up(pipeline, path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda', sparse_ratio: float=2.0, seed=0):
+    name = os.path.basename(path.rstrip('/'))
+    try:
+        LQ, th, tw, F, fps = prepare_input_tensor(path, scale=scale, dtype=dtype, device=device)
+    except Exception as e:
+        print(f"[Error] {name}: {e}")
+    pipeline(
+        prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, 
+        tiled=False,# Disable tiling: faster inference but higher VRAM usage. 
+                    # Set to True for lower memory consumption at the cost of speed.
+        LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
+        topk_ratio=sparse_ratio*768*1280/(th*tw), 
+        kv_ratio=3.0,
+        local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
+        color_fix = True,
+    )
+
 def main():
     RESULT_ROOT = "./results"
     os.makedirs(RESULT_ROOT, exist_ok=True)
@@ -202,10 +222,14 @@ def main():
         "./inputs/example2.mp4",
         "./inputs/example3.mp4",
     ]
-    seed, scale, dtype, device = 0, 4.0, torch.bfloat16, 'cuda'
+    warm_up_file = "./inputs/example0.mp4"
+    seed, scale, dtype, device = 0, 2.0, torch.bfloat16, 'cuda'
     sparse_ratio = 2.0      # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
     pipe = init_pipeline()
 
+    # warm up
+    warm_up(pipe, warm_up_file, scale, dtype, device, sparse_ratio, seed)
+
     for p in inputs:
         torch.cuda.empty_cache(); torch.cuda.ipc_collect()
         name = os.path.basename(p.rstrip('/'))
@@ -216,6 +240,7 @@ def main():
         except Exception as e:
             print(f"[Error] {name}: {e}"); continue
 
+        inference_start_time = time.time()
         video = pipe(
             prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
             LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
@@ -224,6 +249,10 @@ def main():
             local_range=11,  # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
             color_fix = True,
         )
+        inference_end_time = time.time()
+        inference_duration = inference_end_time - inference_start_time
+        print(f"Inference completed in {inference_duration:.2f} seconds")
+
         video = tensor2video(video)
         save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Tiny_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=6)
 
diff --git a/examples/WanVSR/infer_flashvsr_v1.1_tiny_long_video.py b/examples/WanVSR/infer_flashvsr_v1.1_tiny_long_video.py
index 972d5a2..1225f98 100644
--- a/examples/WanVSR/infer_flashvsr_v1.1_tiny_long_video.py
+++ b/examples/WanVSR/infer_flashvsr_v1.1_tiny_long_video.py
@@ -13,6 +13,9 @@ from diffsynth import ModelManager, FlashVSRTinyLongPipeline
 from utils.utils import Causal_LQ4x_Proj
 from utils.TCDecoder import build_tcdecoder
 
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
 def tensor2video(frames):
     frames = rearrange(frames, "C T H W -> T H W C")
     frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -84,7 +87,7 @@ def upscale_then_center_crop(img: Image.Image, scale: float, tW: int, tH: int) -
     return up.crop((l, t, l + tW, t + tH))
 
 
-def prepare_input_tensor(path: str, scale: float = 4, dtype=torch.bfloat16, device='cuda'):
+def prepare_input_tensor(path: str, scale: float = 2, dtype=torch.bfloat16, device='cuda'):
     if os.path.isdir(path):
         paths0 = list_images_natural(path)
         if not paths0:
@@ -197,16 +200,37 @@ def init_pipeline():
     pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
     return pipe
 
+def warm_up(pipeline, path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda', sparse_ratio: float=2.0,seed=0):
+    name = os.path.basename(path.rstrip('/'))
+    try:
+        LQ, th, tw, F, fps = prepare_input_tensor(path, scale=scale, dtype=dtype, device=device)
+    except Exception as e:
+        print(f"[Error] {name}: {e}")
+    pipeline(
+        prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, 
+        tiled=False,# Disable tiling: faster inference but higher VRAM usage. 
+                    # Set to True for lower memory consumption at the cost of speed.
+        LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
+        topk_ratio=sparse_ratio*768*1280/(th*tw), 
+        kv_ratio=3.0,
+        local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
+        color_fix = True,
+    )
+
 def main():
     RESULT_ROOT = "./results"
     os.makedirs(RESULT_ROOT, exist_ok=True)
     inputs = [
         "./inputs/example4.mp4",
     ]
-    seed, scale, dtype, device = 0, 4.0, torch.bfloat16, 'cuda'
+    warm_up_file = "./inputs/example0.mp4"
+    seed, scale, dtype, device = 0, 2.0, torch.bfloat16, 'cuda'
     sparse_ratio = 2.0      # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
     pipe = init_pipeline()
 
+    # warm up
+    warm_up(pipe, warm_up_file, scale, dtype, device, sparse_ratio, seed)
+
     for p in inputs:
         torch.cuda.empty_cache(); torch.cuda.ipc_collect()
         name = os.path.basename(p.rstrip('/'))
@@ -217,6 +241,7 @@ def main():
         except Exception as e:
             print(f"[Error] {name}: {e}"); continue
 
+        inference_start_time = time.time()
         video = pipe(
             prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
             LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
@@ -225,6 +250,9 @@ def main():
             local_range=11,  # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
             color_fix = True,
         )
+        inference_end_time = time.time()
+        inference_duration = inference_end_time - inference_start_time
+        print(f"Inference completed in {inference_duration:.2f} seconds")
 
         video = tensor2video(video)
         save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Tiny_Long_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=5)
diff --git a/requirements.txt b/requirements.txt
index 99c216f..37c3064 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,9 @@
-torch==2.6.0+cu124
-torchaudio==2.6.0+cu124
+torch==2.6.0
+torch-npu==2.6.0.post5
+torchaudio==2.6.0
 torchmetrics==1.7.3
 torchsde==0.2.6
-torchvision==0.21.0+cu124
+torchvision==0.21.0
 accelerate==1.8.1
 einops==0.8.1
 huggingface-hub==0.34.4
@@ -22,4 +23,5 @@ protobuf==3.20.3
 ftfy==6.3.1
 pandas==2.3.0
 tqdm
+modelscope
 datasets
\ No newline at end of file