@@ -27,10 +27,13 @@ try:
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
-from block_sparse_attn import block_sparse_attn_func
from PIL import Image
import numpy as np
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+import mindiesd
+from mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2 import rain_fusion_attention
# ----------------------------
# Local / window masks
@@ -193,38 +196,57 @@ def generate_causal_block_mask(batch_size, nheads, seqlen, local_num, window_siz
return causal_mask
+def get_mask_index(mask):
+ b, n, s1, s2 = mask.shape
+ device = mask.device
+
+ mask_reshaped = mask.reshape(-1, s1, s2)
+ batch_size = mask_reshaped.shape[0]
+
+ row_indices = torch.arange(s2, device=device).expand(batch_size, s1, s2)
+ sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32)
+ sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
+ valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
+ keep_mask = row_indices < valid_count
+ result = torch.where(keep_mask, sorted_vals, -1)
+
+ pos_matrix = result.reshape(b, n, s1, s2).to(torch.int64)
+ return pos_matrix
+
+
# ----------------------------
# Attention kernels
# ----------------------------
-def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
+def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False, blockShape=None):
if attention_mask is not None:
- seqlen = q.shape[1]
- seqlen_kv = k.shape[1]
- q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
- k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
- v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
- cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
- cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
- head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
- streaming_info = None
- base_blockmask = attention_mask
- max_seqlen_q_ = seqlen
- max_seqlen_k_ = seqlen_kv
- p_dropout = 0.0
- x = block_sparse_attn_func(
+ batch, _, headDims = q.shape
+ headDim = headDims // num_heads
+
+ selectIdx = get_mask_index(attention_mask)
+ selectIdx = selectIdx[0].transpose(0,1)
+ selectNumIdx = attention_mask[0].transpose(0,1).sum(dim=-1)
+
+ blockShape = [blockShape, blockShape]
+ actualSeqLengthsHost = [q.shape[1] for _ in range(batch)]
+ actualSeqLengthsKvHost = [k.shape[1] for _ in range(batch)]
+ scale = headDim ** -0.5
+
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads).transpose(1,2).contiguous()
+
+ x = rain_fusion_attention(
q, k, v,
- cu_seqlens_q, cu_seqlens_k,
- head_mask_type,
- streaming_info,
- base_blockmask,
- max_seqlen_q_, max_seqlen_k_,
- p_dropout,
- deterministic=False,
- softmax_scale=None,
- is_causal=False,
- exact_streaming=False,
- return_attn_probs=False,
- ).unsqueeze(0)
+ scale=scale,
+ head_num=num_heads,
+ input_layout="BNSD",
+ select_idx=selectIdx,
+ select_num_idx=selectNumIdx,
+ blockshape=blockShape,
+ actual_seq_lengths=actualSeqLengthsHost,
+ actual_seq_lengths_kv=actualSeqLengthsKvHost,
+ )
+ x = rearrange(x, "b n s d -> b s n d", n=num_heads)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
@@ -289,10 +311,10 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
- x_out = torch.view_as_complex(x.to(torch.float64).reshape(
- x.shape[0], x.shape[1], x.shape[2], -1, 2))
- x_out = torch.view_as_real(x_out * freqs).flatten(2)
- return x_out.to(x.dtype)
+ cos, sin = torch.chunk(torch.view_as_real(freqs), 2, dim=-1)
+ cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+ sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2)
+ return mindiesd.rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)
# ----------------------------
@@ -304,12 +326,8 @@ class RMSNorm(nn.Module):
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
- def norm(self, x):
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
-
def forward(self, x):
- dtype = x.dtype
- return self.norm(x.float()).to(dtype) * self.weight
+ return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
class AttentionModule(nn.Module):
@@ -317,8 +335,8 @@ class AttentionModule(nn.Module):
super().__init__()
self.num_heads = num_heads
- def forward(self, q, k, v, attention_mask=None):
- x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask)
+ def forward(self, q, k, v, attention_mask=None, blockShape=None):
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask, blockShape=blockShape)
return x
@@ -386,8 +404,9 @@ class SelfAttention(nn.Module):
self.local_attn_mask_w = w//8
self.local_range = local_range
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
+ attention_mask[:, :, :, 0] = True
- x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
+ x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask, blockShape=block_s)
cur_block_n, cur_block_s, _ = k_w.shape
cache_num = cur_block_n // one_len
@@ -19,6 +19,9 @@ from tqdm import tqdm
import torch
from einops import rearrange
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
# Add project path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
@@ -142,6 +145,8 @@ def parse_args():
help="Output FPS (default: match input or 30 for images)")
parser.add_argument("--quality", type=int, default=10,
help="Output video quality (0-10)")
+ parser.add_argument("--warmup_file", type=str, default="./inputs/example0.mp4",
+ help="Warm up file path")
# Other parameters
parser.add_argument("--device", type=str, default="cuda",
@@ -197,7 +202,6 @@ def save_video_with_audio_piped(frames, output_path, audio_source, fps=30, quali
# Approximation of quality to CRF/CQ: quality 10 -> crf 3, quality 5 -> crf 13, quality 0 -> crf 23
# For NVENC, we use -cq (Constant Quality) and -rc vbr
-
if NVENC_AVAILABLE:
# NVENC settings - optimized for speed
vcodec = 'h264_nvenc'
@@ -245,12 +249,20 @@ def save_video_with_audio_piped(frames, output_path, audio_source, fps=30, quali
try:
process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ for frame in tqdm(frames, desc="Saving"):
+ frame = np.array(frame)
+ if frame.shape[:2] != (h, w):
+ warnings.warn(f"帧尺寸不匹配,跳过!期望: {w}x{h}, 实际: {frame.shape[1]}x{frame.shape[0]}")
+ continue
+ process.stdin.write(frame.tobytes())
+
except FileNotFoundError:
print("Error: ffmpeg not found for piping.")
return False
-
- for frame in tqdm(frames, desc="Saving"):
- process.stdin.write(np.array(frame).tobytes())
+ except BrokenPipeError:
+ print("\n[错误] FFmpeg 管道断开!通常是:帧尺寸不匹配、编码不支持、无写入权限")
+ return False
out, err = process.communicate()
@@ -666,6 +678,113 @@ def init_pipeline(args):
return pipe, vae_system_instance
+def warm_up(pipe, args):
+ print()
+ print("warm up start!!!")
+ # Setup dtype
+ dtype_map = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ }
+ dtype = dtype_map.get(args.dtype, torch.bfloat16)
+
+ LQ, th, tw, F, fps, input_video_path, total_frames_orig, exact_h, exact_w = prepare_input_tensor(
+ args.warmup_file,
+ scale=args.scale,
+ dtype=dtype,
+ device=args.device
+ )
+
+ pipeline_kwargs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "cfg_scale": 1.0,
+ "num_inference_steps": 1,
+ "seed": args.seed,
+ "LQ_video": LQ,
+ "num_frames": F,
+ "height": th,
+ "width": tw,
+ "is_full_block": False,
+ "if_buffer": True,
+ "topk_ratio": args.sparse_ratio * 768 * 1280 / (th * tw),
+ "kv_ratio": args.kv_ratio,
+ "local_range": args.local_range,
+ "color_fix": args.color_fix,
+ }
+
+ # Add VAE tiling parameters (for full and tiny mode)
+ if args.tile_vae:
+ pipeline_kwargs["tiled"] = True
+ # Ensure we pass sensible tile_size and tile_stride for VAE
+ # args.tile_size is from CLI (default 256 for simple Tile Utils, but here for VAE it is latent size?)
+ # FlashVSRTiny default is (60, 104) ~= 480x832 pixels / 8
+ # If user provides explicit tile size, use it. Otherwise, let's pick a reasonable default or respect args.tile_size
+
+ # NOTE: args.tile_size is typically 256. 256 * 8 = 2048 pixels. This is a very large tile for VAE.
+ # It's likely args.tile_size is meant for DiT pixel blocks in other contexts?
+ # But 'apply_tiled_inference_simple' uses it as pixel size for DiT.
+ # For VAE here, 'tile_size' argument to __call__ expects Latent Size.
+ # If we use 256 output pixels -> 32 latent size.
+ # Let's assume if tile-dit is OFF, args.tile_size might be irrelevant or we should interpret it.
+ # If tile-dit is ON, args.tile_size is used for DiT.
+
+ # Let's interpret args.tile_size as PIXEL size for consistency if tile-dit is ON?
+ # No, for VAE tiling, usually we want bigger chunks than DiT tiling.
+ # Let's use a safe default if not specified, or derive from args.tile_size / 8
+ # If user didn't change default 256... 256/8 = 32.
+
+ # Actually, let's trust the defaults in the pipeline if user didn't specifying anything specific for VAE?
+ # But user might want to control it using CLI args.
+ # Let's pass args.tile_size // 8 if plausible.
+
+ # Current logic: If tile-vae is ON, we want to enforce tiling.
+ # Let's update tile_size and tile_stride in pipeline_kwargs
+
+ # Safely convert pixel tile size (args.tile_size) to latent size
+ # Assuming args.tile_size is "pixel size"
+ vae_tile_size_latent = max(32, args.tile_size // 8)
+ vae_overlap_latent = max(4, args.overlap // 8)
+
+ pipeline_kwargs["tile_size"] = (vae_tile_size_latent, vae_tile_size_latent)
+ pipeline_kwargs["tile_stride"] = (vae_tile_size_latent - vae_overlap_latent, vae_tile_size_latent - vae_overlap_latent)
+
+ print(f"VAE Tiling Enabled: tile_size (latent)={pipeline_kwargs['tile_size']}, stride={pipeline_kwargs['tile_stride']}")
+
+ if args.tile_dit:
+ print(f"Tiled DiT: tile_size={args.tile_size}, overlap={args.overlap}")
+
+ # Create a copy of pipeline_kwargs and remove LQ_video
+ tile_kwargs = pipeline_kwargs.copy()
+ tile_kwargs.pop('LQ_video', None) # Remove LQ_video because it's already passed as a positional argument
+
+ # Handle potential collision of 'tile_size' argument
+ # pipeline_kwargs might contain 'tile_size' (tuple) for VAE if tile-vae is on.
+ # apply_tiled_inference_simple takes 'tile_size' (int) for DiT.
+ # To avoid error, we pop 'tile_size' from kwargs and pass it as 'tile_size_vae' if present.
+ vae_tile_size_tuple = tile_kwargs.pop('tile_size', None)
+
+ # Tiled inference
+ apply_tiled_inference_simple(
+ pipe,
+ LQ,
+ tile_size=args.tile_size,
+ overlap=args.overlap,
+ tile_size_vae=vae_tile_size_tuple,
+ **tile_kwargs
+ )
+ else:
+ # print("Running inference...") # Controlled inside pipeline or tqdm
+ if args.mode == 'tiny-long':
+ msg = "Running inference (Streaming"
+ if args.tile_vae:
+ msg += " & VAE-tiled"
+ msg += ")..."
+ pipe(**pipeline_kwargs)
+ print("warm up end!!!")
+ print()
+
def main():
total_start_time = time.time()
@@ -751,6 +870,12 @@ def main():
print(f"Mode: {args.mode}")
# Prepare input
print(f"Processing: {args.input}")
+
+ # Initialize pipeline with VAE manager
+ pipe, vae_instance = init_pipeline(args)
+
+ # Warm up
+ warm_up(pipe, args)
# Set dtype
if args.dtype == "fp16":
@@ -771,15 +896,6 @@ def main():
if args.fps is not None:
fps = args.fps
- # Initialize pipeline with VAE manager
- pipe, vae_instance = init_pipeline(args)
-
- # Ensure LQ is on the correct device (it should be already if prepare_input_tensor kept it there)
- # This is a no-op if already on device, but safe to keep.
- if LQ.device.type != args.device:
- print(f"Moving input to {args.device}...")
- LQ = LQ.to(args.device)
-
# Determine output file name
input_name = os.path.basename(args.input.rstrip('/')).split('.')[0]
if os.path.isdir(args.output):
@@ -1,8 +1,9 @@
-torch==2.6.0+cu124
-torchaudio==2.6.0+cu124
+torch==2.6.0
+torch-npu==2.6.0.post5
+torchaudio==2.6.0
torchmetrics==1.7.3
torchsde==0.2.6
-torchvision==0.21.0+cu124
+torchvision==0.21.0
accelerate==1.8.1
einops==0.8.1
huggingface-hub==0.34.4