import threading
from contextlib import contextmanager
from typing import Optional
import diffusers
import torch
import torch.nn.functional as F
from aenum import extend_enum
from diffusers.models.attention_dispatch import _AttentionBackendRegistry
from ..parallel_group import ParallelGroup
_thread_local = threading.local()
if not hasattr(diffusers.models.attention_dispatch.AttentionBackendName, "TENSOR_CAST"):
extend_enum(
diffusers.models.attention_dispatch.AttentionBackendName,
"TENSOR_CAST",
"tensor_cast",
)
def set_sp_group(sp_group: Optional[ParallelGroup]):
_thread_local.sp_group = sp_group
def get_sp_group() -> Optional[ParallelGroup]:
return getattr(_thread_local, "sp_group", None)
@_AttentionBackendRegistry.register("tensor_cast")
def _attention(query, key, value, **kwargs):
sp_group = get_sp_group()
if sp_group is None:
return torch.ops.tensor_cast.attention(query, key, value, None, None, None, None, None)
ulysses_size = sp_group.world_size
batch_size, seq_per_rank, num_heads, head_dim = query.shape
batch_size_kv, seq_per_rank_kv, num_heads_kv, head_dim_kv = key.shape
input_tensor_q = torch.ones(
(batch_size, seq_per_rank, num_heads // ulysses_size, head_dim),
dtype=query.dtype,
device=query.device,
)
input_tensor_kv = torch.ones(
(batch_size_kv, seq_per_rank_kv, num_heads_kv // ulysses_size, head_dim_kv),
dtype=query.dtype,
device=query.device,
)
input_split_sizes = [1 for _ in range(ulysses_size - 1)]
output_split_sizes = [1 for _ in range(ulysses_size - 1)]
_ = sp_group.all_to_all(
input_tensor_q,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
)
_ = sp_group.all_to_all(
input_tensor_kv,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
)
_ = sp_group.all_to_all(
input_tensor_kv,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
)
query = query.view(batch_size, seq_per_rank * ulysses_size, num_heads // ulysses_size, head_dim)
key = key.view(
batch_size_kv,
seq_per_rank_kv * ulysses_size,
num_heads_kv // ulysses_size,
head_dim_kv,
)
value = value.view(
batch_size_kv,
seq_per_rank_kv * ulysses_size,
num_heads_kv // ulysses_size,
head_dim_kv,
)
out = torch.ops.tensor_cast.attention(query, key, value, None, None, None, None, None)
_ = sp_group.all_to_all(
input_tensor_q,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
)
out = out.view(batch_size, seq_per_rank, num_heads, head_dim)
return out
@contextmanager
def use_custom_sdpa():
original_sdpa = F.scaled_dot_product_attention
def _custom_sdpa(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
return torch.ops.tensor_cast.attention(q, k, v, attn_mask, None, None, None, None)
F.scaled_dot_product_attention = _custom_sdpa
try:
yield
finally:
F.scaled_dot_product_attention = original_sdpa