import gc
import sys
import math
import pickle
from pathlib import Path
from typing import Tuple, Optional, Union, List
import torch
import torch_npu
import torch.nn as nn
from tqdm import tqdm
from einops import rearrange
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from diffsynth.models.qwen_image_dit import (
qwen_image_flash_attention,
apply_rotary_emb_qwen,
QwenEmbedRope,
QwenDoubleStreamAttention,
QwenImageTransformerBlock,
QwenImageDiT,
QwenImageDiTStateDictConverter
)
from diffsynth.trainers.utils import (
DiffusionTrainingModule,
ModelLogger,
launch_training_task
)
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm
class RMSNorm_npu(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
def patched_qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask=None, enable_fp8_attention: bool = False):
if q.dtype == torch.bfloat16:
scale = 1.0 / math.sqrt(q.shape[-1])
x = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num=num_heads,
input_layout="BNSD",
pse=None,
scale=scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
atten_mask=attention_mask
)[0]
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def patched_apply_rotary_emb_qwen(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
):
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_out = torch.view_as_real(x_rotated * freqs_cis.to(torch.complex64)).flatten(3)
return x_out.type_as(x)
class PatchedQwenDoubleStreamAttention(QwenDoubleStreamAttention):
def __init__(
self,
dim_a,
dim_b,
num_heads,
head_dim,
):
super().__init__(
dim_a,
dim_b,
num_heads,
head_dim
)
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = nn.Linear(dim_a, dim_a)
self.to_k = nn.Linear(dim_a, dim_a)
self.to_v = nn.Linear(dim_a, dim_a)
self.norm_q = RMSNorm_npu(head_dim, eps=1e-6)
self.norm_k = RMSNorm_npu(head_dim, eps=1e-6)
self.add_q_proj = nn.Linear(dim_b, dim_b)
self.add_k_proj = nn.Linear(dim_b, dim_b)
self.add_v_proj = nn.Linear(dim_b, dim_b)
self.norm_added_q = RMSNorm_npu(head_dim, eps=1e-6)
self.norm_added_k = RMSNorm_npu(head_dim, eps=1e-6)
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
self.to_add_out = nn.Linear(dim_b, dim_b)
def forward(
self,
image: torch.FloatTensor,
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
seq_txt = txt_q.shape[1]
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_q = patched_apply_rotary_emb_qwen(img_q, img_freqs)
img_k = patched_apply_rotary_emb_qwen(img_k, img_freqs)
txt_q = patched_apply_rotary_emb_qwen(txt_q, txt_freqs)
txt_k = patched_apply_rotary_emb_qwen(txt_k, txt_freqs)
joint_q = torch.cat([txt_q, img_q], dim=2)
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = patched_qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
img_attn_output = self.to_out(img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class PatchedQwenImageDiT(QwenImageDiT):
def __init__(
self,
num_layers: int = 60,
):
super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.txt_norm = RMSNorm_npu(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072)
self.txt_in = nn.Linear(3584, 3072)
self.transformer_blocks = nn.ModuleList(
[
PatchedQwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, single=True)
self.proj_out = nn.Linear(3072, 64)
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
repeat_dim = latents.shape[1]
max_masks = entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask]
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
total_seq_len = sum(seq_lens) + image.shape[1]
patched_masks = []
for i in range(N):
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height // 16, W=width // 16, P=2, Q=2)
patched_masks.append(patched_mask)
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
image_start = sum(seq_lens)
image_end = total_seq_len
cumsum = [0]
single_image_seq = image_end - image_start
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
for i in range(N):
prompt_start = cumsum[i]
prompt_end = cumsum[i + 1]
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
for i in range(N):
for j in range(N):
if i == j:
continue
start_i, end_i = cumsum[i], cumsum[i + 1]
start_j, end_j = cumsum[j], cumsum[j + 1]
attention_mask[:, start_i:end_i, start_j:end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
):
img_shapes = [(latents.shape[0], latents.shape[2] // 2, latents.shape[3] // 2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height // 16, W=width // 16, P=2, Q=2)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb))
conditioning = self.time_text_embed(timestep, image.dtype)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
for block in self.transformer_blocks:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height // 16, W=width // 16, P=2, Q=2)
return image
@staticmethod
def state_dict_converter():
return QwenImageDiTStateDictConverter()
def Patched_launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
args=None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
gradient_accumulation_steps = args.gradient_accumulation_steps
find_unused_parameters = args.find_unused_parameters
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
preprocess_cache_path = Path(f"./preprocessed_cache_{torch.distributed.get_rank()}.pkl")
preprocessed_inputs = []
model.eval()
with torch.no_grad():
for data in tqdm(dataloader, desc="preprocess prompts"):
inputs = model.forward_preprocess(data)
preprocessed_inputs.append(inputs)
if hasattr(model, 'module'):
original_model = model.module
else:
original_model = model
if hasattr(original_model.pipe, 'text_encoder') and original_model.pipe.text_encoder is not None:
if accelerator.is_main_process:
print("开始释放 text_encoder 内存...")
original_model.pipe.text_encoder = original_model.pipe.text_encoder.to("cpu")
del original_model.pipe.text_encoder
original_model.pipe.text_encoder = None
if hasattr(model.pipe, 'text_encoder'):
del model.pipe.text_encoder
model.pipe.text_encoder = None
gc.collect()
torch.npu.empty_cache()
if accelerator.is_main_process:
print("Prompt preprocess complated, release memory allready...")
model.train()
batches_per_epoch = len(preprocessed_inputs)
total_steps = num_epochs * batches_per_epoch
global_step = 0
progress_bar = tqdm(
range(0, total_steps),
initial=global_step,
desc="Steps",
disable=not accelerator.is_local_main_process,
)
for epoch_id in range(num_epochs):
for inputs in preprocessed_inputs:
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model({}, inputs=inputs)
accelerator.backward(loss)
optimizer.step()
global_step += 1
progress_bar.update(1)
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
progress_bar.set_postfix(refresh=False, **logs)
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
def apply_patches():
dit_module = sys.modules["diffsynth.models.qwen_image_dit"]
dit_module.qwen_image_flash_attention = patched_qwen_image_flash_attention
dit_module.apply_rotary_emb_qwen = patched_apply_rotary_emb_qwen
dit_module.QwenDoubleStreamAttention = PatchedQwenDoubleStreamAttention
dit_module.QwenImageDiT = PatchedQwenImageDiT
trainer_module = sys.modules["diffsynth.trainers.utils"]
trainer_module.launch_training_task = Patched_launch_training_task