import inspect
import sys
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch_npu
from diffusers.models.transformers.transformer_flux import (
FluxAttention,
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
)
from diffusers.utils import logging
logger = logging.get_logger(__name__)
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if dim % 2 != 0:
raise ValueError("dim must be divisible by 2 (even number)")
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos)
theta = theta * ntk_factor
freqs = (
1.0
/ (
theta
** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)
)
/ linear_factor
)
freqs = torch.outer(pos, freqs)
is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()
if use_real and repeat_interleave_real:
freqs_cos = (
freqs.cos()
.T.repeat_interleave(2, dim=0, output_size=freqs.shape[1] * 2)
.T.float()
.contiguous()
)
freqs_sin = (
freqs.sin()
.T.repeat_interleave(2, dim=0, output_size=freqs.shape[1] * 2)
.T.float()
.contiguous()
)
return freqs_cos, freqs_sin
elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
)
return freqs_cis
class RMSNorm_npu(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
class PatchedFluxAttention(FluxAttention):
_default_processor_cls = FluxAttnProcessor
_available_processors = [
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super(FluxAttention, self).__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = RMSNorm_npu(
dim_head, eps=eps, elementwise_affine=elementwise_affine
)
self.norm_k = RMSNorm_npu(
dim_head, eps=eps, elementwise_affine=elementwise_affine
)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(
torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
)
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = RMSNorm_npu(dim_head, eps=eps)
self.norm_added_k = RMSNorm_npu(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_k_proj = torch.nn.Linear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_v_proj = torch.nn.Linear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(
inspect.signature(self.processor.__call__).parameters.keys()
)
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [
k
for k, _ in kwargs.items()
if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states,
attention_mask,
image_rotary_emb,
**kwargs,
)
class PatchedFluxPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
def apply_patches():
module = sys.modules["diffusers.models.transformers.transformer_flux"]
module.FluxAttention = PatchedFluxAttention
module.FluxPosEmbed = PatchedFluxPosEmbed