from abc import ABC, abstractmethod
import json
import math
import os
from pathlib import Path
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
try:
import torch_bsa
except ImportError:
torch_bsa = None
import torch_npu
import yaml
from loguru import logger
from module.unified_sp.uaa import all_gather_anything, _maybe_pad_qkv_head, _maybe_unpad_qkv_head
DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sparse_config.yaml")
def load_sparse_config_from_file(config_path=DEFAULT_CONFIG_PATH):
"""Load sparse config from YAML.
Supports two layouts:
1. Inline (new): the launch YAML has a top-level ``sparse`` section with
``method``, ``block_size_Q/K``, ``model`` and per-method params nested
under ``params.<Method>``. The section is flattened to the legacy shape
expected downstream.
2. Flat (legacy): a standalone ``sparse_config.yaml`` whose root is the
flat structure (``block_size_Q`` / ``block_size_K`` / ``model`` at the
top, method keys ``TopK`` / ``SVG`` holding per-method params).
"""
with open(config_path, 'r', encoding='utf-8') as f:
full_cfg = yaml.safe_load(f)
if isinstance(full_cfg, dict) and isinstance(full_cfg.get("sparse"), dict):
sparse = full_cfg["sparse"]
config = {
"block_size_Q": sparse.get("block_size_Q"),
"block_size_K": sparse.get("block_size_K"),
"model": sparse.get("model"),
}
for method, params in (sparse.get("params") or {}).items():
method_cfg = dict(params)
method_cfg.setdefault("predictor_name", method)
config[method] = method_cfg
else:
config = full_cfg
logger.info(config)
_validate_config_keys(config)
return config
def _validate_config_keys(config: dict):
required_keys = ["block_size_Q", "block_size_K", "model"]
missed_key = [k for k in required_keys if k not in config]
if missed_key:
raise ValueError(f"Missing required key(s): {','.join(missed_key)}")
def parse_sparse_time_step(value):
"""
输入: str,如"20, 30-40, 50"
输出: list[int]
"""
result = []
parts = [p.strip() for p in value.split(',')]
for part in parts:
if not part:
continue
if "-" in part:
try:
start, end = map(int, part.split("-"))
except ValueError as e:
raise ValueError(
f"Invalid range format: {part}"
) from e
if start > end:
raise ValueError(
f"Invalid range: {part} (start > end)"
)
result.extend(range(start, end + 1))
else:
try:
result.append(int(part))
except ValueError as e:
raise ValueError(
f"Invalid integer: {part}"
) from e
return sorted(set(result))
class SparsePredictorManager():
def __init__(self) -> None:
self.sparse_attn_mode = None
self.config = None
self.sparse_params = {}
def from_config(self, config_path, sparse_method, sparse_params=None):
self.config = load_sparse_config_from_file(config_path)
self.config[sparse_method]['sparse_time_step'] = \
parse_sparse_time_step(self.config[sparse_method]['sparse_time_step'])
if sparse_params is not None:
self.sparse_params.update(sparse_params)
if sparse_method == "TopK" and self.config['model'] == "HunyuanVideo":
self.sparse_attn_mode = HunyuanVideoTopKAdapter(self.config, self.sparse_params)
if sparse_method == "SVG" and self.config['model'] == "HunyuanVideo":
self.sparse_attn_mode = HunyuanVideoSVGAdapter(self.config, self.sparse_params)
sparse_predictor_manager = SparsePredictorManager()
def sync_and_get_time(start_time=None, use_syn=True):
if use_syn:
torch.npu.synchronize()
time_stamp = time.time()
if start_time is not None:
time_stamp -= start_time
return time_stamp
return time_stamp
class BaseSparsePredictor(ABC):
def __init__(self, sparse_config: Dict[str, Any], sparse_params: Optional[Dict[str, Any]] = None):
self.sparse_params = sparse_params or {}
self.sparse_config = sparse_config
self.block_size_q = int(sparse_config["block_size_Q"])
self.block_size_k = int(sparse_config["block_size_K"])
self.double_stream_layers = self.sparse_params.get("double_stream_layers")
self.single_stream_layers = self.sparse_params.get("single_stream_layers")
self.attn_layers = self.sparse_params.get("attn_layers")
self.total_steps = self.sparse_params.get("num_steps")
self.device = self.sparse_params.get("device")
self.total_layers_per_step = (
self.double_stream_layers + self.single_stream_layers
if self.double_stream_layers is not None and self.single_stream_layers is not None
else self.attn_layers
)
self.step = 0
self.layer_counter = 0
self.index_type = torch.int32
self.current = {
"step": self.step,
"layer": 0,
"num_steps": self.total_steps
}
@staticmethod
def _route_heads_bsnd(x: torch.Tensor, head_index: torch.Tensor) -> torch.Tensor:
head_index = head_index.to(device=x.device, dtype=torch.long)
return x.index_select(2, head_index).contiguous()
@staticmethod
def _slice_ulysses_local_heads_with_uaa(runtime_attn, q, k, v, global_head_num):
if runtime_attn.ulysses_anything:
h = q.shape[2]
padded_tensors = []
h_pad = 0
for idx, tensor in enumerate((q, k, v)):
padded_tensor, tensor_h_pad = _maybe_pad_qkv_head(tensor, h, runtime_attn.ulysses_world_size)
padded_tensors.append(padded_tensor)
if idx == 0:
h_pad = tensor_h_pad
head_per_rank = (h + h_pad) // runtime_attn.ulysses_world_size
head_start = runtime_attn.ulysses_rank * head_per_rank
head_end = head_start + head_per_rank
sliced_tensors = [
tensor[:, :, head_start:head_end, :].contiguous()
for tensor in padded_tensors
]
sliced_tensors = [
_maybe_unpad_qkv_head(
tensor,
h_pad,
runtime_attn.ulysses_rank,
runtime_attn.ulysses_world_size,
runtime_attn.ulysses_pg,
).contiguous()
for tensor in sliced_tensors
]
return tuple(sliced_tensors)
if global_head_num <= 0 or global_head_num % runtime_attn.ulysses_world_size != 0:
raise ValueError("global head num must be divisible by ulysses_world_size in standard ulysses.")
head_per_rank = global_head_num // runtime_attn.ulysses_world_size
head_start = runtime_attn.ulysses_rank * head_per_rank
head_end = head_start + head_per_rank
return tuple(
tensor[:, :, head_start:head_end, :].contiguous()
for tensor in (q, k, v)
)
@staticmethod
def _ulysses_all_gather_heads_bshd(runtime_attn, x: torch.Tensor) -> torch.Tensor:
if runtime_attn.ulysses_world_size <= 1:
return x
return all_gather_anything(
tensor=x,
dim=2,
world_size=runtime_attn.ulysses_world_size,
group=runtime_attn.ulysses_pg,
).contiguous()
@staticmethod
def _ulysses_all_to_all_qkv(runtime_attn, x: torch.Tensor) -> torch.Tensor:
if runtime_attn.ulysses_anything:
return getattr(runtime_attn, "_all_to_all_qkv_anything")(x)()
return getattr(runtime_attn, "_all_to_all_qkv")(x)
@staticmethod
def _ulysses_all_to_all_o(
runtime_attn,
x: torch.Tensor,
*,
num_qo_head: int,
q_s_local: int,
) -> torch.Tensor:
if runtime_attn.ulysses_anything:
return getattr(runtime_attn, "_all_to_all_o_anything")(
x,
NUM_QO_HEAD=num_qo_head,
Q_S_LOCAL=q_s_local,
)()
return getattr(runtime_attn, "_all_to_all_o")(x)
def _ulysses_all_to_all_qkv_triplet(self, runtime_attn, q, k, v):
return (
self._ulysses_all_to_all_qkv(runtime_attn, q),
self._ulysses_all_to_all_qkv(runtime_attn, k),
self._ulysses_all_to_all_qkv(runtime_attn, v),
)
@staticmethod
def _move_sink_qkv_to_end(q, k, v, *, q_sink_len: int, kv_sink_len: int):
sink_q, rest_q = q[:, :, :q_sink_len, :], q[:, :, q_sink_len:, :]
sink_k, rest_k = k[:, :, :kv_sink_len, :], k[:, :, kv_sink_len:, :]
sink_v, rest_v = v[:, :, :kv_sink_len, :], v[:, :, kv_sink_len:, :]
return (
torch.cat((rest_q, sink_q), dim=2).contiguous(),
torch.cat((rest_k, sink_k), dim=2).contiguous(),
torch.cat((rest_v, sink_v), dim=2).contiguous(),
)
@staticmethod
def _split_img_txt_qkv(q_full, k_full, v_full, img_q_len: int, img_kv_len: int):
return {
"q_img": q_full[:, :img_q_len, :, :].contiguous(),
"k_img": k_full[:, :img_kv_len, :, :].contiguous(),
"v_img": v_full[:, :img_kv_len, :, :].contiguous(),
"txt_q": q_full[:, img_q_len:, :, :].contiguous(),
"txt_k": k_full[:, img_kv_len:, :, :].contiguous(),
"txt_v": v_full[:, img_kv_len:, :, :].contiguous(),
}
@staticmethod
def _ring_all_gather_seq(runtime_attn, x: torch.Tensor) -> torch.Tensor:
b, s_local, h, d = x.shape
gathered = torch.empty(
(runtime_attn.ring_world_size, b, s_local, h, d),
dtype=x.dtype,
device=x.device,
)
dist.all_gather_into_tensor(gathered, x.contiguous(), group=runtime_attn.ring_pg)
return gathered.permute(1, 0, 2, 3, 4).reshape(b, runtime_attn.ring_world_size * s_local, h, d)
@staticmethod
def _ring_gathered_img_to_bnsd(pre_attn_layout, gathered, txt_bnsd, rank_indices=None):
if rank_indices is not None:
gathered = torch.index_select(gathered, dim=0, index=rank_indices)
_, b, _, n, d = gathered.shape
img_bshd = gathered.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d).contiguous()
return torch.cat([pre_attn_layout(img_bshd), txt_bnsd], dim=2).contiguous()
@staticmethod
def _split_patched_width(total_width: int, world_size: int):
base = int(total_width) // int(world_size)
extra = int(total_width) % int(world_size)
return [base + (1 if rank < extra else 0) for rank in range(int(world_size))]
def _get_ring_local_sink_token_len(
self,
*,
local_img_token_len: int,
ring_rank: int,
ring_world_size: int,
) -> int:
sink_len = min(int(getattr(self, "sink_frame_len", 0)), int(getattr(self, "img_token_len", 0)))
if sink_len <= 0:
return 0
frame_num = int(getattr(self, "frame_num", 0))
frame_patch_h = int(getattr(self, "frame_patch_h", 0))
frame_patch_w = int(getattr(self, "frame_patch_w", 0))
width_splits = self._split_patched_width(frame_patch_w, ring_world_size)
width_start = int(sum(width_splits[:ring_rank]))
width_len = int(width_splits[ring_rank])
frame_size = frame_patch_h * frame_patch_w
full_frames, rem_tokens = divmod(sink_len, frame_size)
local_sink = full_frames * frame_patch_h * width_len
for row in range(frame_patch_h):
row_start = row * frame_patch_w
row_covered = max(0, min(rem_tokens - row_start, frame_patch_w))
if row_covered <= 0:
break
local_sink += max(0, min(width_start + width_len, row_covered) - width_start)
return max(0, min(int(local_sink), int(local_img_token_len)))
@staticmethod
def _ring_all_gather_seq_bnsd(runtime_attn, x: torch.Tensor) -> torch.Tensor:
b, n, s_local, d = x.shape
local_len = torch.tensor([s_local], dtype=torch.int32, device=x.device)
gathered_lens = [torch.empty_like(local_len) for _ in range(runtime_attn.ring_world_size)]
dist.all_gather(gathered_lens, local_len, group=runtime_attn.ring_pg)
seq_lens = [int(item.item()) for item in gathered_lens]
max_len = max(seq_lens)
if s_local < max_len:
pad = torch.zeros((b, n, max_len - s_local, d), dtype=x.dtype, device=x.device)
x = torch.cat([x, pad], dim=2).contiguous()
gathered = torch.empty(
(runtime_attn.ring_world_size, b, n, max_len, d),
dtype=x.dtype,
device=x.device,
)
dist.all_gather_into_tensor(gathered, x.contiguous(), group=runtime_attn.ring_pg)
parts = [gathered[rank, :, :, :seq_lens[rank], :] for rank in range(runtime_attn.ring_world_size)]
return torch.cat(parts, dim=2).contiguous()
@staticmethod
def _merge_two_sparse_outputs(out1, lse1, out2, lse2):
if lse1 is None or lse2 is None:
raise ValueError("Sparse ring overlap requires LSE outputs.")
lse1 = BaseSparsePredictor._format_sparse_lse_for_merge(lse1, out1)
lse2 = BaseSparsePredictor._format_sparse_lse_for_merge(lse2, out2)
merged, _ = BaseSparsePredictor._merge_sparse_outputs_with_lse(out1, lse1, out2, lse2)
return merged.to(dtype=out1.dtype)
@staticmethod
def _format_sparse_lse_for_merge(lse: torch.Tensor, ref_out: torch.Tensor) -> torch.Tensor:
b, h, s = ref_out.shape[:3]
if lse.dim() == 4 and lse.shape == (b, h, s, 1):
return lse.to(dtype=torch.float32)
if lse.dim() == 3 and lse.shape == (b, h, s):
return lse.unsqueeze(-1).to(dtype=torch.float32)
raise ValueError(f"Unsupported sparse LSE shape: {tuple(lse.shape)}")
@staticmethod
def _merge_sparse_outputs_with_lse(out1, lse1, out2, lse2):
lses = (lse1, lse2)
valid = [torch.isfinite(lse) for lse in lses]
any_valid = valid[0] | valid[1]
masked_lse = [
torch.where(mask, lse, torch.full_like(lse, float("-inf")))
for lse, mask in zip(lses, valid)
]
max_lse = torch.maximum(masked_lse[0], masked_lse[1])
safe_max_lse = torch.where(any_valid, max_lse, torch.zeros_like(max_lse))
exp_lse = [
torch.where(mask, torch.exp(lse - safe_max_lse), torch.zeros_like(lse))
for lse, mask in zip(lses, valid)
]
denom = torch.clamp_min(exp_lse[0] + exp_lse[1], 1e-30)
weights = [
torch.where(any_valid, exp_item / denom, torch.zeros_like(exp_item))
for exp_item in exp_lse
]
merged = weights[0] * out1.to(torch.float32)
merged = merged + weights[1] * out2.to(torch.float32)
merged_lse = safe_max_lse + torch.log(denom)
merged_lse = torch.where(any_valid, merged_lse, torch.full_like(merged_lse, float("-inf")))
return merged.to(dtype=out1.dtype), merged_lse
def _build_ring_native_topk_sabi(
self,
q: torch.Tensor,
k: torch.Tensor,
*,
q_dense_prefix_len: int = 0,
q_dense_suffix_len: int = 0,
k_mean_override: Optional[torch.Tensor] = None,
) -> torch.Tensor:
block_num_q = math.ceil(int(q.shape[2]) / int(self.block_size_q))
block_num_k = (
int(k_mean_override.shape[2])
if k_mean_override is not None
else math.ceil(int(k.shape[2]) / int(self.block_size_k))
)
if k_mean_override is None:
sabi = self.get_sabi_v2(q, k)
else:
q_mean = self.pooling_matmul(q, self.block_size_q, block_num_q)
attn = (q_mean @ k_mean_override.transpose(-2, -1)).softmax(dim=-1)
sabi = torch.full(
size=(q.shape[0], q.shape[1], block_num_q, block_num_k),
fill_value=-1,
dtype=self.index_type,
device=q.device,
)
step_idx, layer_idx = self.get_effective_indices()
sparsity = self._get_layer_sparsity_runtime_head_order(step_idx, layer_idx, device=attn.device)
k_nums = ((1.0 - sparsity) * block_num_k).to(torch.int32)
k_nums = torch.clamp(k_nums, min=1, max=block_num_k)
max_k = int(k_nums.max().item())
if max_k > 0:
_, indices = torch.topk(attn, max_k, dim=-1)
k_nums_expanded = k_nums.view(1, q.shape[1], 1, 1)
arange_k = torch.arange(max_k, device=q.device).view(1, 1, 1, max_k)
mask = arange_k < k_nums_expanded
sabi[:, :, :, :max_k] = torch.where(mask, indices, -1)
full_k = torch.arange(block_num_k, dtype=self.index_type, device=q.device).view(1, 1, 1, block_num_k)
full_k = full_k.expand(q.shape[0], q.shape[1], -1, -1)
prefix_blocks = math.ceil(max(0, int(q_dense_prefix_len)) / int(self.block_size_q))
suffix_blocks = math.ceil(max(0, int(q_dense_suffix_len)) / int(self.block_size_q))
if prefix_blocks > 0:
end = min(prefix_blocks, block_num_q)
sabi[:, :, :end, :] = full_k.expand(q.shape[0], q.shape[1], end, block_num_k)
if suffix_blocks > 0:
start = max(0, block_num_q - suffix_blocks)
rows = block_num_q - start
sabi[:, :, start:, :] = full_k.expand(q.shape[0], q.shape[1], rows, block_num_k)
return sabi.to(device=q.device, dtype=torch.uint16).contiguous()
def _run_ring_topk_sparse_part(
self,
q,
k,
v,
*,
batch_size,
num_heads,
scale,
dense_prefix_len,
dense_suffix_len,
):
sabi = self._build_ring_native_topk_sabi(
q,
k,
q_dense_prefix_len=dense_prefix_len,
q_dense_suffix_len=dense_suffix_len,
)
return self._call_blitz_sparse_attention(
q,
k,
v,
sabi=sabi,
actual_seq_lengths=[int(q.shape[2])] * batch_size,
actual_seq_lengths_kv=[int(k.shape[2])] * batch_size,
num_heads=num_heads,
scale=scale,
return_lse=True,
)
def _call_blitz_sparse_attention(
self,
q,
k,
v,
*,
sabi,
actual_seq_lengths,
actual_seq_lengths_kv,
num_heads,
scale,
return_lse: bool = False,
):
kwargs = dict(
sabi=sabi,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_heads=num_heads,
num_key_value_heads=num_heads,
input_layout="BNSD",
scale_value=scale,
atten_mask=None,
sparse_mode=0,
block_shape=[self.block_size_q, self.block_size_k],
)
if return_lse:
kwargs["softmax_lse_flag"] = True
out = torch_bsa.blitz_sparse_attention(q.contiguous(), k.contiguous(), v.contiguous(), **kwargs)
if isinstance(out, tuple):
if len(out) >= 2:
return out[0], out[1]
return out[0], None
return out, None
@staticmethod
def _take_attention_output(attn_out):
if isinstance(attn_out, tuple):
return attn_out[0]
return attn_out
@staticmethod
def _replace_prefix_with_dense_attention(
sparse_out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
prefix_len: int,
num_heads: int,
scale: float,
) -> torch.Tensor:
if prefix_len <= 0:
return sparse_out
dense_prefix = torch_npu.npu_fused_infer_attention_score(
q[:, :, :prefix_len, :].contiguous(),
k,
v,
num_heads=num_heads,
input_layout="BNSD",
scale=scale,
)[0]
if prefix_len < sparse_out.shape[2]:
return torch.cat([dense_prefix, sparse_out[:, :, prefix_len:, :]], dim=2).contiguous()
return dense_prefix.contiguous()
def _forward_ring_topk_global(
self,
runtime_attn,
block_args: dict,
softmax_scale: Optional[float],
) -> torch.Tensor:
q_img_local, k_img_local, v_img_local = (
block_args["q_img_local"], block_args["k_img_local"], block_args["v_img_local"]
)
txt_q, txt_k, txt_v = block_args["txt_q"], block_args["txt_k"], block_args["txt_v"]
k_img_global = self._ring_all_gather_seq(runtime_attn, k_img_local)
v_img_global = self._ring_all_gather_seq(runtime_attn, v_img_local)
local_sink_len = self._get_ring_local_sink_token_len(
local_img_token_len=int(q_img_local.shape[1]),
ring_rank=int(runtime_attn.ring_rank),
ring_world_size=int(runtime_attn.ring_world_size),
)
q_img_exec = torch.cat(
[q_img_local[:, local_sink_len:, :, :], q_img_local[:, :local_sink_len, :, :]],
dim=1,
).contiguous() if local_sink_len > 0 else q_img_local
q_full = torch.cat([q_img_exec, txt_q], dim=1).contiguous()
k_full = torch.cat([k_img_global, txt_k], dim=1).contiguous()
v_full = torch.cat([v_img_global, txt_v], dim=1).contiguous()
prefix_q_len, prefix_kv_len = int(q_full.shape[1]), int(k_full.shape[1])
out_full = self.attention(
q=q_full,
k=k_full,
v=v_full,
cu_seqlens_q=[0, prefix_q_len],
cu_seqlens_kv=[0, prefix_kv_len],
return_bshd=True,
softmax_scale=softmax_scale,
img_token_len_q=int(q_img_local.shape[1]),
img_token_len_k=int(k_img_global.shape[1]),
sink_frame_len_q=0,
sink_frame_len_k=int(self.sink_frame_len),
)
img_out = out_full[:, :int(q_img_local.shape[1]), :, :].contiguous()
if local_sink_len > 0:
img_out = torch.cat([img_out[:, -local_sink_len:, :, :], img_out[:, :-local_sink_len, :, :]], dim=1)
dense_sink = torch_npu.npu_fused_infer_attention_score(
q_img_local[:, :local_sink_len, :, :].contiguous(),
k_full,
v_full,
num_heads=int(q_img_local.shape[2]),
input_layout="BSND",
scale=softmax_scale if softmax_scale is not None else q_img_local.shape[-1] ** (-0.5),
)[0]
img_out = torch.cat([dense_sink, img_out[:, local_sink_len:, :, :]], dim=1).contiguous()
txt_out = out_full[:, int(q_img_local.shape[1]):, :, :].contiguous()
return torch.cat([img_out.contiguous(), txt_out], dim=1).contiguous()
def _forward_ring_topk_overlap(
self,
runtime_attn,
block_args: dict,
softmax_scale: Optional[float],
) -> torch.Tensor:
q_img_local, k_img_local, v_img_local = (
block_args["q_img_local"], block_args["k_img_local"], block_args["v_img_local"]
)
txt_q, txt_k, txt_v = block_args["txt_q"], block_args["txt_k"], block_args["txt_v"]
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT["BNSD"]
q_full = torch.cat([q_img_local, txt_q], dim=1).contiguous()
b, _, n, d = q_full.shape
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(d)
q1 = pre_attn_layout(q_full).contiguous()
k_local = pre_attn_layout(k_img_local).contiguous()
v_local = pre_attn_layout(v_img_local).contiguous()
txt_k_bnsd = pre_attn_layout(txt_k).contiguous()
txt_v_bnsd = pre_attn_layout(txt_v).contiguous()
world_size = int(runtime_attn.ring_world_size)
rank = int(runtime_attn.ring_rank)
local_sink_len = self._get_ring_local_sink_token_len(
local_img_token_len=int(q_img_local.shape[1]),
ring_rank=rank,
ring_world_size=world_size,
)
k_gathered = torch.empty(
(runtime_attn.ring_world_size, *k_img_local.shape),
dtype=k_img_local.dtype,
device=k_img_local.device,
)
v_gathered = torch.empty(
(runtime_attn.ring_world_size, *v_img_local.shape),
dtype=v_img_local.dtype,
device=v_img_local.device,
)
k_handle = dist.all_gather_into_tensor(
k_gathered,
k_img_local.contiguous(),
group=runtime_attn.ring_pg,
async_op=True,
)
v_handle = dist.all_gather_into_tensor(
v_gathered,
v_img_local.contiguous(),
group=runtime_attn.ring_pg,
async_op=True,
)
out_local, lse_local = self._run_ring_topk_sparse_part(
q1,
k_local,
v_local,
batch_size=b,
num_heads=n,
scale=scale,
dense_prefix_len=local_sink_len,
dense_suffix_len=0,
)
k_handle.wait()
v_handle.wait()
k_full = self._ring_gathered_img_to_bnsd(pre_attn_layout, k_gathered, txt_k_bnsd)
v_full = self._ring_gathered_img_to_bnsd(pre_attn_layout, v_gathered, txt_v_bnsd)
other_indices = torch.tensor(
[idx for idx in range(world_size) if idx != rank],
dtype=torch.long,
device=q1.device,
)
k_other = self._ring_gathered_img_to_bnsd(pre_attn_layout, k_gathered, txt_k_bnsd, other_indices)
v_other = self._ring_gathered_img_to_bnsd(pre_attn_layout, v_gathered, txt_v_bnsd, other_indices)
out_other, lse_other = self._run_ring_topk_sparse_part(
q1,
k_other,
v_other,
batch_size=b,
num_heads=n,
scale=scale,
dense_prefix_len=local_sink_len,
dense_suffix_len=int(txt_q.shape[1]),
)
sparse_out = self._merge_two_sparse_outputs(out_local, lse_local, out_other, lse_other)
sparse_out = self._replace_prefix_with_dense_attention(
sparse_out,
q1,
k_full,
v_full,
prefix_len=local_sink_len,
num_heads=n,
scale=scale,
)
return post_attn_layout(sparse_out).contiguous()
def forward_ring_sparse(
self,
runtime_attn,
block_args: dict,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
if self.ring_sparse_overlap:
return self._forward_ring_topk_overlap(runtime_attn, block_args, softmax_scale)
return self._forward_ring_topk_global(runtime_attn, block_args, softmax_scale)
def _get_local_seq_remap_indices(
self,
device: torch.device,
*,
ulysses_world_size: int,
reverse: bool = False,
) -> torch.Tensor:
cache = getattr(self, "_seq_remap_index_cache", None)
if cache is None:
cache = {}
self._seq_remap_index_cache = cache
frame_num = int(getattr(self, "frame_num", 0))
frame_patch_h = int(getattr(self, "frame_patch_h", 0))
frame_patch_w = int(getattr(self, "frame_patch_w", 0))
img_token_len = int(getattr(self, "img_token_len", 0))
cache_key = (device, ulysses_world_size, frame_num, frame_patch_h, frame_patch_w)
cached = cache.get(cache_key)
if cached is None:
cpu = torch.device("cpu")
base = frame_patch_w // ulysses_world_size
extra = frame_patch_w % ulysses_world_size
width_splits = [base + (1 if rank < extra else 0) for rank in range(ulysses_world_size)]
token_splits = [frame_num * frame_patch_h * w for w in width_splits]
idx_current = torch.arange(img_token_len, dtype=torch.long, device=cpu).view(1, 1, img_token_len, 1)
idx_split = torch.split(idx_current, token_splits, dim=2)
idx_split = [
chunk.view(1, 1, frame_num, frame_patch_h, width_splits[idx], 1)
for idx, chunk in enumerate(idx_split)
]
idx_canonical = torch.cat(idx_split, dim=4).reshape(1, 1, img_token_len, 1).contiguous()
cur2can_cpu = idx_canonical.view(-1).long()
can2cur_cpu = torch.empty_like(cur2can_cpu)
can2cur_cpu.index_copy_(0, cur2can_cpu, torch.arange(img_token_len, dtype=torch.long, device=cpu))
cached = (cur2can_cpu.to(device=device), can2cur_cpu.to(device=device))
cache[cache_key] = cached
return cached[1] if reverse else cached[0]
def get_effective_indices(self) -> Tuple[int, int]:
"""获取当前有效的(step, layer)索引"""
effective_step = self.step if self.total_steps > 1 else 0
return effective_step, self.layer_counter
def update_layer_counter(self):
self.layer_counter += 1
if self.layer_counter >= self.total_layers_per_step:
self.step += 1
if self.step >= self.total_steps:
self.step = 0
self.layer_counter = 0
self.current["step"] = self.step
self.current["layer"] = self.layer_counter
return self.get_effective_indices()
def _apply_local_seq_remap_tensor(
self,
x: torch.Tensor,
*,
ulysses_world_size: int,
reverse: bool = False,
) -> torch.Tensor:
remap_index = self._get_local_seq_remap_indices(
x.device,
ulysses_world_size=ulysses_world_size,
reverse=reverse,
)
img_token_len = int(self.img_token_len)
if x.shape[2] < img_token_len:
raise ValueError(
f"Ulysses seq len ({x.shape[2]}) is smaller than img_token_len ({img_token_len})."
)
x_img = x[:, :, :img_token_len, :].index_select(2, remap_index)
x_ctx = x[:, :, img_token_len:, :]
return torch.cat((x_img, x_ctx), dim=2).contiguous() if x_ctx.shape[2] > 0 else x_img.contiguous()
def forward_ulysses_sparse(
self,
runtime_attn,
block_args: dict,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
q_img_local, k_img_local, v_img_local = (
block_args["q_img_local"], block_args["k_img_local"], block_args["v_img_local"]
)
txt_q, txt_k, txt_v = block_args["txt_q"], block_args["txt_k"], block_args["txt_v"]
route_head_perm = route_inv_head_perm = None
if (
getattr(self, "ulysses_head_routing_enabled", False)
and runtime_attn.ulysses_world_size > 1
and hasattr(self, "get_runtime_head_route")
):
route_head_perm, route_inv_head_perm = self.get_runtime_head_route()
if route_head_perm is not None:
q_local_full = self._route_heads_bsnd(
torch.cat([q_img_local, txt_q], dim=1).contiguous(),
route_head_perm,
)
k_local_full = self._route_heads_bsnd(
torch.cat([k_img_local, txt_k], dim=1).contiguous(),
route_head_perm,
)
v_local_full = self._route_heads_bsnd(
torch.cat([v_img_local, txt_v], dim=1).contiguous(),
route_head_perm,
)
img_q_len_local = q_img_local.shape[1]
img_kv_len_local = k_img_local.shape[1]
split_args = self._split_img_txt_qkv(
q_local_full,
k_local_full,
v_local_full,
img_q_len_local,
img_kv_len_local,
)
q_img_local, k_img_local, v_img_local = split_args["q_img"], split_args["k_img"], split_args["v_img"]
txt_q, txt_k, txt_v = split_args["txt_q"], split_args["txt_k"], split_args["txt_v"]
q_img_global, k_img_global, v_img_global = self._ulysses_all_to_all_qkv_triplet(
runtime_attn,
q_img_local,
k_img_local,
v_img_local,
)
txt_q_local_h, txt_k_local_h, txt_v_local_h = self._slice_ulysses_local_heads_with_uaa(
runtime_attn,
txt_q, txt_k, txt_v, q_img_local.shape[2]
)
predictor_name = str(getattr(self, "predictor_name", ""))
is_svg_mode = predictor_name.lower() == "svg" or "svg" in self.__class__.__name__.lower()
attn_kwargs = {}
img_seq_len = int(q_img_global.shape[1])
if is_svg_mode:
q_full = torch.cat([q_img_global, txt_q_local_h], dim=1).contiguous()
k_full = torch.cat([k_img_global, txt_k_local_h], dim=1).contiguous()
v_full = torch.cat([v_img_global, txt_v_local_h], dim=1).contiguous()
attn_kwargs["joint_q_local_bnsd"] = txt_q_local_h
else:
h_local = q_img_global.shape[2]
img_packed = torch.cat([q_img_global, k_img_global, v_img_global], dim=2)
txt_packed = torch.cat([txt_q_local_h, txt_k_local_h, txt_v_local_h], dim=2)
qkv_full = torch.cat([img_packed, txt_packed], dim=1).contiguous()
q_full, k_full, v_full = torch.split(qkv_full, [h_local, h_local, h_local], dim=2)
head_sabi = None
prefix_q_len = int(q_full.shape[1])
prefix_kv_len = int(k_full.shape[1])
out_full = self.attention(
q=q_full,
k=k_full,
v=v_full,
head_sabi=head_sabi,
ulysses_pg=runtime_attn.ulysses_pg,
ulysses_rank=runtime_attn.ulysses_rank,
ulysses_world_size=runtime_attn.ulysses_world_size,
cu_seqlens_q=[0, prefix_q_len],
cu_seqlens_kv=[0, prefix_kv_len],
return_bshd=True,
softmax_scale=softmax_scale,
**attn_kwargs,
)
img_out_global = out_full[:, :img_seq_len, :, :].contiguous()
txt_out_local_h = out_full[:, img_seq_len:, :, :].contiguous()
if runtime_attn.ulysses_anything:
img_out_local = self._ulysses_all_to_all_o(
runtime_attn,
img_out_global,
num_qo_head=q_img_local.shape[2],
q_s_local=q_img_local.shape[1],
)
txt_out = all_gather_anything(
tensor=txt_out_local_h,
dim=2,
world_size=runtime_attn.ulysses_world_size,
group=runtime_attn.ulysses_pg,
).contiguous()
else:
img_out_local = self._ulysses_all_to_all_o(
runtime_attn,
img_out_global,
num_qo_head=q_img_local.shape[2],
q_s_local=q_img_local.shape[1],
)
txt_out = self._ulysses_all_gather_heads_bshd(runtime_attn, txt_out_local_h)
out = torch.cat([img_out_local, txt_out], dim=1).contiguous()
expected_local_seq_len = int(q_img_local.shape[1]) + int(txt_q.shape[1])
if out.shape[1] != expected_local_seq_len:
img_expected = int(q_img_local.shape[1])
txt_expected = int(txt_q.shape[1])
img_actual = int(img_out_local.shape[1])
txt_actual = int(txt_out.shape[1])
if img_actual >= img_expected:
img_out_local = img_out_local[:, :img_expected, :, :].contiguous()
if txt_actual >= txt_expected:
txt_out = txt_out[:, :txt_expected, :, :].contiguous()
out = torch.cat([img_out_local, txt_out], dim=1).contiguous()
if int(out.shape[1]) > expected_local_seq_len:
out = out[:, :expected_local_seq_len, :, :].contiguous()
if route_inv_head_perm is not None:
out = self._route_heads_bsnd(out, route_inv_head_perm)
return out
def load_sparsity(self, path: Union[str, Path], step_pattern: str = "step-{}.pt", step_num: int = 50):
path = Path(path)
if path.is_file():
self.load_single_file(path)
elif path.is_dir():
self.load_directory(path, step_pattern, step_num)
else:
raise FileNotFoundError(f"路径不存在: {path}")
def load_single_file(self, file_path: Path):
self.sparsity_dict[0] = torch.load(file_path, map_location=self.device)
def load_directory(self, file_path: Path, step_pattern: str, step_num: int):
for i in range(step_num):
step_dir_path = os.path.join(file_path, f"step-{i}/")
sparsity_pt_path = os.path.join(step_dir_path, f"sparsity_of_RE_{self.cac_threshold}_only_img.pt")
sparsity_per_step = torch.load(sparsity_pt_path, map_location=self.device)
self.sparsity_dict[i] = sparsity_per_step
@staticmethod
def padding_sabi(selected_indices_tensor: torch.Tensor, max_width: int, pad_value: int = -1):
padded_selected_indices_tensor = [F.pad(t, (0, max_width - t.shape[1]), value=pad_value)
for t in selected_indices_tensor]
result = torch.stack(padded_selected_indices_tensor, dim=0)
return result
def get_block_mask(self, q, sabi):
b, h, n, _ = q.shape
block_num_q = math.ceil(n / self.block_size_q)
block_num_k = math.ceil(n / self.block_size_k)
block_mask = torch.full((b, h, block_num_q, block_num_k), True, device=q.device, dtype=torch.bool)
valid_mask = sabi > -1
b_valid, h_valid, r_valid, _ = torch.where(valid_mask)
col_indices = sabi[valid_mask].long()
block_mask[b_valid, h_valid, r_valid, col_indices] = False
return block_mask
def get_token_mask(self, block_mask, n):
mask = block_mask[:, :, :, None, :, None]
b, h, bq, bk = block_mask.shape
mask = mask.expand(b, h, bq, self.block_size_q, bk, self.block_size_k)
mask = mask.reshape(b, h, bq * self.block_size_q, bk * self.block_size_k)
return mask[:, :, :n, :n]
def get_token_level_sparisty(self, token_mask):
sparisty = token_mask.to(torch.float32).mean().item()
with open(self.runtime_sparisty_path, 'a') as f:
f.write(f"{sparisty} \n")
def get_sparse_token_mask(self, q: torch.Tensor, k: torch.Tensor, final_sabi: torch.Tensor):
block_mask = self.get_block_mask(q, final_sabi)
token_mask = self.get_token_mask(block_mask, q.shape[2])
return token_mask
@abstractmethod
def get_sabi(self, q: torch.Tensor, k: torch.Tensor):
pass
def get_must_keep_blocks_indices(self, **kwargs):
return None
def combined_sabi_tensor_list(self, mid_sabi_list, must_keep_indices_q, must_keep_indices_k, num_blocks_k):
pass
def combined_sabi_tensor(self, mid_sabi_list, must_keep_indices_q, must_keep_indices_k, num_blocks_k):
return None
class TopKPredictor(BaseSparsePredictor):
def __init__(self, sparse_config, sparse_params=None):
super().__init__(sparse_config, sparse_params)
topk_config = sparse_config['TopK']
logger.info(topk_config)
self.predictor_name = topk_config['predictor_name']
self.sparse_time_step = topk_config['sparse_time_step']
self.sparsity_files_path = topk_config['sparsity_files_path']
self.cac_threshold = topk_config['CAC_threshold']
self.ring_sparse_overlap = bool(topk_config.get("ring_sparse_overlap", False))
self.sparsity_dict = {}
self.load_sparsity(self.sparsity_files_path, self.total_steps)
sample_step = next(iter(self.sparsity_dict.keys()))
sample_layer = self.sparsity_dict[sample_step][0]
self.head_num = int(torch.as_tensor(sample_layer).numel())
self.ulysses_degree_for_lb = 1
self.ulysses_head_routing_enabled = False
self.step_layer_head_perm = {}
self.step_layer_inv_head_perm = {}
self._ulysses_head_split_cache = {}
def _get_layer_sparsity_tensor(self, step_idx: int, layer_idx: int, device=None) -> torch.Tensor:
raw_sparsity = self.sparsity_dict[step_idx][layer_idx]
sparsity_cpu = torch.as_tensor(raw_sparsity, dtype=torch.float32, device="cpu").contiguous()
if device is None:
return sparsity_cpu
if isinstance(device, torch.device) and device.type == "cpu":
return sparsity_cpu
if isinstance(device, str) and device.lower().startswith("cpu"):
return sparsity_cpu
return torch.tensor(sparsity_cpu.tolist(), dtype=torch.float32, device=device).contiguous()
def _get_layer_sparsity_runtime_head_order(self, step_idx: int, layer_idx: int, device=None) -> torch.Tensor:
sparsity = self._get_layer_sparsity_tensor(step_idx, layer_idx, device=device)
if not self.ulysses_head_routing_enabled:
return sparsity
perm = self.step_layer_head_perm.get((int(step_idx), int(layer_idx)))
if perm is None:
return sparsity
if not torch.is_tensor(perm):
perm = torch.as_tensor(perm, dtype=torch.long)
perm = perm.to(device=sparsity.device, dtype=torch.long)
return sparsity.index_select(0, perm).contiguous()
def apply_head_reorder_for_load_balance(self, ulysses_degree: int):
self.ulysses_degree_for_lb = int(ulysses_degree)
self.step_layer_head_perm.clear()
self.step_layer_inv_head_perm.clear()
if self.ulysses_degree_for_lb <= 1:
self.ulysses_head_routing_enabled = False
return
for step_idx, layer_sparsity_list in self.sparsity_dict.items():
total_layers = len(layer_sparsity_list)
for layer_idx in range(total_layers):
sparsity = self._get_layer_sparsity_tensor(int(step_idx), int(layer_idx), device="cpu")
sparsity_list = sparsity.tolist()
rank_buckets = [[] for _ in range(self.ulysses_degree_for_lb)]
rank_costs = [0.0 for _ in range(self.ulysses_degree_for_lb)]
ranked_heads = sorted(
range(len(sparsity_list)),
key=lambda idx: (1.0 - float(sparsity_list[idx])),
reverse=True,
)
for head_idx in ranked_heads:
target_rank = min(
range(self.ulysses_degree_for_lb),
key=lambda rank: (rank_costs[rank], len(rank_buckets[rank]), rank),
)
rank_buckets[target_rank].append(head_idx)
rank_costs[target_rank] += 1.0 - float(sparsity_list[head_idx])
perm_list = [head_idx for bucket in rank_buckets for head_idx in bucket]
perm = torch.tensor(perm_list, dtype=torch.long)
inv_list = [0] * len(perm_list)
for new_idx, old_idx in enumerate(perm_list):
inv_list[old_idx] = new_idx
inv = torch.tensor(inv_list, dtype=torch.long)
self.step_layer_head_perm[(int(step_idx), int(layer_idx))] = perm
self.step_layer_inv_head_perm[(int(step_idx), int(layer_idx))] = inv
self.ulysses_head_routing_enabled = True
def get_runtime_head_route(self):
step_idx, layer_idx = self.get_effective_indices()
return (
self.step_layer_head_perm.get((int(step_idx), int(layer_idx))),
self.step_layer_inv_head_perm.get((int(step_idx), int(layer_idx))),
)
def get_block_attn(self, q: torch.Tensor, k, block_num_q, block_num_k):
_, _, n_q, _ = q.shape
n_k = k.shape[-2]
q_list = []
k_list = []
for i in range(block_num_q):
start = i * self.block_size_q
end = min((i + 1) * self.block_size_q, n_q)
q_i = q[:, :, start: end, :]
q_i = torch.mean(q_i, dim=2, keepdim=True)
q_list.append(q_i)
for i in range(block_num_k):
start = i * self.block_size_k
end = min((i + 1) * self.block_size_k, n_k)
k_i = k[:, :, start: end, :]
k_i = torch.mean(k_i, dim=2, keepdim=True)
k_list.append(k_i)
q_mean = torch.cat(q_list, dim=2).to(q.device)
k_mean = torch.cat(k_list, dim=2).to(k.device)
attn = q_mean @ k_mean.transpose(-2, -1)
attn = attn.softmax(dim=-1)
return attn, q_list, k_list
def pooling_matmul(self, q, block_size, block_num):
b, h, n, d = q.shape
if block_num <= 0:
return q.new_zeros((b, h, 0, d))
full_len = int(block_num) * int(block_size)
if n < full_len:
q = F.pad(q, (0, 0, 0, full_len - n))
elif n > full_len:
q = q[:, :, :full_len, :].contiguous()
q_blocks = q.reshape(b, h, int(block_num), int(block_size), d)
q_sum = q_blocks.sum(dim=3)
valid_counts = torch.full((int(block_num),), int(block_size), device=q.device, dtype=q.dtype)
tail = n - (int(block_num) - 1) * int(block_size)
if tail < int(block_size):
valid_counts[-1] = max(tail, 1)
return (q_sum / valid_counts.view(1, 1, int(block_num), 1)).contiguous()
def get_block_attn_by_matmul_v2(self, q, k, block_num_q, block_num_k):
q_mean = self.pooling_matmul(q, self.block_size_q, block_num_q)
k_mean = self.pooling_matmul(k, self.block_size_k, block_num_k)
attn = q_mean @ k_mean.transpose(-2, -1)
attn = attn.softmax(dim=-1)
return attn
def get_sabi(self, q: torch.Tensor, k: torch.Tensor, sparsity_override: Optional[torch.Tensor] = None):
b, h, n_q, _ = q.shape
n_k = k.shape[-2]
block_num_q = math.ceil(n_q / self.block_size_q)
block_num_k = math.ceil(n_k / self.block_size_k)
attn, _, _ = self.get_block_attn(q, k, block_num_q, block_num_k)
new_sabi_list = []
if sparsity_override is None:
step_idx, layer_idx = self.get_effective_indices()
sparsity = self._get_layer_sparsity_runtime_head_order(step_idx, layer_idx, device=attn.device)
else:
sparsity = sparsity_override.to(device=attn.device, dtype=torch.float32).contiguous()
for batch in range(b):
for head in range(h):
k_num = int((1 - sparsity[head]) * block_num_k)
attn_bh = attn[batch, head]
_, indices = torch.topk(attn_bh, k_num, dim=-1)
new_sabi_list.append(indices)
return new_sabi_list
def get_sabi_v2(self, q: torch.Tensor, k: torch.Tensor, sparsity_override: Optional[torch.Tensor] = None):
b, h, n_q, _ = q.shape
n_k = k.shape[-2]
block_num_q = math.ceil(n_q / self.block_size_q)
block_num_k = math.ceil(n_k / self.block_size_k)
attn = self.get_block_attn_by_matmul_v2(q, k, block_num_q, block_num_k)
new_sabi_tensor = torch.full(size=(b, h, block_num_q, block_num_k),
fill_value=-1, dtype=self.index_type, device=q.device)
if sparsity_override is None:
step_idx, layer_idx = self.get_effective_indices()
sparsity = self._get_layer_sparsity_runtime_head_order(step_idx, layer_idx, device=attn.device)
else:
sparsity = sparsity_override.to(device=attn.device, dtype=torch.float32).contiguous()
k_nums = ((1.0 - sparsity) * block_num_k).to(torch.int32)
k_nums = torch.clamp(k_nums, min=1, max=block_num_k)
max_k = int(k_nums.max().item())
_, indices = torch.topk(attn, max_k, dim=-1)
k_nums_expanded = k_nums.view(1, h, 1, 1)
arange_k = torch.arange(max_k, device=q.device).view(1, 1, 1, max_k)
mask = arange_k < k_nums_expanded
new_sabi_tensor[:, :, :, :max_k] = torch.where(mask, indices, -1)
return new_sabi_tensor
def get_final_sabi(self, q: torch.Tensor, k: torch.Tensor, **kwargs):
'''sabi连接must_keep的block indices,得到final_sabi'''
sink_frame_len = kwargs["sink_frame_len"]
img_token_len = kwargs["img_token_len"]
sparsity_override = kwargs.get("sparsity_override")
all_token_len = q.shape[2]
txt_token_len = all_token_len - img_token_len
n_q = q.shape[2]
n_k = k.shape[-2]
block_num_q = math.ceil(n_q / self.block_size_q)
block_num_k = math.ceil(n_k / self.block_size_k)
sink_txt_blocks_q, must_keep_indices_q, sink_txt_blocks_k, must_keep_indices_k = \
self.get_must_keep_blocks_indices(token_len=all_token_len,
sink_frame_len=sink_frame_len,
txt_token_len=txt_token_len)
mid_q, mid_k = q[:, :, : (block_num_q - sink_txt_blocks_q) * self.block_size_q, :],\
k[:, :, : (block_num_k - sink_txt_blocks_k) * self.block_size_k, :]
mid_sabi_list = self.get_sabi(mid_q, mid_k, sparsity_override=sparsity_override)
final_sabi = self.combined_sabi_tensor_list(mid_sabi_list, must_keep_indices_q,
must_keep_indices_k, block_num_k).unsqueeze(0)
return final_sabi
def get_final_sabi_v2(self, q: torch.Tensor, k: torch.Tensor, **kwargs):
'''sabi连接must_keep的block indices,得到final_sabi'''
sink_frame_len = kwargs["sink_frame_len"]
img_token_len = kwargs["img_token_len"]
img_token_len_q = int(kwargs.get("img_token_len_q", img_token_len))
img_token_len_k = int(kwargs.get("img_token_len_k", img_token_len))
sink_frame_len_q = int(kwargs.get("sink_frame_len_q", sink_frame_len))
sink_frame_len_k = int(kwargs.get("sink_frame_len_k", sink_frame_len))
sparsity_override = kwargs.get("sparsity_override")
n_q = q.shape[2]
n_k = k.shape[-2]
txt_token_len_q = max(n_q - img_token_len_q, 0)
txt_token_len_k = max(n_k - img_token_len_k, 0)
block_num_q = math.ceil(n_q / self.block_size_q)
block_num_k = math.ceil(n_k / self.block_size_k)
sink_txt_blocks_q, must_keep_indices_q, sink_txt_blocks_k, must_keep_indices_k = \
self.get_must_keep_blocks_indices(
token_len_q=n_q,
token_len_k=n_k,
sink_frame_len=sink_frame_len,
sink_frame_len_q=sink_frame_len_q,
sink_frame_len_k=sink_frame_len_k,
txt_token_len=txt_token_len_q,
txt_token_len_q=txt_token_len_q,
txt_token_len_k=txt_token_len_k,
)
mid_q, mid_k = q[:, :, : (block_num_q - sink_txt_blocks_q) * self.block_size_q, :],\
k[:, :, : (block_num_k - sink_txt_blocks_k) * self.block_size_k, :]
mid_q = mid_q.contiguous()
mid_k = mid_k.contiguous()
mid_sabi_tensor = self.get_sabi_v2(mid_q, mid_k, sparsity_override=sparsity_override)
final_sabi = self.combined_sabi_tensor(mid_sabi_tensor, must_keep_indices_q, must_keep_indices_k, block_num_k)
return final_sabi
MEMORY_LAYOUT = {
"TND": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"BNSD": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"BSND": (
lambda x: x,
lambda x: x,
),
}
def get_cu_seqlens(text_mask, img_len):
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
Args:
text_mask (torch.Tensor): the mask of text
img_len (int): the length of image
Returns:
torch.Tensor: the calculated cu_seqlens for flash attention
"""
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens[2 * i + 1] = s1
cu_seqlens[2 * i + 2] = s2
return cu_seqlens
def get_row_indices(block):
indices_list = []
for row in block:
row_false_indices = torch.where(~row)[0]
indices_list.append(row_false_indices)
result = torch.full((len(indices_list), block.shape[1]), -1, dtype=torch.long)
for i, indices in enumerate(indices_list):
if len(indices) > 0:
result[i, :len(indices)] = indices
return result
def block2sabi(block_mask):
sabi_mask = get_row_indices(block_mask)
return sabi_mask
class SVGPredictor(BaseSparsePredictor):
def __init__(self, sparse_config, sparse_params=None):
super().__init__(sparse_config, sparse_params)
svg_config = sparse_config['SVG']
logger.info(svg_config)
self.sparse_time_step = svg_config['sparse_time_step']
self.sparsity = svg_config['sparsity']
self.context_length = svg_config['context_length']
self.sample_mse_max_row = svg_config['sample_mse_max_row']
self.attention_masks = []
self.sabi_tensor = None
self._svg_mask_cache_key = None
def sparsity_to_width(self, sparsity, num_frame, frame_size):
seq_len = self.context_length + num_frame * frame_size
width = seq_len * (1 - math.sqrt(sparsity)) - self.context_length
width_frame = width / frame_size
return width_frame
def sample_mse(self, query, key, value, context_length):
if context_length > 0:
key = key[:, :, :-context_length]
value = value[:, :, :-context_length]
mask_name = ["spatial", "temporal"]
num_sampled_rows = 64
_, h, seq_len, dim = query.size()
num_sampled_rows = min(num_sampled_rows, seq_len)
sampled_row_high = min(seq_len, self.sample_mse_max_row)
sampled_rows = torch.randint(low=0, high=sampled_row_high, size=(num_sampled_rows,), device=query.device)
sampled_q = query[:, :, sampled_rows, :]
sampled_golden_hidden_states = torch_npu.npu_fusion_attention(
sampled_q, key, value, head_num=h, input_layout="BNSD", scale=(1 / dim ** 0.5),
)[0]
sampled_mses = {}
for mask_idx, attn_mask in enumerate(self.attention_masks):
sampled_attention_mask = attn_mask[sampled_rows, :-context_length]
sampled_hidden_states = torch_npu.npu_fusion_attention(
sampled_q, key, value, head_num=h, input_layout="BNSD", scale=(1 / dim ** 0.5),
atten_mask=~sampled_attention_mask
)[0]
mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
sampled_mses[mask_name[mask_idx]] = mse
del sampled_attention_mask, sampled_hidden_states
return sampled_mses
def get_attention_mask(self, mask_name, s, num_frame, frame_size, width_frame=1.5, block_size=512, device="NPU"):
block_size_q, block_size_k = self.block_size_q, self.block_size_k
if block_size % block_size_q != 0 or block_size % block_size_k != 0:
raise ValueError("block_size must be divisible")
q_num_per_block = block_size // block_size_q
k_num_per_block = block_size // block_size_k
context_length = s - num_frame * frame_size
attention_mask = torch.zeros((s, s), dtype=torch.bool, device="cpu")
num_block_q = math.ceil(s / block_size_q)
num_block_k = math.ceil(s / block_size_k)
pixel_attn_mask = torch.zeros_like(
attention_mask[:-context_length, :-context_length], dtype=torch.bool, device="cpu"
)
block_thres = frame_size * width_frame
num_block = math.ceil(num_frame * frame_size / block_size)
idx = torch.arange(num_block, device="cpu")
band = (idx[:, None] - idx[None, :]).abs() < int(block_thres // block_size)
block_mask = (~band).repeat_interleave(q_num_per_block, dim=0) \
.repeat_interleave(k_num_per_block, dim=1)
pixel_attn_mask = (~block_mask).repeat_interleave(self.block_size_q, dim=0) \
.repeat_interleave(self.block_size_k, dim=1)
pixel_attn_mask = pixel_attn_mask[:num_frame * frame_size, :num_frame * frame_size]
if mask_name == "spatial":
attention_mask[:-context_length, :-context_length] = pixel_attn_mask
attention_mask[-context_length:, :] = 1
attention_mask[:, -context_length:] = 1
context_blocks_q = math.ceil((context_length) / block_size_q)
context_blocks_k = math.ceil((context_length) / block_size_k)
block_mask[-context_blocks_q:, :] = False
block_mask[:, -context_blocks_k:] = False
else:
pixel_attn_mask = (
pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame)
.permute(1, 0, 3, 2)
.reshape(frame_size * num_frame, frame_size * num_frame)
)
attention_mask[:-context_length, :-context_length] = pixel_attn_mask
attention_mask[-context_length:, :] = 1
attention_mask[:, -context_length:] = 1
attention_mask = attention_mask[:self.sample_mse_max_row].to(device)
return attention_mask, block_mask
def get_sabi(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_frames: int, frame_size: int):
b, h, s, _ = q.shape
device = q.device
self.device = device
frame_width = self.sparsity_to_width(self.sparsity, num_frames, frame_size)
context_length_cu = s - num_frames * frame_size
cache_key = (
int(s),
int(num_frames),
int(frame_size),
str(device),
int(self.sample_mse_max_row),
)
if self._svg_mask_cache_key != cache_key:
spatial_mask, block_mask = self.get_attention_mask("spatial", s, num_frames,
frame_size, frame_width, device=device)
temporal_mask, _ = self.get_attention_mask("temporal", s, num_frames,
frame_size, frame_width, device=device)
self.attention_masks = [spatial_mask, temporal_mask]
base_sabi_tensor = block2sabi(block_mask)
block_num_q, block_num_k = base_sabi_tensor.shape[0], base_sabi_tensor.shape[1]
self.sabi_tensor = base_sabi_tensor.unsqueeze(0).unsqueeze(0).expand(b, h, block_num_q, block_num_k)
self._svg_mask_cache_key = cache_key
elif self.sabi_tensor is None or self.sabi_tensor.shape[0] != b or self.sabi_tensor.shape[1] != h:
base_sabi_tensor = self.sabi_tensor[0, 0]
block_num_q, block_num_k = base_sabi_tensor.shape[0], base_sabi_tensor.shape[1]
self.sabi_tensor = base_sabi_tensor.unsqueeze(0).unsqueeze(0).expand(b, h, block_num_q, block_num_k)
mse_result = self.sample_mse(q, k, v, context_length_cu)
pattern = (mse_result["spatial"] < mse_result["temporal"]).flatten()
return pattern
def get_final_sabi(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_frames: int, frame_size: int):
'''sabi连接must_keep的block indices,得到final_sabi'''
sink_frame_len = 0
all_token_len = q.shape[2]
txt_token_len = q.shape[2] - num_frames * frame_size
self.get_must_keep_blocks_indices(
token_len=all_token_len,
sink_frame_len=sink_frame_len,
txt_token_len=txt_token_len,
)
pattern = self.get_sabi(q, k, v, num_frames, frame_size)
return pattern, self.sabi_tensor
class HunyuanVideoTopKAdapter(TopKPredictor):
def __init__(self, sparse_config, sparse_params=None):
super().__init__(sparse_config, sparse_params)
self.update_sparse_params(sparse_params)
def update_sparse_params(self, sparse_params):
self.sink_frame_len = sparse_params.get("sink_frame_len", 0)
self.img_token_len = sparse_params.get("img_token_len", 0)
self.frame_num = sparse_params.get("frame_num", 0)
self.frame_patch_h = sparse_params.get("frame_patch_h", 0)
self.frame_patch_w = sparse_params.get("frame_patch_w", 0)
logger.info(f"update sparse params successfully, sink_frame_len: {self.sink_frame_len}\
img_token_len: {self.img_token_len}. ")
def get_must_keep_blocks_indices(self, **kwargs):
token_len_q = kwargs.get("token_len_q", kwargs.get("token_len"))
token_len_k = kwargs.get("token_len_k", kwargs.get("token_len"))
sink_frame_len = kwargs.get("sink_frame_len", 0)
sink_frame_len_q = kwargs.get("sink_frame_len_q", sink_frame_len)
sink_frame_len_k = kwargs.get("sink_frame_len_k", sink_frame_len)
txt_len = kwargs.get("txt_token_len", 0)
txt_len_q = kwargs.get("txt_token_len_q", txt_len)
txt_len_k = kwargs.get("txt_token_len_k", txt_len)
sink_txt_len_k = sink_frame_len_k + txt_len_k
sink_txt_blocks_k = math.ceil(sink_txt_len_k / self.block_size_k)
num_blocks_k = math.ceil(token_len_k / self.block_size_k)
k_num_of_last_block = token_len_k % self.block_size_k
multi_last_k_blocks_token_num = (sink_txt_blocks_k - 1) * self.block_size_k + k_num_of_last_block
if multi_last_k_blocks_token_num < sink_txt_len_k:
sink_txt_blocks_k += 1
sink_txt_start_indices = num_blocks_k - sink_txt_blocks_k
must_keep_indices_k = torch.cat([
torch.arange(sink_txt_start_indices, num_blocks_k)
])
sink_txt_len_q = sink_frame_len_q + txt_len_q
sink_txt_blocks_q = math.ceil(sink_txt_len_q / self.block_size_q)
num_blocks_q = math.ceil(token_len_q / self.block_size_q)
q_num_of_last_block = token_len_q % self.block_size_q
multi_last_q_blocks_token_num = (sink_txt_blocks_q - 1) * self.block_size_q + q_num_of_last_block
if multi_last_q_blocks_token_num < sink_txt_len_q:
sink_txt_blocks_q += 1
sink_txt_start_indices = num_blocks_q - sink_txt_blocks_q
must_keep_indices_q = torch.cat([
torch.arange(sink_txt_start_indices, num_blocks_q)
])
return sink_txt_blocks_q, must_keep_indices_q, sink_txt_blocks_k, must_keep_indices_k
def combined_sabi_tensor_list(self, mid_sabi_list, must_keep_indices_q, must_keep_indices_k, num_blocks_k):
sink_txt_suffix = must_keep_indices_k \
.unsqueeze(0) \
.expand(mid_sabi_list[0].size(0), -1)\
.to(mid_sabi_list[0].device)
full_k_indices = torch.arange(num_blocks_k, device=mid_sabi_list[0].device) \
.view(1, 1, num_blocks_k) \
.expand(len(mid_sabi_list), must_keep_indices_q.shape[0], num_blocks_k)
padded_sabi_list = []
for mid_sabi in mid_sabi_list:
padded_sabi = torch.cat((mid_sabi, sink_txt_suffix[:, :]), dim=1)
padded_sabi_list.append(padded_sabi)
padded_sabi_tensor = self.padding_sabi(padded_sabi_list, max_width=num_blocks_k)
padded_sabi_tensor = torch.cat((padded_sabi_tensor, full_k_indices[:, :]), dim=1)
return padded_sabi_tensor
def move_sink_frame_to_end(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sink_frame_len: int):
return self._move_sink_qkv_to_end(q, k, v, q_sink_len=sink_frame_len, kv_sink_len=sink_frame_len)
def move_sink_frame_back(self, attn_out: torch.Tensor, sink_frame_len: int):
s_q = attn_out.shape[2]
sink_attn_out = attn_out[:, :, s_q - sink_frame_len:, :]
rest_attn_out = attn_out[:, :, :s_q - sink_frame_len, :]
attn_out = torch.cat((sink_attn_out, rest_attn_out), dim=2)
return attn_out
def combined_sabi_tensor(self, mid_sabi_tensor, must_keep_indices_q, must_keep_indices_k, total_num_blocks_k):
b, h, num_blocks_q, _ = mid_sabi_tensor.shape
must_keep_k = must_keep_indices_k.shape[0]
added_q_row_num = must_keep_indices_q.shape[0]
sink_txt_suffix = must_keep_indices_k.view(1, 1, 1, must_keep_k) \
.expand(b, h, num_blocks_q, must_keep_k) \
.to(mid_sabi_tensor[0].device)
full_k_indices = torch.arange(total_num_blocks_k, device=mid_sabi_tensor.device, dtype=self.index_type) \
.view(1, 1, 1, total_num_blocks_k) \
.expand(b, h, added_q_row_num, total_num_blocks_k)
padded_sabi_tensor = torch.cat((sink_txt_suffix, mid_sabi_tensor), dim=3)[:, :, :, :total_num_blocks_k]
padded_sabi_tensor = torch.cat((padded_sabi_tensor, full_k_indices), dim=2)
return padded_sabi_tensor
def _resolve_ulysses_local_sparsity_override(
self,
*,
local_head_num: int,
device: torch.device,
ulysses_pg=None,
ulysses_rank: int = 0,
ulysses_world_size: int = 1,
) -> Optional[torch.Tensor]:
if ulysses_world_size <= 1:
return None
step_idx, layer_idx = self.get_effective_indices()
layer_sparsity = self._get_layer_sparsity_runtime_head_order(step_idx, layer_idx, device=device)
if int(layer_sparsity.numel()) == local_head_num:
return layer_sparsity
head_start = ulysses_rank * local_head_num
head_end = head_start + local_head_num
if ulysses_pg is not None and dist.is_initialized():
cache_key = (
int(ulysses_world_size),
int(ulysses_rank),
int(local_head_num),
str(device),
int(layer_sparsity.numel()),
)
cached = self._ulysses_head_split_cache.get(cache_key)
if cached is not None:
head_start, head_end = cached
else:
local_h = torch.tensor([local_head_num], dtype=torch.int64, device=device)
gathered_h = torch.empty((ulysses_world_size, 1), dtype=torch.int64, device=device)
dist.all_gather_into_tensor(gathered_h, local_h, group=ulysses_pg)
head_splits = gathered_h.view(-1).to(device="cpu", dtype=torch.long).tolist()
head_start = int(sum(head_splits[:ulysses_rank]))
head_end = head_start + int(head_splits[ulysses_rank])
self._ulysses_head_split_cache[cache_key] = (head_start, head_end)
return layer_sparsity[head_start:head_end].contiguous()
def attention(self, q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
batch_size=1,
head_sabi: Optional[torch.Tensor] = None,
ulysses_pg=None,
ulysses_rank: int = 0,
ulysses_world_size: int = 1,
return_bshd: bool = False,
softmax_scale: Optional[float] = None,
img_token_len_q: Optional[int] = None,
img_token_len_k: Optional[int] = None,
sink_frame_len_q: Optional[int] = None,
sink_frame_len_k: Optional[int] = None,):
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT["BNSD"]
b, s, n, d = q.shape
q, k, v = pre_attn_layout(q), pre_attn_layout(k), pre_attn_layout(v)
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(d)
if cu_seqlens_q is None or cu_seqlens_kv is None:
if cu_seqlens_q is not None or cu_seqlens_kv is not None:
raise ValueError("TopK sparse attention requires both cu_seqlens_q and cu_seqlens_kv, or neither.")
x = torch_npu.npu_fused_infer_attention_score(
q, k, v,
num_heads=n,
input_layout="BNSD",
scale=scale,
)[0]
x = post_attn_layout(x)
if return_bshd:
return x
return x.reshape(b, s, -1)
q1 = q[:, :, :cu_seqlens_q[1], :].contiguous()
k1 = k[:, :, :cu_seqlens_kv[1], :].contiguous()
v1 = v[:, :, :cu_seqlens_kv[1], :].contiguous()
if ulysses_world_size > 1:
qkv1 = torch.cat([q1, k1, v1], dim=1).contiguous()
qkv1 = self._apply_local_seq_remap_tensor(
qkv1,
ulysses_world_size=ulysses_world_size,
)
q1, k1, v1 = torch.split(qkv1, [n, n, n], dim=1)
sink_frame_len_q_eff = int(self.sink_frame_len if sink_frame_len_q is None else sink_frame_len_q)
sink_frame_len_k_eff = int(self.sink_frame_len if sink_frame_len_k is None else sink_frame_len_k)
sink_frame_len_q_eff = max(0, min(sink_frame_len_q_eff, int(q1.shape[2])))
sink_frame_len_k_eff = max(0, min(sink_frame_len_k_eff, int(k1.shape[2]), int(v1.shape[2])))
sink_dense_q_len = sink_frame_len_q_eff
q1_dense_ref = q1
k1_dense_ref = k1
v1_dense_ref = v1
q1, k1, v1 = self._move_sink_qkv_to_end(
q1, k1, v1, q_sink_len=sink_frame_len_q_eff, kv_sink_len=sink_frame_len_k_eff
)
sparsity_override = self._resolve_ulysses_local_sparsity_override(
local_head_num=n,
device=q1.device,
ulysses_pg=ulysses_pg,
ulysses_rank=ulysses_rank,
ulysses_world_size=ulysses_world_size,
)
sabi_tensor = self.get_final_sabi_v2(
q1,
k1,
sink_frame_len=self.sink_frame_len,
sink_frame_len_q=sink_frame_len_q_eff,
sink_frame_len_k=sink_frame_len_k_eff,
img_token_len=self.img_token_len,
img_token_len_q=img_token_len_q if img_token_len_q is not None else self.img_token_len,
img_token_len_k=img_token_len_k if img_token_len_k is not None else self.img_token_len,
sparsity_override=sparsity_override,
)
sabi_tensor = sabi_tensor.to(device=q1.device, dtype=torch.uint16).contiguous()
actseqlen = [cu_seqlens_q[1]] * b
actseqlenkv = [cu_seqlens_kv[1]] * b
attn1, _ = torch_bsa.blitz_sparse_attention(
q1,
k1,
v1,
sabi=sabi_tensor,
actual_seq_lengths=[cu_seqlens_q[1]] * b,
actual_seq_lengths_kv=[cu_seqlens_kv[1]] * b,
num_heads=n,
num_key_value_heads=n,
input_layout="BNSD",
scale_value=scale,
atten_mask=None,
sparse_mode=0,
softmax_lse_flag=False,
block_shape=[self.block_size_q, self.block_size_k],
)
attn1 = self.move_sink_frame_back(attn1, sink_frame_len_q_eff)
attn1 = self._replace_prefix_with_dense_attention(
attn1,
q1_dense_ref,
k1_dense_ref,
v1_dense_ref,
prefix_len=sink_dense_q_len,
num_heads=n,
scale=scale,
)
if ulysses_world_size > 1:
attn1 = self._apply_local_seq_remap_tensor(
attn1,
ulysses_world_size=ulysses_world_size,
reverse=True,
)
if cu_seqlens_q[1] < s:
attn2 = torch_npu.npu_fused_infer_attention_score(
q[:, :, cu_seqlens_q[1]:, :],
k[:, :, cu_seqlens_kv[1]:, :],
v[:, :, cu_seqlens_kv[1]:, :],
num_heads=n,
input_layout="BNSD",
scale=scale,
)[0]
x = torch.cat([attn1, attn2], dim=2)
else:
x = attn1
x = post_attn_layout(x)
if return_bshd:
return x
out = x.reshape(b, s, -1)
return out
class HunyuanVideoSVGAdapter(SVGPredictor):
def __init__(self, sparse_config, sparse_params=None):
super().__init__(sparse_config, sparse_params)
self.ring_svg_overlap = bool(sparse_config["SVG"].get("ring_sparse_overlap", False))
self.ring_sample_mse_rows = int(sparse_config["SVG"].get("ring_sample_mse_rows", 64))
self.update_sparse_params(sparse_params)
def update_sparse_params(self, sparse_params):
self.sink_frame_len = int(sparse_params.get("sink_frame_len", 0))
self.img_token_len = int(sparse_params.get("img_token_len", 0))
self.frame_num = int(sparse_params.get("frame_num", 0))
self.frame_patch_h = int(sparse_params.get("frame_patch_h", 0))
self.frame_patch_w = int(sparse_params.get("frame_patch_w", 0))
logger.info(f"update sparse params successfully, sink_frame_len: {self.sink_frame_len}\
img_token_len: {self.img_token_len}, frame_num:{self.frame_num}. ")
def forward_ring_sparse(
self,
runtime_attn,
block_args: dict,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
if self.ring_svg_overlap:
logger.info("Ring SVG overlap is not adapted yet; using the Ring SVG path.")
return self._forward_ring_svg_global(runtime_attn, block_args, softmax_scale)
def _ring_all_gather_canonical_img(self, runtime_attn, x: torch.Tensor) -> torch.Tensor:
x_rank_major = self._ring_all_gather_seq_bnsd(runtime_attn, x.transpose(1, 2).contiguous())
x_rank_major = x_rank_major.transpose(1, 2).contiguous()
frame_num = self.frame_num
frame_patch_h = self.frame_patch_h
frame_patch_w = self.frame_patch_w
width_splits = self._split_patched_width(frame_patch_w, int(runtime_attn.ring_world_size))
token_splits = [frame_num * frame_patch_h * width for width in width_splits]
b, _, h, d = x_rank_major.shape
rank_parts = torch.split(x_rank_major, token_splits, dim=1)
rank_parts = [
part.reshape(b, frame_num, frame_patch_h, width_splits[rank], h, d)
for rank, part in enumerate(rank_parts)
]
return torch.cat(rank_parts, dim=3).reshape(b, -1, h, d).contiguous()
def _get_ring_svg_local_metadata(
self,
runtime_attn,
*,
device: torch.device,
img_q_len: int,
img_kv_len: int,
q_len: int,
kv_len: int,
):
frame_num = self.frame_num
frame_patch_h = self.frame_patch_h
frame_patch_w = self.frame_patch_w
world_size = int(runtime_attn.ring_world_size)
frame_size = frame_patch_h * frame_patch_w
width_splits = self._split_patched_width(frame_patch_w, world_size)
rank = int(runtime_attn.ring_rank)
width_start, width_len = sum(width_splits[:rank]), width_splits[rank]
local_ids = torch.arange(img_q_len, device=device, dtype=torch.long)
local_ids = local_ids.reshape(frame_num, frame_patch_h, width_len)
local_ids = local_ids + torch.arange(frame_num, device=device).view(-1, 1, 1) * (
frame_size - frame_patch_h * width_len
)
local_ids = local_ids + torch.arange(frame_patch_h, device=device).view(1, -1, 1) * (
frame_patch_w - width_len
)
local_ids = (local_ids + width_start).reshape(-1)
block_num_q = math.ceil(int(q_len) / int(self.block_size_q))
block_num_k = math.ceil(int(kv_len) / int(self.block_size_k))
img_block_num_k = math.ceil(int(img_kv_len) / int(self.block_size_k))
key_ids = torch.arange(img_kv_len, device=device, dtype=torch.long)
key_blocks = key_ids // int(self.block_size_k)
svg_block_size = 512
spatial_key_groups = key_ids // svg_block_size
temporal_key_groups = (
(key_ids % frame_size) * frame_num + key_ids // frame_size
) // svg_block_size
group_num = math.ceil(int(img_kv_len) / svg_block_size)
relation_spatial = torch.zeros((group_num, img_block_num_k), dtype=torch.bool, device=device)
relation_temporal = torch.zeros_like(relation_spatial)
relation_spatial[spatial_key_groups, key_blocks] = True
relation_temporal[temporal_key_groups, key_blocks] = True
context_length = int(kv_len) - img_kv_len
width_frame = ((context_length + frame_num * frame_size) * (1 - math.sqrt(self.sparsity)) - context_length)
width_frame = width_frame / frame_size
radius = max(1, int(frame_size * width_frame // svg_block_size))
group_ids = torch.arange(group_num, device=device)
nearby = (group_ids[:, None] - group_ids[None, :]).abs() < radius
spatial_groups = (nearby.float() @ relation_spatial.float()) > 0
temporal_groups = (nearby.float() @ relation_temporal.float()) > 0
spatial_rows, temporal_rows = [], []
for row in range(math.ceil(img_q_len / int(self.block_size_q))):
row_ids = local_ids[row * int(self.block_size_q):min((row + 1) * int(self.block_size_q), img_q_len)]
spatial_rows.append(spatial_groups.index_select(0, row_ids // svg_block_size).any(dim=0))
temporal_ids = (
(row_ids % frame_size) * frame_num + row_ids // frame_size
) // svg_block_size
temporal_rows.append(temporal_groups.index_select(0, temporal_ids).any(dim=0))
def build_sabi_template(img_rows):
selected = torch.zeros((block_num_q, block_num_k), dtype=torch.bool, device=device)
selected[:img_rows.shape[0], :img_block_num_k] = img_rows
txt_k_start = min(block_num_k, int(img_kv_len) // int(self.block_size_k))
selected[:, txt_k_start:] = True
txt_q_start = min(block_num_q, int(img_q_len) // int(self.block_size_q))
selected[txt_q_start:, :] = True
block_ids = torch.arange(block_num_k, dtype=self.index_type, device=device)
block_ids = block_ids.view(1, block_num_k).expand_as(selected)
sabi = torch.where(selected, block_ids, torch.full_like(block_ids, block_num_k))
sabi = torch.sort(sabi, dim=-1).values
return torch.where(sabi < block_num_k, sabi, -1).to(dtype=torch.uint16).contiguous()
metadata = {
"frame_num": frame_num,
"frame_size": frame_size,
"svg_block_size": svg_block_size,
"local_ids": local_ids,
"nearby": nearby,
"spatial_sabi": build_sabi_template(torch.stack(spatial_rows)),
"temporal_sabi": build_sabi_template(torch.stack(temporal_rows)),
}
return metadata
def _build_ring_svg_pattern(
self,
runtime_attn,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
img_q_len: int,
img_kv_len: int,
metadata,
) -> torch.Tensor:
local_ids = metadata["local_ids"]
frame_num = metadata["frame_num"]
frame_size = metadata["frame_size"]
svg_block_size = metadata["svg_block_size"]
nearby = metadata["nearby"]
key_global_ids = torch.arange(img_kv_len, dtype=torch.long, device=q.device)
spatial_key_groups = key_global_ids // svg_block_size
temporal_key_groups = (
(key_global_ids % frame_size) * frame_num + key_global_ids // frame_size
) // svg_block_size
sample_num = max(1, math.ceil(self.ring_sample_mse_rows / int(runtime_attn.ring_world_size)))
sample_num = min(sample_num, int(img_q_len))
sample_local = torch.randint(0, int(img_q_len), (sample_num,), device=q.device)
sample_global = local_ids.index_select(0, sample_local)
sampled_q = q[:, :, sample_local, :]
image_k, image_v = k[:, :, :img_kv_len, :], v[:, :, :img_kv_len, :]
scores = torch.matmul(sampled_q, image_k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
golden = torch.matmul(F.softmax(scores, dim=-1), image_v)
mse_values = []
for query_groups, key_groups in (
(sample_global // svg_block_size, spatial_key_groups),
(
((sample_global % frame_size) * frame_num + sample_global // frame_size) // svg_block_size,
temporal_key_groups,
),
):
allowed = nearby.index_select(0, query_groups).index_select(1, key_groups)
sparse_hidden = torch.matmul(F.softmax(scores.masked_fill(~allowed, float("-inf")), dim=-1), image_v)
mse_values.append(torch.mean((sparse_hidden - golden) ** 2, dim=(0, 2, 3)))
mse = torch.stack(mse_values)
dist.all_reduce(mse, op=dist.ReduceOp.SUM, group=runtime_attn.ring_pg)
return mse[0] < mse[1]
@staticmethod
def _select_ring_svg_local_sabi(pattern: torch.Tensor, metadata, batch_size: int) -> torch.Tensor:
sabi = torch.where(
pattern.view(-1, 1, 1),
metadata["spatial_sabi"],
metadata["temporal_sabi"],
)
return sabi.unsqueeze(0).expand(batch_size, -1, -1, -1).contiguous()
def _build_ring_svg_local_sabi(
self,
runtime_attn,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
img_q_len: int,
img_kv_len: int,
) -> torch.Tensor:
metadata = self._get_ring_svg_local_metadata(
runtime_attn,
device=q.device,
img_q_len=img_q_len,
img_kv_len=img_kv_len,
q_len=int(q.shape[2]),
kv_len=int(k.shape[2]),
)
pattern = self._build_ring_svg_pattern(
runtime_attn,
q,
k,
v,
img_q_len=img_q_len,
img_kv_len=img_kv_len,
metadata=metadata,
)
return self._select_ring_svg_local_sabi(pattern, metadata, q.shape[0])
def _forward_ring_svg_global(
self,
runtime_attn,
block_args: dict,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
q_img_local, k_img_local, v_img_local = (
block_args["q_img_local"], block_args["k_img_local"], block_args["v_img_local"]
)
txt_q, txt_k, txt_v = block_args["txt_q"], block_args["txt_k"], block_args["txt_v"]
k_img_global = self._ring_all_gather_canonical_img(runtime_attn, k_img_local)
v_img_global = self._ring_all_gather_canonical_img(runtime_attn, v_img_local)
q_full = torch.cat([q_img_local, txt_q], dim=1).contiguous()
k_full = torch.cat([k_img_global, txt_k], dim=1).contiguous()
v_full = torch.cat([v_img_global, txt_v], dim=1).contiguous()
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT["BNSD"]
q_exec = pre_attn_layout(q_full).contiguous()
k_exec = pre_attn_layout(k_full).contiguous()
v_exec = pre_attn_layout(v_full).contiguous()
b, n, s_q, d = q_exec.shape
s_kv = int(k_exec.shape[2])
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(d)
sabi_tensor = self._build_ring_svg_local_sabi(
runtime_attn,
q_exec,
k_exec,
v_exec,
img_q_len=int(q_img_local.shape[1]),
img_kv_len=int(k_img_global.shape[1]),
)
out_full, _ = self._call_blitz_sparse_attention(
q_exec,
k_exec,
v_exec,
sabi=sabi_tensor.to(device=q_exec.device, dtype=torch.uint16).contiguous(),
actual_seq_lengths=[s_q] * b,
actual_seq_lengths_kv=[s_kv] * b,
num_heads=n,
scale=scale,
)
return post_attn_layout(out_full).contiguous()
def get_must_keep_blocks_indices(self, **kwargs):
token_len = kwargs["token_len"]
sink_frame_len = kwargs['sink_frame_len']
txt_len = kwargs['txt_token_len']
sink_txt_len = sink_frame_len + txt_len
sink_txt_blocks_k = math.ceil(sink_txt_len / self.block_size_k)
num_blocks_k = math.ceil(token_len / self.block_size_k)
k_num_of_last_block = token_len % self.block_size_k
multi_last_k_blocks_token_num = (sink_txt_blocks_k - 1) * self.block_size_k + k_num_of_last_block
if multi_last_k_blocks_token_num < sink_txt_len:
sink_txt_blocks_k += 1
sink_txt_start_indices = num_blocks_k - sink_txt_blocks_k
must_keep_indices_k = torch.cat([
torch.arange(sink_txt_start_indices, num_blocks_k)
])
sink_txt_blocks_q = math.ceil(sink_txt_len / self.block_size_q)
num_blocks_q = math.ceil(token_len / self.block_size_q)
q_num_of_last_block = token_len % self.block_size_q
multi_last_q_blocks_token_num = (sink_txt_blocks_q - 1) * self.block_size_q + q_num_of_last_block
if multi_last_q_blocks_token_num < sink_txt_len:
sink_txt_blocks_q += 1
sink_txt_start_indices = num_blocks_q - sink_txt_blocks_q
must_keep_indices_q = torch.cat([
torch.arange(sink_txt_start_indices, num_blocks_q)
])
return sink_txt_blocks_q, must_keep_indices_q, sink_txt_blocks_k, must_keep_indices_k
def rearrange_x(self, x, frame_num, frame_size):
_, d = x.shape
s_wocontext = frame_num * frame_size
x_wocontext = x[:s_wocontext]
x_context = x[s_wocontext:]
x_wocontext = x_wocontext.reshape(frame_num, frame_size, d) \
.permute(1, 0, 2) \
.reshape(frame_num * frame_size, -1) \
.contiguous()
x = torch.cat([x_wocontext, x_context], dim=0)
return x
def inv_rearrange_x(self, x, frame_num, frame_size):
_, d = x.shape
s_wocontext = frame_num * frame_size
x_wocontext = x[:s_wocontext]
x_context = x[s_wocontext:]
x_wocontext = x_wocontext.reshape(frame_size, frame_num, d) \
.permute(1, 0, 2) \
.reshape(frame_num * frame_size, -1) \
.contiguous()
x = torch.cat([x_wocontext, x_context], dim=0)
return x
def build_sp_sabi_before_head_shard(self, block_args: dict):
if block_args["v_local_bnsd"] is None:
raise ValueError("SVG sparse path requires v_local_bnsd.")
q_local_bnsd, k_local_bnsd, v_local_bnsd = (
block_args["q_local_bnsd"], block_args["k_local_bnsd"], block_args["v_local_bnsd"]
)
q_local, k_local, v_local = (
q_local_bnsd.transpose(1, 2).contiguous(),
k_local_bnsd.transpose(1, 2).contiguous(),
v_local_bnsd.transpose(1, 2).contiguous(),
)
joint_q_local = block_args.get("joint_q_local_bnsd")
if joint_q_local is not None:
joint_q_local = joint_q_local.transpose(1, 2).contiguous()
q_global, k_global, v_global = q_local, k_local, v_local
if block_args["ulysses_world_size"] > 1:
qkv_global = torch.cat([q_global, k_global, v_global], dim=1).contiguous()
qkv_global = self._apply_local_seq_remap_tensor(
qkv_global,
ulysses_world_size=block_args["ulysses_world_size"],
)
h = q_global.shape[1]
q_global, k_global, v_global = torch.split(qkv_global, [h, h, h], dim=1)
frame_size = self.img_token_len // self.frame_num
q_meta = q_global
pattern, sabi_tensor = self.get_final_sabi(q_meta, k_global, v_global, self.frame_num, frame_size)
img_blocks_q = math.ceil(self.img_token_len / self.block_size_q)
sabi_tensor = sabi_tensor[:, :, :img_blocks_q, :].contiguous()
return {"pattern": pattern, "sabi": sabi_tensor}
def attention(self, q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
batch_size=1,
head_sabi: Optional[Dict[str, torch.Tensor]] = None,
ulysses_pg=None,
ulysses_rank: int = 0,
ulysses_world_size: int = 1,
return_bshd: bool = False,
softmax_scale: Optional[float] = None,
joint_q_local_bnsd: Optional[torch.Tensor] = None,
):
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT["BNSD"]
frame_num = self.frame_num
frame_size = (q.shape[1] - self.context_length) // frame_num
b, s, n, d = q.shape
q, k, v = pre_attn_layout(q), pre_attn_layout(k), pre_attn_layout(v)
scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(d)
if cu_seqlens_q is None or cu_seqlens_kv is None:
if cu_seqlens_q is not None or cu_seqlens_kv is not None:
raise ValueError("SVG sparse attention requires both cu_seqlens_q and cu_seqlens_kv, or neither.")
x = torch_npu.npu_fused_infer_attention_score(
q, k, v,
num_heads=n,
input_layout="BNSD",
scale=scale,
)[0]
else:
actseqlen = [cu_seqlens_q[1]] * b
actseqlenkv = [cu_seqlens_kv[1]] * b
q1 = q[:, :, :cu_seqlens_q[1], :].contiguous()
k1 = k[:, :, :cu_seqlens_kv[1], :].contiguous()
v1 = v[:, :, :cu_seqlens_kv[1], :].contiguous()
qkv1 = torch.cat([q1, k1, v1], dim=1).contiguous()
qkv1 = self._apply_local_seq_remap_tensor(
qkv1,
ulysses_world_size=ulysses_world_size,
)
q1, k1, v1 = torch.split(qkv1, [n, n, n], dim=1)
if head_sabi is None:
if ulysses_world_size <= 1:
pattern, sabi_tensor = self.get_final_sabi(q1, k1, v1, frame_num, frame_size)
head_sabi = {"pattern": pattern, "sabi": sabi_tensor}
else:
if joint_q_local_bnsd is None:
prefix_q_len = int(cu_seqlens_q[1])
img_q_len = min(int(self.img_token_len), prefix_q_len)
joint_q_local_bnsd = q[:, :, img_q_len:prefix_q_len, :].transpose(1, 2).contiguous()
head_sabi = self.build_sp_sabi_before_head_shard(
{
"q_local_bnsd": q.transpose(1, 2).contiguous(),
"k_local_bnsd": k.transpose(1, 2).contiguous(),
"v_local_bnsd": v.transpose(1, 2).contiguous(),
"joint_q_local_bnsd": joint_q_local_bnsd,
"ulysses_pg": ulysses_pg,
"ulysses_rank": ulysses_rank,
"ulysses_world_size": ulysses_world_size,
}
)
pattern = head_sabi["pattern"]
sabi_tensor = head_sabi["sabi"]
for h in range(n):
if pattern[h] == False:
q_h = q1[0, h]
k_h = k1[0, h]
v_h = v1[0, h]
q1[0, h] = self.rearrange_x(q_h, frame_num, frame_size).contiguous()
k1[0, h] = self.rearrange_x(k_h, frame_num, frame_size).contiguous()
v1[0, h] = self.rearrange_x(v_h, frame_num, frame_size).contiguous()
sabi_tensor = sabi_tensor.contiguous()
sabi_tensor = sabi_tensor.to(q.device).to(torch.uint16)
attn1, _ = torch_bsa.blitz_sparse_attention(
q1, k1, v1,
sabi=sabi_tensor,
actual_seq_lengths=actseqlen,
actual_seq_lengths_kv=actseqlenkv,
num_heads=n,
num_key_value_heads=n,
input_layout="BNSD",
scale_value=scale,
atten_mask=None,
sparse_mode=0,
softmax_lse_flag=False,
block_shape=[self.block_size_q, self.block_size_k],
)
attn1 = self._take_attention_output(attn1)
for h in range(n):
if pattern[h] == False:
attn1_h = attn1[0, h]
attn1[0, h] = self.inv_rearrange_x(attn1_h, frame_num, frame_size).contiguous()
attn1 = self._apply_local_seq_remap_tensor(
attn1,
ulysses_world_size=ulysses_world_size,
reverse=True,
)
if cu_seqlens_q[1] < s:
attn2 = torch_npu.npu_fused_infer_attention_score(
q[:, :, cu_seqlens_q[1]:, :],
k[:, :, cu_seqlens_kv[1]:, :],
v[:, :, cu_seqlens_kv[1]:, :],
num_heads=n,
input_layout="BNSD",
scale=scale,
)[0]
x = torch.cat([attn1, attn2], dim=2)
else:
x = attn1
x = post_attn_layout(x)
if return_bshd:
return x
out = x.reshape(b, s, -1)
return out
if __name__ == "__main__":
sparse_predictor_manager.from_config(DEFAULT_CONFIG_PATH)
logger.info(sparse_predictor_manager.config)