import math
import os
from typing import Optional
import torch
import torch_npu
from torch import nn
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_timestep_embedding(x, outdim):
if len(x.shape) != 2:
raise ValueError("timestep embedding has to be 2 dimensions")
b, dims = x.shape[0], x.shape[1]
x = torch.flatten(x)
emb = timestep_embedding(x, outdim)
emb = torch.reshape(emb, (b, dims * outdim))
return emb
def get_size_embeddings(orig_size, crop_size, target_size, device):
emb1 = get_timestep_embedding(orig_size, 256)
emb2 = get_timestep_embedding(crop_size, 256)
emb3 = get_timestep_embedding(target_size, 256)
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
return vector
def pool_workaround(
text_encoder: CLIPTextModelWithProjection,
last_hidden_state: torch.Tensor,
input_ids: torch.Tensor,
eos_token_id: int,
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
eos_token_mask = (input_ids == eos_token_id).int()
eos_token_index = torch.argmax(eos_token_mask, dim=1).to(
device=last_hidden_state.device
)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
eos_token_index,
]
pooled_output = text_encoder.text_projection(
pooled_output.to(text_encoder.text_projection.weight.dtype)
).to(dtype=last_hidden_state.dtype, device=last_hidden_state.device)
return pooled_output
def get_hidden_states_sdxl(
max_token_length: int,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
):
b_size = input_ids1.size()[0]
input_ids1 = input_ids1.reshape(
(-1, tokenizer1.model_max_length)
)
input_ids2 = input_ids2.reshape(
(-1, tokenizer2.model_max_length)
)
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2]
pool2 = pool_workaround(
text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id
)
n_size = (
1
if max_token_length is None
else max_token_length // (tokenizer1.model_max_length - 2)
)
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
states_list = [hidden_states1[:, 0].unsqueeze(1)]
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(
hidden_states1[:, i : i + tokenizer1.model_max_length - 2]
)
states_list.append(hidden_states1[:, -1].unsqueeze(1))
hidden_states1 = torch.cat(states_list, dim=1)
states_list = [hidden_states2[:, 0].unsqueeze(1)]
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[
:, i : i + tokenizer2.model_max_length - 2
]
states_list.append(chunk)
states_list.append(
hidden_states2[:, -1].unsqueeze(1)
)
hidden_states2 = torch.cat(states_list, dim=1)
pool2 = pool2[::n_size]
if weight_dtype is not None:
hidden_states1 = hidden_states1.to(weight_dtype)
hidden_states2 = hidden_states2.to(weight_dtype)
return hidden_states1, hidden_states2, pool2
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents, epoch, step, weight_dtype
):
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
min_timestep = 0
max_timestep = noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(
min_timestep, max_timestep, (b_size,), device=latents.device
).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(
weight_dtype
)
return noise, noisy_latents, timesteps
class SdxlPretrainModels(nn.Module):
def __init__(
self,
args,
unet: nn.Module,
text_encoder1: nn.Module,
text_encoder2: nn.Module,
weight_dtype,
):
super().__init__()
self.args = args
self.unet = unet
self.text_encoder1 = text_encoder1
self.text_encoder2 = text_encoder2
self.weight_dtype = weight_dtype
def forward(
self,
batch,
accelerator,
noise_scheduler,
latent,
epoch,
step,
encoder_hidden_state1,
encoder_hidden_state2,
pool2,
):
with torch.set_grad_enabled(True):
def compute_time_ids(original_size, crops_coords_top_left):
target_size = (self.args.resolution, self.args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(
accelerator.device, dtype=self.weight_dtype
)
return add_time_ids
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = get_size_embeddings(
orig_size, crop_size, target_size, accelerator.device
).to(device=accelerator.device, dtype=self.weight_dtype)
time_ids = []
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts_list"]):
time_ids.append(compute_time_ids(s, c))
add_time_ids = torch.cat(time_ids)
vector_embedding = torch.cat([pool2, embs], dim=1).to(
device=accelerator.device, dtype=self.weight_dtype
)
text_embedding = torch.cat(
[encoder_hidden_state1, encoder_hidden_state2], dim=2
).to(device=accelerator.device, dtype=self.weight_dtype)
noise, noisy_latents, timesteps = get_noise_noisy_latents_and_timesteps(
self.args,
noise_scheduler,
latent,
epoch,
step,
self.weight_dtype,
)
unet_added_conditions = {
"time_ids": add_time_ids,
"text_embeds": pool2,
}
with accelerator.autocast():
noise_pred = self.unet(
noisy_latents,
timesteps,
text_embedding,
added_cond_kwargs=unet_added_conditions,
).sample
return noise_pred, noise, timesteps
def save_text_encoder(self, model_type, path):
if model_type == 1:
text_encoder_bak = CLIPTextModel.from_pretrained(
self.args.pretrained_model_name_or_path, subfolder="text_encoder", local_files_only=True
).to("npu")
text_encoder_bak.load_state_dict(
self.text_encoder1.state_dict(), strict=False
)
else:
text_encoder_bak = CLIPTextModelWithProjection.from_pretrained(
self.args.pretrained_model_name_or_path, subfolder="text_encoder_2", local_files_only=True
).to("npu")
text_encoder_bak.load_state_dict(
self.text_encoder2.state_dict(), strict=False
)
text_encoder_bak.save_pretrained(path)
def save_pretrained(self, path):
self.unet.save_pretrained(os.path.join(path, "unet"))
self.save_text_encoder(1, os.path.join(path, "text_encoder"))
self.save_text_encoder(2, os.path.join(path, "text_encoder_2"))