@@ -27,10 +27,14 @@ try:
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
-from block_sparse_attn import block_sparse_attn_func
from PIL import Image
import numpy as np
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+import mindiesd
+from mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2 import rain_fusion_attention
+
# ----------------------------
# Local / window masks
@@ -167,39 +171,56 @@ def generate_causal_block_mask(batch_size, nheads, seqlen, local_num, window_siz
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).repeat(batch_size, nheads, 1, 1)
return causal_mask
+def get_mask_index(mask):
+ b, n, s1, s2 = mask.shape
+ device = mask.device
+
+ mask_reshaped = mask.reshape(-1, s1, s2)
+ batch_size = mask_reshaped.shape[0]
+
+ row_indices = torch.arange(s2, device=device).expand(batch_size, s1, s2)
+ sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32)
+ sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
+ valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
+ keep_mask = row_indices < valid_count
+ result = torch.where(keep_mask, sorted_vals, -1)
+
+ pos_matrix = result.reshape(b, n, s1, s2).to(torch.int64)
+ return pos_matrix
# ----------------------------
# Attention kernels
# ----------------------------
-def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
+def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, blockShape=None):
if attention_mask is not None:
- seqlen = q.shape[1]
- seqlen_kv = k.shape[1]
- q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
- k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
- v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
- cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
- cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
- head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
- streaming_info = None
- base_blockmask = attention_mask
- max_seqlen_q_ = seqlen
- max_seqlen_k_ = seqlen_kv
- p_dropout = 0.0
- x = block_sparse_attn_func(
+ batch, _, headDims = q.shape
+ headDim = headDims // num_heads
+
+ selectIdx = get_mask_index(attention_mask)
+ selectIdx = selectIdx[0].transpose(0,1)
+ selectNumIdx = attention_mask[0].transpose(0,1).sum(dim=-1)
+
+ blockShape = [blockShape, blockShape]
+ actualSeqLengthsHost = [q.shape[1] for _ in range(batch)]
+ actualSeqLengthsKvHost = [k.shape[1] for _ in range(batch)]
+ scale = headDim ** -0.5
+
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+
+ x = rain_fusion_attention(
q, k, v,
- cu_seqlens_q, cu_seqlens_k,
- head_mask_type,
- streaming_info,
- base_blockmask,
- max_seqlen_q_, max_seqlen_k_,
- p_dropout,
- deterministic=False,
- softmax_scale=None,
- is_causal=False,
- exact_streaming=False,
- return_attn_probs=False,
- ).unsqueeze(0)
+ scale=scale,
+ head_num=num_heads,
+ input_layout="BNSD",
+ select_idx=selectIdx,
+ select_num_idx=selectNumIdx,
+ blockshape=blockShape,
+ actual_seq_lengths=actualSeqLengthsHost,
+ actual_seq_lengths_kv=actualSeqLengthsKvHost,
+ )
+ x = rearrange(x, "b n s d -> b s n d", n=num_heads)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
@@ -264,10 +285,10 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
- x_out = torch.view_as_complex(x.to(torch.float64).reshape(
- x.shape[0], x.shape[1], x.shape[2], -1, 2))
- x_out = torch.view_as_real(x_out * freqs).flatten(2)
- return x_out.to(x.dtype)
+ cos, sin = torch.chunk(torch.view_as_real(freqs), 2, dim=-1)
+ cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+ sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+ return mindiesd.rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)
# ----------------------------
@@ -279,12 +300,8 @@ class RMSNorm(nn.Module):
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
- def norm(self, x):
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
-
def forward(self, x):
- dtype = x.dtype
- return self.norm(x.float()).to(dtype) * self.weight
+ return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
class AttentionModule(nn.Module):
@@ -292,8 +309,8 @@ class AttentionModule(nn.Module):
super().__init__()
self.num_heads = num_heads
- def forward(self, q, k, v, attention_mask=None):
- x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask)
+ def forward(self, q, k, v, attention_mask=None, blockShape=None):
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, blockShape=blockShape)
return x
@@ -361,8 +378,9 @@ class SelfAttention(nn.Module):
self.local_attn_mask_w = w//8
self.local_range = local_range
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
+ attention_mask[:, :, :, 0] = True
- x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
+ x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask, blockShape=block_s)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
@@ -12,6 +12,9 @@ from einops import rearrange
from diffsynth import ModelManager, FlashVSRFullPipeline
from utils.utils import Causal_LQ4x_Proj
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
def tensor2video(frames: torch.Tensor):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -63,7 +66,7 @@ def upscale_then_center_crop(img: Image.Image, scale: int, tW: int, tH: int) ->
l = max(0, (sW - tW) // 2); t = max(0, (sH - tH) // 2)
return up.crop((l, t, l + tW, t + tH))
-def prepare_input_tensor(path: str, scale: int = 4, dtype=torch.bfloat16, device='cuda'):
+def prepare_input_tensor(path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda'):
if os.path.isdir(path):
paths0 = list_images_natural(path)
if not paths0:
@@ -177,6 +180,23 @@ def init_pipeline():
pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
return pipe
+def warm_up(pipeline, path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda', sparse_ratio: float=2.0, seed=0):
+ name = os.path.basename(path.rstrip('/'))
+ try:
+ LQ, th, tw, F, fps = prepare_input_tensor(path, scale=scale, dtype=dtype, device=device)
+ except Exception as e:
+ print(f"[Error] {name}: {e}")
+ pipeline(
+ prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
+ tiled=False,# Disable tiling: faster inference but higher VRAM usage.
+ # Set to True for lower memory consumption at the cost of speed.
+ LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
+ topk_ratio=sparse_ratio*768*1280/(th*tw),
+ kv_ratio=3.0,
+ local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
+ color_fix = True,
+ )
+
def main():
RESULT_ROOT = "./results"
os.makedirs(RESULT_ROOT, exist_ok=True)
@@ -186,10 +206,14 @@ def main():
"./inputs/example2.mp4",
"./inputs/example3.mp4",
]
- seed, scale, dtype, device = 0, 4, torch.bfloat16, 'cuda'
+ warm_up_file = "./inputs/example0.mp4"
+ seed, scale, dtype, device = 0, 2, torch.bfloat16, 'cuda'
sparse_ratio = 2.0 # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
pipe = init_pipeline()
+ # warm up
+ warm_up(pipe, warm_up_file, scale, dtype, device, sparse_ratio, seed)
+
for p in inputs:
torch.cuda.empty_cache(); torch.cuda.ipc_collect()
name = os.path.basename(p.rstrip('/'))
@@ -201,6 +225,7 @@ def main():
print(f"[Error] {name}: {e}")
continue
+ inference_start_time = time.time()
video = pipe(
prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
tiled=False,# Disable tiling: faster inference but higher VRAM usage.
@@ -211,6 +236,10 @@ def main():
local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
color_fix = True,
)
+ inference_end_time = time.time()
+ inference_duration = inference_end_time - inference_start_time
+ print(f"Inference completed in {inference_duration:.2f} seconds")
+
video = tensor2video(video)
save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Full_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=6)
print("Done.")
@@ -13,6 +13,9 @@ from diffsynth import ModelManager, FlashVSRTinyPipeline
from utils.utils import Causal_LQ4x_Proj
from utils.TCDecoder import build_tcdecoder
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
def tensor2video(frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -84,7 +87,7 @@ def upscale_then_center_crop(img: Image.Image, scale: float, tW: int, tH: int) -
return up.crop((l, t, l + tW, t + tH))
-def prepare_input_tensor(path: str, scale: float = 4, dtype=torch.bfloat16, device='cuda'):
+def prepare_input_tensor(path: str, scale: float = 2, dtype=torch.bfloat16, device='cuda'):
if os.path.isdir(path):
paths0 = list_images_natural(path)
if not paths0:
@@ -193,6 +196,23 @@ def init_pipeline():
pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
return pipe
+def warm_up(pipeline, path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda', sparse_ratio: float=2.0, seed=0):
+ name = os.path.basename(path.rstrip('/'))
+ try:
+ LQ, th, tw, F, fps = prepare_input_tensor(path, scale=scale, dtype=dtype, device=device)
+ except Exception as e:
+ print(f"[Error] {name}: {e}")
+ pipeline(
+ prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
+ tiled=False,# Disable tiling: faster inference but higher VRAM usage.
+ # Set to True for lower memory consumption at the cost of speed.
+ LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
+ topk_ratio=sparse_ratio*768*1280/(th*tw),
+ kv_ratio=3.0,
+ local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
+ color_fix = True,
+ )
+
def main():
RESULT_ROOT = "./results"
os.makedirs(RESULT_ROOT, exist_ok=True)
@@ -202,10 +222,14 @@ def main():
"./inputs/example2.mp4",
"./inputs/example3.mp4",
]
- seed, scale, dtype, device = 0, 4.0, torch.bfloat16, 'cuda'
+ warm_up_file = "./inputs/example0.mp4"
+ seed, scale, dtype, device = 0, 2.0, torch.bfloat16, 'cuda'
sparse_ratio = 2.0 # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
pipe = init_pipeline()
+ # warm up
+ warm_up(pipe, warm_up_file, scale, dtype, device, sparse_ratio, seed)
+
for p in inputs:
torch.cuda.empty_cache(); torch.cuda.ipc_collect()
name = os.path.basename(p.rstrip('/'))
@@ -216,6 +240,7 @@ def main():
except Exception as e:
print(f"[Error] {name}: {e}"); continue
+ inference_start_time = time.time()
video = pipe(
prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
@@ -224,6 +249,10 @@ def main():
local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
color_fix = True,
)
+ inference_end_time = time.time()
+ inference_duration = inference_end_time - inference_start_time
+ print(f"Inference completed in {inference_duration:.2f} seconds")
+
video = tensor2video(video)
save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Tiny_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=6)
@@ -13,6 +13,9 @@ from diffsynth import ModelManager, FlashVSRTinyLongPipeline
from utils.utils import Causal_LQ4x_Proj
from utils.TCDecoder import build_tcdecoder
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
def tensor2video(frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -84,7 +87,7 @@ def upscale_then_center_crop(img: Image.Image, scale: float, tW: int, tH: int) -
return up.crop((l, t, l + tW, t + tH))
-def prepare_input_tensor(path: str, scale: float = 4, dtype=torch.bfloat16, device='cuda'):
+def prepare_input_tensor(path: str, scale: float = 2, dtype=torch.bfloat16, device='cuda'):
if os.path.isdir(path):
paths0 = list_images_natural(path)
if not paths0:
@@ -197,16 +200,37 @@ def init_pipeline():
pipe.init_cross_kv(); pipe.load_models_to_device(["dit","vae"])
return pipe
+def warm_up(pipeline, path: str, scale: int = 2, dtype=torch.bfloat16, device='cuda', sparse_ratio: float=2.0,seed=0):
+ name = os.path.basename(path.rstrip('/'))
+ try:
+ LQ, th, tw, F, fps = prepare_input_tensor(path, scale=scale, dtype=dtype, device=device)
+ except Exception as e:
+ print(f"[Error] {name}: {e}")
+ pipeline(
+ prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
+ tiled=False,# Disable tiling: faster inference but higher VRAM usage.
+ # Set to True for lower memory consumption at the cost of speed.
+ LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
+ topk_ratio=sparse_ratio*768*1280/(th*tw),
+ kv_ratio=3.0,
+ local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
+ color_fix = True,
+ )
+
def main():
RESULT_ROOT = "./results"
os.makedirs(RESULT_ROOT, exist_ok=True)
inputs = [
"./inputs/example4.mp4",
]
- seed, scale, dtype, device = 0, 4.0, torch.bfloat16, 'cuda'
+ warm_up_file = "./inputs/example0.mp4"
+ seed, scale, dtype, device = 0, 2.0, torch.bfloat16, 'cuda'
sparse_ratio = 2.0 # Recommended: 1.5 or 2.0. 1.5 → faster; 2.0 → more stable.
pipe = init_pipeline()
+ # warm up
+ warm_up(pipe, warm_up_file, scale, dtype, device, sparse_ratio, seed)
+
for p in inputs:
torch.cuda.empty_cache(); torch.cuda.ipc_collect()
name = os.path.basename(p.rstrip('/'))
@@ -217,6 +241,7 @@ def main():
except Exception as e:
print(f"[Error] {name}: {e}"); continue
+ inference_start_time = time.time()
video = pipe(
prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed,
LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
@@ -225,6 +250,9 @@ def main():
local_range=11, # Recommended: 9 or 11. local_range=9 → sharper details; 11 → more stable results.
color_fix = True,
)
+ inference_end_time = time.time()
+ inference_duration = inference_end_time - inference_start_time
+ print(f"Inference completed in {inference_duration:.2f} seconds")
video = tensor2video(video)
save_video(video, os.path.join(RESULT_ROOT, f"FlashVSR_v1.1_Tiny_Long_{name.split('.')[0]}_seed{seed}.mp4"), fps=fps, quality=5)
@@ -1,8 +1,9 @@
-torch==2.6.0+cu124
-torchaudio==2.6.0+cu124
+torch==2.6.0
+torch-npu==2.6.0.post5
+torchaudio==2.6.0
torchmetrics==1.7.3
torchsde==0.2.6
-torchvision==0.21.0+cu124
+torchvision==0.21.0
accelerate==1.8.1
einops==0.8.1
huggingface-hub==0.34.4
@@ -22,4 +23,5 @@ protobuf==3.20.3
ftfy==6.3.1
pandas==2.3.0
tqdm
+modelscope
datasets
\ No newline at end of file