import re
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from typing import List, Optional, Union
import numpy as np
import torch
from numpy import exp, pi, sqrt
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
def preprocess_image(image):
from PIL import Image
"""Preprocess an input image
Same as
https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44
"""
w, h = image.size
w, h = (x - x % 32 for x in (w, h))
image = image.resize((w, h), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
@dataclass
class CanvasRegion:
"""Class defining a rectangular region in the canvas"""
row_init: int
row_end: int
col_init: int
col_end: int
region_seed: int = None
noise_eps: float = 0.0
def __post_init__(self):
if self.region_seed is None:
self.region_seed = np.random.randint(9999999999)
for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
if coord < 0:
raise ValueError(
f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})"
)
for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
if coord // 8 != coord / 8:
raise ValueError(
f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})"
)
if self.noise_eps < 0:
raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}")
self.latent_row_init = self.row_init // 8
self.latent_row_end = self.row_end // 8
self.latent_col_init = self.col_init // 8
self.latent_col_end = self.col_end // 8
@property
def width(self):
return self.col_end - self.col_init
@property
def height(self):
return self.row_end - self.row_init
def get_region_generator(self, device="cpu"):
"""Creates a torch.Generator based on the random seed of this region"""
return torch.Generator(device).manual_seed(self.region_seed)
@property
def __dict__(self):
return asdict(self)
class MaskModes(Enum):
"""Modes in which the influence of diffuser is masked"""
CONSTANT = "constant"
GAUSSIAN = "gaussian"
QUARTIC = "quartic"
@dataclass
class DiffusionRegion(CanvasRegion):
"""Abstract class defining a region where some class of diffusion process is acting"""
pass
@dataclass
class Text2ImageRegion(DiffusionRegion):
"""Class defining a region where a text guided diffusion process is acting"""
prompt: str = ""
guidance_scale: float = 7.5
mask_type: MaskModes = MaskModes.GAUSSIAN.value
mask_weight: float = 1.0
tokenized_prompt = None
encoded_prompt = None
def __post_init__(self):
super().__post_init__()
if self.mask_weight < 0:
raise ValueError(
f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}"
)
if self.mask_type not in [e.value for e in MaskModes]:
raise ValueError(
f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})"
)
if self.guidance_scale is None:
self.guidance_scale = np.random.randint(5, 30)
self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ")
def tokenize_prompt(self, tokenizer):
"""Tokenizes the prompt for this diffusion region using a given tokenizer"""
self.tokenized_prompt = tokenizer(
self.prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
def encode_prompt(self, text_encoder, device):
"""Encodes the previously tokenized prompt for this diffusion region using a given encoder"""
assert self.tokenized_prompt is not None, ValueError(
"Prompt in diffusion region must be tokenized before encoding"
)
self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0]
@dataclass
class Image2ImageRegion(DiffusionRegion):
"""Class defining a region where an image guided diffusion process is acting"""
reference_image: torch.FloatTensor = None
strength: float = 0.8
def __post_init__(self):
super().__post_init__()
if self.reference_image is None:
raise ValueError("Must provide a reference image when creating an Image2ImageRegion")
if self.strength < 0 or self.strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}")
self.reference_image = resize(self.reference_image, size=[self.height, self.width])
def encode_reference_image(self, encoder, device, generator, cpu_vae=False):
"""Encodes the reference image for this Image2Image region into the latent space"""
if cpu_vae:
self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device)
else:
self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(
generator=generator
)
self.reference_latents = 0.18215 * self.reference_latents
@property
def __dict__(self):
super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()}
return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength}
class RerollModes(Enum):
"""Modes in which the reroll regions operate"""
RESET = "reset"
EPSILON = "epsilon"
@dataclass
class RerollRegion(CanvasRegion):
"""Class defining a rectangular canvas region in which initial latent noise will be rerolled"""
reroll_mode: RerollModes = RerollModes.RESET.value
@dataclass
class MaskWeightsBuilder:
"""Auxiliary class to compute a tensor of weights for a given diffusion region"""
latent_space_dim: int
nbatch: int = 1
def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Computes a tensor of weights for a given diffusion region"""
MASK_BUILDERS = {
MaskModes.CONSTANT.value: self._constant_weights,
MaskModes.GAUSSIAN.value: self._gaussian_weights,
MaskModes.QUARTIC.value: self._quartic_weights,
}
return MASK_BUILDERS[region.mask_type](region)
def _constant_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Computes a tensor of constant for a given diffusion region"""
latent_width = region.latent_col_end - region.latent_col_init
latent_height = region.latent_row_end - region.latent_row_init
return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight
def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Generates a gaussian mask of weights for tile contributions"""
latent_width = region.latent_col_end - region.latent_col_init
latent_height = region.latent_row_end - region.latent_row_init
var = 0.01
midpoint = (latent_width - 1) / 2
x_probs = [
exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
for x in range(latent_width)
]
midpoint = (latent_height - 1) / 2
y_probs = [
exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
for y in range(latent_height)
]
weights = np.outer(y_probs, x_probs) * region.mask_weight
return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor:
"""Generates a quartic mask of weights for tile contributions
The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.
"""
quartic_constant = 15.0 / 16.0
support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (
region.latent_col_end - region.latent_col_init - 1
) * 1.99 - (1.99 / 2.0)
x_probs = quartic_constant * np.square(1 - np.square(support))
support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (
region.latent_row_end - region.latent_row_init - 1
) * 1.99 - (1.99 / 2.0)
y_probs = quartic_constant * np.square(1 - np.square(support))
weights = np.outer(y_probs, x_probs) * region.mask_weight
return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
class StableDiffusionCanvasPipeline(DiffusionPipeline):
"""Stable Diffusion pipeline that mixes several diffusers in the same canvas"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def decode_latents(self, latents, cpu_vae=False):
"""Decodes a given array of latents into pixel space"""
if cpu_vae:
lat = deepcopy(latents).cpu()
vae = deepcopy(self.vae).cpu()
else:
lat = latents
vae = self.vae
lat = 1 / 0.18215 * lat
image = vae.decode(lat).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return self.numpy_to_pil(image)
def get_latest_timestep_img2img(self, num_inference_steps, strength):
"""Finds the latest timesteps where an img2img strength does not impose latents anymore"""
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * (1 - strength)) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1)
latest_timestep = self.scheduler.timesteps[t_start]
return latest_timestep
@torch.no_grad()
def __call__(
self,
canvas_height: int,
canvas_width: int,
regions: List[DiffusionRegion],
num_inference_steps: Optional[int] = 50,
seed: Optional[int] = 12345,
reroll_regions: Optional[List[RerollRegion]] = None,
cpu_vae: Optional[bool] = False,
decode_steps: Optional[bool] = False,
):
if reroll_regions is None:
reroll_regions = []
batch_size = 1
if decode_steps:
steps_images = []
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)]
image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)]
for region in text2image_regions:
region.tokenize_prompt(self.tokenizer)
region.encode_prompt(self.text_encoder, self.device)
latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8)
generator = torch.Generator(self.device).manual_seed(seed)
init_noise = torch.randn(latents_shape, generator=generator, device=self.device)
for region in reroll_regions:
if region.reroll_mode == RerollModes.RESET.value:
region_shape = (
latents_shape[0],
latents_shape[1],
region.latent_row_end - region.latent_row_init,
region.latent_col_end - region.latent_col_init,
)
init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device)
all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value]
for region in all_eps_rerolls:
if region.noise_eps > 0:
region_noise = init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
]
eps_noise = (
torch.randn(
region_noise.shape, generator=region.get_region_generator(self.device), device=self.device
)
* region.noise_eps
)
init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] += eps_noise
latents = init_noise * self.scheduler.init_noise_sigma
for region in text2image_regions:
max_length = region.tokenized_prompt.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt])
for region in image2image_regions:
region.encode_reference_image(self.vae, device=self.device, generator=generator)
mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size)
mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions]
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
noise_preds_regions = []
for region in text2image_regions:
region_latents = latents[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
]
latent_model_input = torch.cat([region_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_preds_regions.append(noise_pred_region)
noise_pred = torch.zeros(latents.shape, device=self.device)
contributors = torch.zeros(latents.shape, device=self.device)
for region, noise_pred_region, mask_weights_region in zip(
text2image_regions, noise_preds_regions, mask_weights
):
noise_pred[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] += (
noise_pred_region * mask_weights_region
)
contributors[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] += mask_weights_region
noise_pred /= contributors
noise_pred = torch.nan_to_num(
noise_pred
)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
for region in image2image_regions:
influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength)
if t > influence_step:
timestep = t.repeat(batch_size)
region_init_noise = init_noise[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
]
region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep)
latents[
:,
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
] = region_latents
if decode_steps:
steps_images.append(self.decode_latents(latents, cpu_vae))
image = self.decode_latents(latents, cpu_vae)
output = {"images": image}
if decode_steps:
output = {**output, "steps_images": steps_images}
return output