import argparse
import ast
import asyncio
import datetime
import importlib
import json
import pathlib
import re
import shutil
import time
from typing import (
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs
import gc
import glob
import math
import os
import random
import hashlib
import subprocess
from io import BytesIO
import toml
from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
StableDiffusionPipeline,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
AutoencoderKL,
)
from library import custom_train_functions
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
from library.original_unet import UNet2DConditionModel
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2"
EPOCH_STATE_NAME = "{}-{:06d}-state"
EPOCH_FILE_NAME = "{}-{:06d}"
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
LAST_STATE_NAME = "{}-state"
DEFAULT_EPOCH_NAME = "epoch"
DEFAULT_LAST_OUTPUT_NAME = "last"
DEFAULT_STEP_NAME = "at"
STEP_STATE_NAME = "{}-step{:08d}-state"
STEP_FILE_NAME = "{}-step{:08d}"
STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
try:
import pillow_avif
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
except:
pass
try:
from jxlpy import JXLImagePlugin
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_original_size: Tuple[int, int] = None
self.latents_crop_ltrb: Tuple[int, int] = None
self.cond_img_path: str = None
self.image: Optional[Image.Image] = None
self.text_encoder_outputs_npz: Optional[str] = None
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
if max_size is not None:
if max_reso is not None:
assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso"
assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso"
if min_size is not None:
assert max_size >= min_size, "the max_size should be larger than the min_size"
self.no_upscale = no_upscale
if max_reso is None:
self.max_reso = None
self.max_area = None
else:
self.max_reso = max_reso
self.max_area = max_reso[0] * max_reso[1]
self.min_size = min_size
self.max_size = max_size
self.reso_steps = reso_steps
self.resos = []
self.reso_to_id = {}
self.buckets = []
def add_image(self, reso, image_or_info):
bucket_id = self.reso_to_id[reso]
self.buckets[bucket_id].append(image_or_info)
def shuffle(self):
for bucket in self.buckets:
random.shuffle(bucket)
def sort(self):
sorted_resos = self.resos.copy()
sorted_resos.sort()
sorted_buckets = []
sorted_reso_to_id = {}
for i, reso in enumerate(sorted_resos):
bucket_id = self.reso_to_id[reso]
sorted_buckets.append(self.buckets[bucket_id])
sorted_reso_to_id[reso] = i
self.resos = sorted_resos
self.buckets = sorted_buckets
self.reso_to_id = sorted_reso_to_id
def make_buckets(self):
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
self.set_predefined_resos(resos)
def set_predefined_resos(self, resos):
self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos)
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
def add_if_new_reso(self, reso):
if reso not in self.reso_to_id:
bucket_id = len(self.resos)
self.reso_to_id[reso] = bucket_id
self.resos.append(reso)
self.buckets.append([])
def round_to_steps(self, x):
x = int(x + 0.5)
return x - x % self.reso_steps
def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height
if not self.no_upscale:
reso = (image_width, image_height)
if reso in self.predefined_resos_set:
pass
else:
ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin()
reso = self.predefined_resos[predefined_bucket_id]
ar_reso = reso[0] / reso[1]
if aspect_ratio > ar_reso:
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
else:
if image_width * image_height > self.max_area:
resized_width = math.sqrt(self.max_area * aspect_ratio)
resized_height = self.max_area / resized_width
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
b_width_rounded = self.round_to_steps(resized_width)
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
ar_width_rounded = b_width_rounded / b_height_in_wr
b_height_rounded = self.round_to_steps(resized_height)
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
ar_height_rounded = b_width_in_hr / b_height_rounded
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5))
else:
resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded)
else:
resized_size = (image_width, image_height)
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
reso = (bucket_width, bucket_height)
self.add_if_new_reso(reso)
ar_error = (reso[0] / reso[1]) - aspect_ratio
return reso, resized_size, ar_error
@staticmethod
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
bucket_ar = bucket_reso[0] / bucket_reso[1]
image_ar = image_size[0] / image_size[1]
if bucket_ar > image_ar:
resized_width = bucket_reso[1] * image_ar
resized_height = bucket_reso[1]
else:
resized_width = bucket_reso[0]
resized_height = bucket_reso[0] / image_ar
crop_left = (bucket_reso[0] - resized_width) // 2
crop_top = (bucket_reso[1] - resized_height) // 2
crop_right = crop_left + resized_width
crop_bottom = crop_top + resized_height
return crop_left, crop_top, crop_right, crop_bottom
class BucketBatchIndex(NamedTuple):
bucket_index: int
bucket_batch_size: int
batch_index: int
class AugHelper:
def __init__(self):
pass
def color_aug(self, image: np.ndarray):
hue_shift_limit = 8
if random.random() <= 0.33:
if random.random() > 0.5:
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
if hue_shift < 0:
hue_shift = 180 + hue_shift
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
else:
gamma = random.uniform(0.95, 1.05)
image = np.clip(image**gamma, 0, 255).astype(np.uint8)
return {"image": image}
def get_augmentor(self, use_color_aug: bool):
return self.color_aug if use_color_aug else None
class BaseSubset:
def __init__(
self,
image_dir: Optional[str],
num_repeats: int,
shuffle_caption: bool,
caption_separator: str,
keep_tokens: int,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
random_crop: bool,
caption_dropout_rate: float,
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
caption_prefix: Optional[str],
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
self.random_crop = random_crop
self.caption_dropout_rate = caption_dropout_rate
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.caption_prefix = caption_prefix
self.caption_suffix = caption_suffix
self.token_warmup_min = token_warmup_min
self.token_warmup_step = token_warmup_step
self.img_count = 0
class DreamBoothSubset(BaseSubset):
def __init__(
self,
image_dir: str,
is_reg: bool,
class_tokens: Optional[str],
caption_extension: str,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
return NotImplemented
return self.image_dir == other.image_dir
class FineTuningSubset(BaseSubset):
def __init__(
self,
image_dir,
metadata_file: str,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.metadata_file = metadata_file
def __eq__(self, other) -> bool:
if not isinstance(other, FineTuningSubset):
return NotImplemented
return self.metadata_file == other.metadata_file
class ControlNetSubset(BaseSubset):
def __init__(
self,
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.conditioning_data_dir = conditioning_data_dir
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
return NotImplemented
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
debug_dataset: bool,
) -> None:
super().__init__()
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
self.max_token_length = max_token_length
self.width, self.height = (None, None) if resolution is None else resolution
self.debug_dataset = debug_dataset
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
self.token_padding_disabled = False
self.tag_frequency = {}
self.XTI_layers = None
self.token_strings = None
self.enable_bucket = False
self.bucket_manager: BucketManager = None
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None
self.bucket_no_upscale = None
self.bucket_info = None
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
self.aug_helper = AugHelper()
self.image_transforms = IMAGE_TRANSFORMS
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
self.replacements = {}
self.caching_mode = None
def set_seed(self, seed):
self.seed = seed
def set_caching_mode(self, mode):
self.caching_mode = mode
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch:
self.shuffle_buckets()
self.current_epoch = epoch
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
tag = tag.strip()
if tag:
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self):
self.token_padding_disabled = True
def enable_XTI(self, layers=None, token_strings=None):
self.XTI_layers = layers
self.token_strings = token_strings
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
def process_caption(self, subset: BaseSubset, caption):
if subset.caption_prefix:
caption = subset.caption_prefix + " " + caption
if subset.caption_suffix:
caption = caption + " " + subset.caption_suffix
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
is_drop_out = (
is_drop_out
or subset.caption_dropout_every_n_epochs > 0
and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
)
if is_drop_out:
caption = ""
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= subset.caption_tag_dropout_rate:
l.append(token)
return l
fixed_tokens = []
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption:
random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens)
for str_from, str_to in self.replacements.items():
if str_from == "":
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
def get_input_ids(self, caption, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizers[0]
input_ids = tokenizer(
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
).input_ids
if self.tokenizer_max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
for i in range(
1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2
):
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list)
return input_ids
def register_image(self, info: ImageInfo, subset: BaseSubset):
self.image_data[info.image_key] = info
self.image_to_subset[info.image_key] = subset
def make_buckets(self):
"""
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
min_size and max_size are ignored when enable_bucket is False
"""
print("loading image sizes.")
for info in tqdm(self.image_data.values()):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
if self.enable_bucket:
print("make buckets")
else:
print("prepare dataset")
if self.enable_bucket:
if self.bucket_manager is None:
self.bucket_manager = BucketManager(
self.bucket_no_upscale,
(self.width, self.height),
self.min_bucket_reso,
self.max_bucket_reso,
self.bucket_reso_steps,
)
if not self.bucket_no_upscale:
self.bucket_manager.make_buckets()
else:
print(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
img_ar_errors = []
for image_info in self.image_data.values():
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
image_width, image_height
)
img_ar_errors.append(abs(ar_error))
self.bucket_manager.sort()
else:
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
self.bucket_manager.set_predefined_resos([(self.width, self.height)])
for image_info in self.image_data.values():
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
for image_info in self.image_data.values():
for _ in range(image_info.num_repeats):
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
count = len(bucket)
if count > 0:
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
self.buckets_indices: List(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
self.shuffle_buckets()
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()
def verify_bucket_reso_steps(self, min_steps: int):
assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, (
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
)
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_text_encoder_output_cacheable(self):
return all(
[
not (
subset.caption_dropout_rate > 0
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
)
for subset in self.subsets
]
)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
print("caching latents.")
image_infos = list(self.image_data.values())
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
batches = []
batch = []
print("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None:
continue
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
if not is_main_process:
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
if cache_available:
continue
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
batch.append(info)
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
if cache_to_disk and not is_main_process:
return
print("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
assert len(tokenizers) == 2, "only support SDXL"
print("caching text encoder outputs.")
image_infos = list(self.image_data.values())
print("checking cache existence...")
image_infos_to_cache = []
for info in tqdm(image_infos):
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process:
continue
if os.path.exists(te_out_npz):
continue
image_infos_to_cache.append(info)
if cache_to_disk and not is_main_process:
return
for text_encoder in text_encoders:
text_encoder.to(device)
if weight_dtype is not None:
text_encoder.to(dtype=weight_dtype)
batch = []
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
if len(batch) >= self.batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
print("caching text encoder outputs...")
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
face_cx = face_cy = face_w = face_h = 0
if subset.face_crop_aug_range is not None:
tokens = os.path.splitext(os.path.basename(image_path))[0].split("_")
if len(tokens) >= 5:
face_cx = int(tokens[-4])
face_cy = int(tokens[-3])
face_w = int(tokens[-2])
face_h = int(tokens[-1])
return img, face_cx, face_cy, face_w, face_h
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
height, width = image.shape[0:2]
if height == self.height and width == self.width:
return image
face_size = max(face_w, face_h)
size = min(self.height, self.width)
min_scale = max(self.height / height, self.width / width)
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1])))
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0])))
if min_scale >= max_scale:
scale = min_scale
else:
scale = random.uniform(min_scale, max_scale)
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
p1 = face_p - target_size // 2
if subset.random_crop:
range = max(length - face_p, face_p)
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range
else:
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
p1 = max(0, min(p1, length - target_size))
if axis == 0:
image = image[p1 : p1 + target_size, :]
else:
image = image[:, p1 : p1 + target_size]
return image
def __len__(self):
return self._length
def __getitem__(self, index):
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
if self.caching_mode is not None:
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
loss_weights = []
captions = []
input_ids_list = []
input_ids2_list = []
latents_list = []
images = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
flippeds = []
text_encoder_outputs1_list = []
text_encoder_outputs2_list = []
text_encoder_pool2_list = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
flipped = subset.flip_aug and random.random() < 0.5
if image_info.latents is not None:
original_size = image_info.latents_original_size
crop_ltrb = image_info.latents_crop_ltrb
if not flipped:
latents = image_info.latents
else:
latents = image_info.latents_flipped
image = None
elif image_info.latents_npz is not None:
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
if flipped:
latents = flipped_latents
del flipped_latents
latents = torch.FloatTensor(latents)
image = None
else:
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
)
else:
if face_cx > 0:
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
elif im_h > self.height or im_w > self.width:
assert (
subset.random_crop
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
if im_h > self.height:
p = random.randint(0, im_h - self.height)
img = img[p : p + self.height]
if im_w > self.width:
p = random.randint(0, im_w - self.width)
img = img[:, p : p + self.width]
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
original_size = [im_w, im_h]
crop_ltrb = (0, 0, 0, 0)
aug = self.aug_helper.get_augmentor(subset.color_aug)
if aug is not None:
img = aug(image=img)["image"]
if flipped:
img = img[:, ::-1, :].copy()
latents = None
image = self.image_transforms(img)
images.append(image)
latents_list.append(latents)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
caption = image_info.caption
if image_info.text_encoder_outputs1 is not None:
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
captions.append(caption)
elif image_info.text_encoder_outputs_npz is not None:
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
image_info.text_encoder_outputs_npz
)
text_encoder_outputs1_list.append(text_encoder_outputs1)
text_encoder_outputs2_list.append(text_encoder_outputs2)
text_encoder_pool2_list.append(text_encoder_pool2)
captions.append(caption)
else:
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled:
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
else:
token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption)
if len(self.tokenizers) > 1:
if self.XTI_layers:
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
else:
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2_list.append(token_caption2)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if len(text_encoder_outputs1_list) == 0:
if self.token_padding_disabled:
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
if len(self.tokenizers) > 1:
example["input_ids2"] = self.tokenizer[1](
captions, padding=True, truncation=True, return_tensors="pt"
).input_ids
else:
example["input_ids2"] = None
else:
example["input_ids"] = torch.stack(input_ids_list)
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
example["text_encoder_outputs1_list"] = None
example["text_encoder_outputs2_list"] = None
example["text_encoder_pool2_list"] = None
else:
example["input_ids"] = None
example["input_ids2"] = None
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts])
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
example["flippeds"] = flippeds
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
captions = []
images = []
input_ids1_list = []
input_ids2_list = []
absolute_paths = []
resized_sizes = []
bucket_reso = None
flip_aug = None
random_crop = None
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
if flip_aug is None:
flip_aug = subset.flip_aug
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
caption = image_info.caption
if self.caching_mode == "latents":
image = load_image(image_info.absolute_path)
else:
image = None
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
else:
input_ids1 = None
input_ids2 = None
captions.append(caption)
images.append(image)
input_ids1_list.append(input_ids1)
input_ids2_list.append(input_ids2)
absolute_paths.append(image_info.absolute_path)
resized_sizes.append(image_info.resized_size)
example = {}
if images[0] is None:
images = None
example["images"] = images
example["captions"] = captions
example["input_ids1_list"] = input_ids1_list
example["input_ids2_list"] = input_ids2_list
example["absolute_paths"] = absolute_paths
example["resized_sizes"] = resized_sizes
example["flip_aug"] = flip_aug
example["random_crop"] = random_crop
example["bucket_reso"] = bucket_reso
return example
class DreamBoothDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height)
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break
return caption
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
img_paths = glob_images(subset.image_dir, "*")
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
print(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions)
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption + f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, captions
print("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[ImageInfo] = []
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
if subset in self.subsets:
print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
img_paths, captions = load_dreambooth_dir(subset)
if len(img_paths) < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
continue
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
else:
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if subset.is_reg:
reg_infos.append(info)
else:
self.register_image(info, subset)
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
if num_reg_images == 0:
print("no regularization images / 正則化画像が見つかりませんでした")
else:
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1
n += 1
if n >= num_train_images:
break
first_loop = False
self.num_reg_images = num_reg_images
class FineTuningDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[FineTuningSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.batch_size = batch_size
self.num_train_images = 0
self.num_reg_images = 0
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
if subset in self.subsets:
print(
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
if os.path.exists(subset.metadata_file):
print(f"loading existing metadata: {subset.metadata_file}")
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
if len(metadata) < 1:
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
continue
tags_list = []
for image_key, img_md in metadata.items():
abs_path = None
if os.path.exists(image_key):
abs_path = image_key
else:
paths = glob_images(subset.image_dir, image_key)
if len(paths) > 0:
abs_path = paths[0]
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
caption = img_md.get("caption")
tags = img_md.get("tags")
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ", " + tags
tags_list.append(tags)
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info.image_size = img_md.get("train_resolution")
if not subset.color_aug and not subset.random_crop:
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
self.register_image(image_info, subset)
self.num_train_images += len(metadata) * subset.num_repeats
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
subset.img_count = len(metadata)
self.subsets.append(subset)
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
flip_aug_in_subset = False
npz_any = False
npz_all = True
for image_info in self.image_data.values():
subset = self.image_to_subset[image_info.image_key]
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
break
if not npz_any:
use_npz_latents = False
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
elif not npz_all:
use_npz_latents = False
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
if flip_aug_in_subset:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
sizes = set()
resos = set()
for image_info in self.image_data.values():
if image_info.image_size is None:
sizes = None
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
resos.add(tuple(image_info.image_size))
if sizes is None:
if use_npz_latents:
use_npz_latents = False
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
assert (
resolution is not None
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
self.enable_bucket = enable_bucket
if self.enable_bucket:
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
if not enable_bucket:
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
self.enable_bucket = True
assert (
not bucket_no_upscale
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
self.bucket_manager = BucketManager(False, None, None, None, None)
self.bucket_manager.set_predefined_resos(resos)
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
if os.path.exists(npz_file_norm):
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
if subset.image_dir is None:
return None, None
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
db_subsets = []
for subset in subsets:
db_subset = DreamBoothSubset(
subset.image_dir,
False,
None,
subset.caption_extension,
subset.num_repeats,
subset.shuffle_caption,
subset.keep_tokens,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
subset.random_crop,
subset.caption_dropout_rate,
subset.caption_dropout_every_n_epochs,
subset.caption_tag_dropout_rate,
subset.caption_prefix,
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
)
db_subsets.append(db_subset)
self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets,
batch_size,
tokenizer,
max_token_length,
resolution,
enable_bucket,
min_bucket_reso,
max_bucket_reso,
bucket_reso_steps,
bucket_no_upscale,
1.0,
debug_dataset,
)
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
missing_imgs = []
cond_imgs_with_img = set()
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
subset = None
for s in subsets:
if s.image_dir == db_subset.image_dir:
subset = s
break
assert subset is not None, "internal error: subset not found"
if not os.path.isdir(subset.conditioning_data_dir):
print(f"not directory: {subset.conditioning_data_dir}")
continue
img_basename = os.path.basename(info.absolute_path)
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
if not os.path.exists(ctrl_img_path):
missing_imgs.append(img_basename)
info.cond_img_path = ctrl_img_path
cond_imgs_with_img.add(ctrl_img_path)
extra_imgs = []
for subset in subsets:
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
extra_imgs.extend(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def __len__(self):
return self.dreambooth_dataset_delegate.__len__()
def __getitem__(self, index):
example = self.dreambooth_dataset_delegate[index]
bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[
self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index
]
bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size
image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size
conditioning_images = []
for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]):
image_info = self.dreambooth_dataset_delegate.image_data[image_key]
target_size_hw = example["target_sizes_hw"][i]
original_size_hw = example["original_sizes_hw"][i]
crop_top_left = example["crop_top_lefts"][i]
flipped = example["flippeds"][i]
cond_img = load_image(image_info.cond_img_path)
if self.dreambooth_dataset_delegate.enable_bucket:
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA)
h, w = target_size_hw
ct = (cond_img.shape[0] - h) // 2
cl = (cond_img.shape[1] - w) // 2
cond_img = cond_img[ct : ct + h, cl : cl + w]
else:
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = cv2.resize(
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
)
if flipped:
cond_img = cond_img[:, ::-1, :].copy()
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
return example
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
super().__init__(datasets)
self.image_data = {}
self.num_train_images = 0
self.num_reg_images = 0
for dataset in datasets:
self.image_data.update(dataset.image_data)
self.num_train_images += dataset.num_train_images
self.num_reg_images += dataset.num_reg_images
def add_replacement(self, str_from, str_to):
for dataset in self.datasets:
dataset.add_replacement(str_from, str_to)
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
def verify_bucket_reso_steps(self, min_steps: int):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def set_current_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8)
if not os.path.exists(npz_path):
return False
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz:
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
return True
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
return latents, original_size, crop_ltrb, flipped_latents
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
epoch = 1
while True:
print(f"\nepoch: {epoch}")
steps = (epoch - 1) * len(train_dataset) + 1
indices = list(range(len(train_dataset)))
random.shuffle(indices)
k = 0
for i, idx in enumerate(indices):
train_dataset.set_current_epoch(epoch)
train_dataset.set_current_step(steps)
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
example = train_dataset[idx]
if example["latents"] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
zip(
example["image_keys"],
example["captions"],
example["loss_weights"],
example["input_ids"],
example["original_sizes_hw"],
example["crop_top_lefts"],
example["target_sizes_hw"],
example["flippeds"],
)
):
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if show_input_ids:
print(f"input ids: {iid}")
if "input_ids2" in example:
print(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = im[:, :, ::-1]
if "conditioning_images" in example:
cond_img = example["conditioning_images"][j]
print(f"conditioning image size: {cond_img.size()}")
cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8)
cond_img = np.transpose(cond_img, (1, 2, 0))
cond_img = cond_img[:, :, ::-1]
if os.name == "nt":
cv2.imshow("cond_img", cond_img)
if os.name == "nt":
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
if k == ord("e"):
break
if k == 27 or (example["images"] is None and i >= 8):
k = 27
break
if k == 27:
break
epoch += 1
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == "*":
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths))
img_paths.sort()
return img_paths
def glob_images_pathlib(dir_path, recursive):
image_paths = []
if recursive:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.rglob("*" + ext))
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob("*" + ext))
image_paths = list(set(image_paths))
image_paths.sort()
return image_paths
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.num_train_images = 0
self.num_reg_images = 0
self.datasets = [self]
self.batch_size = 1
self.subsets = [self]
self.num_repeats = 1
self.img_count = 1
self.bucket_info = {}
self.is_reg = False
self.image_dir = "dummy"
def verify_bucket_reso_steps(self, min_steps: int):
pass
def is_latent_cacheable(self) -> bool:
return False
def __len__(self):
raise NotImplementedError
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def __getitem__(self, idx):
r"""
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
Returns: example like this:
for i in range(batch_size):
image_key = ... # whatever hashable
image_keys.append(image_key)
image = ... # PIL Image
img_tensor = self.image_transforms(img)
images.append(img_tensor)
caption = ... # str
input_ids = self.get_input_ids(caption)
input_ids_list.append(input_ids)
captions.append(caption)
images = torch.stack(images, dim=0)
input_ids_list = torch.stack(input_ids_list, dim=0)
example = {
"images": images,
"input_ids": input_ids_list,
"captions": captions, # for debug_dataset
"latents": None,
"image_keys": image_keys, # for debug_dataset
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
}
return example
"""
raise NotImplementedError
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
return train_dataset_group
def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
if image_width != resized_size[0] or image_height != resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
image_height, image_width = image.shape[0:2]
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
image = image[:, p : p + reso[0]]
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
image = image[p : p + reso[1]]
crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size)
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image, original_size, crop_ltrb
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
optionally requires image_infos to have: image
if cache_to_disk is True, set info.latents_npz
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_ltrb are also set
"""
images = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image = IMAGE_TRANSFORMS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
if torch.cuda.is_available():
torch.cuda.empty_cache()
def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
b_hidden_state1 = b_hidden_state1.detach().to("cpu")
b_hidden_state2 = b_hidden_state2.detach().to("cpu")
b_pool2 = b_pool2.detach().to("cpu")
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2)
else:
info.text_encoder_outputs1 = hidden_state1
info.text_encoder_outputs2 = hidden_state2
info.text_encoder_pool2 = pool2
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
np.savez(
npz_path,
hidden_state1=hidden_state1.cpu().float().numpy(),
hidden_state2=hidden_state2.cpu().float().numpy(),
pool2=pool2.cpu().float().numpy(),
)
def load_text_encoder_outputs_from_disk(npz_path):
with np.load(npz_path) as f:
hidden_state1 = torch.from_numpy(f["hidden_state1"])
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
return hidden_state1, hidden_state2, pool2
"""
高速化のためのモジュール入れ替え
"""
EPSILON = 1e-6
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try:
with open(filename, "rb") as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return "NOFILE"
except IsADirectoryError:
return "IsADirectory"
except PermissionError:
return "IsADirectory"
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
try:
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
except FileNotFoundError:
return "NOFILE"
except IsADirectoryError:
return "IsADirectory"
except PermissionError:
return "IsADirectory"
def precalculate_safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)
model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash
def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]
def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def get_git_revision_hash() -> str:
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip()
except:
return "(unknown)"
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
print("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_sdpa(True)
"""
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
# vae is not used currently, but it is here for future use
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
print("Use Diffusers xformers for VAE")
vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
def replace_vae_attn_to_memory_efficient():
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states):
print("forward_flash_attn")
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
"""
def load_metadata_from_safetensors(safetensors_file: str) -> dict:
"""r
This method locks the file. see https://github.com/huggingface/safetensors/issues/164
If the file isn't .safetensors or doesn't have metadata, return empty dict.
"""
if os.path.splitext(safetensors_file)[1] != ".safetensors":
return {}
with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata is None:
metadata = {}
return metadata
SS_METADATA_KEY_V2 = "ss_v2"
SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
SS_METADATA_MINIMUM_KEYS = [
SS_METADATA_KEY_V2,
SS_METADATA_KEY_BASE_MODEL_VERSION,
SS_METADATA_KEY_NETWORK_MODULE,
SS_METADATA_KEY_NETWORK_DIM,
SS_METADATA_KEY_NETWORK_ALPHA,
SS_METADATA_KEY_NETWORK_ARGS,
]
def build_minimum_network_metadata(
v2: Optional[bool],
base_model: Optional[str],
network_module: str,
network_dim: str,
network_alpha: str,
network_args: Optional[dict],
):
metadata = {
SS_METADATA_KEY_NETWORK_MODULE: network_module,
SS_METADATA_KEY_NETWORK_DIM: network_dim,
SS_METADATA_KEY_NETWORK_ALPHA: network_alpha,
}
if v2 is not None:
metadata[SS_METADATA_KEY_V2] = v2
if base_model is not None:
metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model
if network_args is not None:
metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args)
return metadata
def get_sai_model_spec(
state_dict: dict,
args: argparse.Namespace,
sdxl: bool,
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None,
):
timestamp = time.time()
v2 = args.v2
v_parameterization = args.v_parameterization
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
state_dict,
v2,
v_parameterization,
sdxl,
lora,
textual_inversion,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip,
)
return metadata
def add_sd_models_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
parser.add_argument(
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
)
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
)
def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)",
)
parser.add_argument(
"--use_lion_optimizer",
action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)",
)
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない"
)
parser.add_argument(
"--optimizer_args",
type=str,
default=None,
nargs="*",
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
)
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
"--lr_scheduler_args",
type=str,
default=None,
nargs="*",
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
)
parser.add_argument(
"--lr_scheduler_num_cycles",
type=int,
default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
)
parser.add_argument(
"--lr_scheduler_power",
type=float,
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument(
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
)
parser.add_argument(
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
)
parser.add_argument(
"--huggingface_path_in_repo",
type=str,
default=None,
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
)
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
parser.add_argument(
"--huggingface_repo_visibility",
type=str,
default=None,
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
)
parser.add_argument(
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
)
parser.add_argument(
"--resume_from_huggingface",
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--async_upload",
action="store_true",
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
)
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving / 保存時に精度を変更して保存する",
)
parser.add_argument(
"--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する"
)
parser.add_argument(
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
)
parser.add_argument(
"--save_n_epoch_ratio",
type=int,
default=None,
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)",
)
parser.add_argument(
"--save_last_n_epochs",
type=int,
default=None,
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
)
parser.add_argument(
"--save_last_n_epochs_state",
type=int,
default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
)
parser.add_argument(
"--save_last_n_steps",
type=int,
default=None,
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
)
parser.add_argument(
"--save_last_n_steps_state",
type=int,
default=None,
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
)
parser.add_argument(
"--save_state",
action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
)
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument(
"--max_token_length",
type=int,
default=None,
choices=[None, 77, 150, 225, 512],
help="max token length of text encoder (default for 77, 150, 225 or 512) / text encoderのトークンの最大長(未指定で75、150または225が指定可)",
)
parser.add_argument(
"--mem_eff_attn",
action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
action="store_true",
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
)
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
)
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument(
"--max_train_epochs",
type=int,
default=None,
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
)
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
)
parser.add_argument(
"--persistent_data_loader_workers",
action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
)
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
)
parser.add_argument(
"--ddp_timeout",
type=int,
default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
)
parser.add_argument(
"--clip_skip",
type=int,
default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)",
)
parser.add_argument(
"--logging_dir",
type=str,
default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
)
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument(
"--log_tracker_name",
type=str,
default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
)
parser.add_argument(
"--log_tracker_config",
type=str,
default=None,
help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
)
parser.add_argument(
"--wandb_api_key",
type=str,
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
)
parser.add_argument(
"--noise_offset",
type=float,
default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
)
parser.add_argument(
"--multires_noise_iterations",
type=int,
default=None,
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
)
parser.add_argument(
"--ip_noise_gamma",
type=float,
default=None,
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) "
+ "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)",
)
parser.add_argument(
"--multires_noise_discount",
type=float,
default=0.3,
help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)",
)
parser.add_argument(
"--adaptive_noise_scale",
type=float,
default=None,
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)",
)
parser.add_argument(
"--zero_terminal_snr",
action="store_true",
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
)
parser.add_argument(
"--min_timestep",
type=int,
default=None,
help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
)
parser.add_argument(
"--max_timestep",
type=int,
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
)
parser.add_argument(
"--lowram",
action="store_true",
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)",
)
parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
)
parser.add_argument(
"--sample_every_n_epochs",
type=int,
default=None,
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
)
parser.add_argument(
"--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル"
)
parser.add_argument(
"--sample_sampler",
type=str,
default="ddim",
choices=[
"ddim",
"pndm",
"lms",
"euler",
"euler_a",
"heun",
"dpm_2",
"dpm_2_a",
"dpmsolver",
"dpmsolver++",
"dpmsingle",
"k_lms",
"k_euler",
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
],
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
)
parser.add_argument(
"--config_file",
type=str,
default=None,
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
)
parser.add_argument(
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
if support_dreambooth:
parser.add_argument(
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)
def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
print(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
if args.v_pred_like_loss and args.v_parameterization:
raise ValueError(
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
)
if args.zero_terminal_snr and not args.v_parameterization:
print(
f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected"
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
)
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
):
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
)
parser.add_argument(
"--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字"
)
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)",
)
parser.add_argument(
"--keep_tokens",
type=int,
default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
)
parser.add_argument(
"--caption_prefix",
type=str,
default=None,
help="prefix for caption text / captionのテキストの先頭に付ける文字列",
)
parser.add_argument(
"--caption_suffix",
type=str,
default=None,
help="suffix for caption text / captionのテキストの末尾に付ける文字列",
)
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument(
"--face_crop_aug_range",
type=str,
default=None,
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)",
)
parser.add_argument(
"--random_crop",
action="store_true",
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)",
)
parser.add_argument(
"--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)"
)
parser.add_argument(
"--resolution",
type=str,
default=None,
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)",
)
parser.add_argument(
"--cache_latents",
action="store_true",
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ",
)
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--cache_latents_to_disk",
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
)
parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument(
"--bucket_reso_steps",
type=int,
default=64,
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
)
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
)
parser.add_argument("--tokenizer1_path", type=str, default=None, help="tokenizer1 path")
parser.add_argument("--tokenizer2_path", type=str, default=None, help="tokenizer2 path")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="enable npu flash attention")
if support_caption_dropout:
parser.add_argument(
"--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合"
)
parser.add_argument(
"--caption_dropout_every_n_epochs",
type=int,
default=0,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする",
)
parser.add_argument(
"--caption_tag_dropout_rate",
type=float,
default=0.0,
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合",
)
if support_dreambooth:
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
if support_caption:
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
parser.add_argument(
"--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数"
)
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--save_model_as",
type=str,
default=None,
choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)",
)
parser.add_argument(
"--use_safetensors",
action="store_true",
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)",
)
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
if not args.config_file:
return args
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if args.output_config:
if os.path.exists(config_path):
print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}")
exit(1)
args_dict = vars(args)
for key in ["config_file", "output_config", "wandb_api_key"]:
if key in args_dict:
del args_dict[key]
default_args = vars(parser.parse_args([]))
for key, value in list(args_dict.items()):
if key in default_args and value == default_args[key]:
del args_dict[key]
for key, value in args_dict.items():
if isinstance(value, pathlib.Path):
args_dict[key] = str(value)
with open(config_path, "w") as f:
toml.dump(args_dict, f)
print(f"Saved config file / 設定ファイルを保存しました: {config_path}")
exit(0)
if not os.path.exists(config_path):
print(f"{config_path} not found.")
exit(1)
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
if not isinstance(section_dict, dict):
ignore_nesting_dict[section_name] = section_dict
continue
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = os.path.splitext(args.config_file)[0]
print(args.config_file)
return args
def resume_from_local_or_hf_if_specified(accelerator, args):
if not args.resume:
return
if not args.resume_from_huggingface:
print(f"resume training from local state: {args.resume}")
accelerator.load_state(args.resume)
return
print(f"resume training from huggingface state: {args.resume}")
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
path_in_repo = "/".join(args.resume.split("/")[2:])
revision = None
repo_type = None
if ":" in path_in_repo:
divided = path_in_repo.split(":")
if len(divided) == 2:
path_in_repo, revision = divided
repo_type = "model"
else:
path_in_repo, revision, repo_type = divided
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
list_files = huggingface_util.list_dir(
repo_id=repo_id,
subfolder=path_in_repo,
revision=revision,
token=args.huggingface_token,
repo_type=repo_type,
)
async def download(filename) -> str:
def task():
return hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
repo_type=repo_type,
token=args.huggingface_token,
)
return await asyncio.get_event_loop().run_in_executor(None, task)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
if len(results) == 0:
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
def get_optimizer(args, trainable_params):
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
assert (
not args.use_lion_optimizer
), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
assert (
optimizer_type is None or optimizer_type == ""
), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
optimizer_type = "AdamW8bit"
elif args.use_lion_optimizer:
assert (
optimizer_type is None or optimizer_type == ""
), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
optimizer_type = "Lion"
if optimizer_type is None or optimizer_type == "":
optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower()
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
lr = args.learning_rate
optimizer = None
if optimizer_type == "Lion".lower():
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print(f"use Lion optimizer | {optimizer_kwargs}")
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("8bit".lower()):
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
if optimizer_type == "AdamW8bit".lower():
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
optimizer_class = bnb.optim.AdamW8bit
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov8bit".lower():
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
print(
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
)
optimizer_kwargs["momentum"] = 0.9
optimizer_class = bnb.optim.SGD8bit
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type == "Lion8bit".lower():
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
try:
optimizer_class = bnb.optim.Lion8bit
except AttributeError:
raise AttributeError(
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
)
elif optimizer_type == "PagedAdamW8bit".lower():
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
try:
optimizer_class = bnb.optim.PagedAdamW8bit
except AttributeError:
raise AttributeError(
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
elif optimizer_type == "PagedLion8bit".lower():
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
try:
optimizer_class = bnb.optim.PagedLion8bit
except AttributeError:
raise AttributeError(
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW32bit".lower():
print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}")
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
try:
optimizer_class = bnb.optim.PagedAdamW32bit
except AttributeError:
raise AttributeError(
"No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov".lower():
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
optimizer_kwargs["momentum"] = 0.9
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
actual_lr = lr
lr_count = 1
if type(trainable_params) == list and type(trainable_params[0]) == dict:
lrs = set()
actual_lr = trainable_params[0].get("lr", actual_lr)
for group in trainable_params:
lrs.add(group.get("lr", actual_lr))
lr_count = len(lrs)
if actual_lr <= 0.1:
print(
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
)
print("recommend option: lr=1.0 / 推奨は1.0です")
if lr_count > 1:
print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
)
if optimizer_type.startswith("DAdapt".lower()):
try:
import dadaptation
import dadaptation.experimental as experimental
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
optimizer_class = experimental.DAdaptAdamPreprint
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdaGrad".lower():
optimizer_class = dadaptation.DAdaptAdaGrad
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdam".lower():
optimizer_class = dadaptation.DAdaptAdam
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdan".lower():
optimizer_class = dadaptation.DAdaptAdan
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdanIP".lower():
optimizer_class = experimental.DAdaptAdanIP
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptLion".lower():
optimizer_class = dadaptation.DAdaptLion
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptSGD".lower():
optimizer_class = dadaptation.DAdaptSGD
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
else:
try:
import prodigyopt
except ImportError:
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
print(f"use Prodigy optimizer | {optimizer_kwargs}")
optimizer_class = prodigyopt.Prodigy
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "Adafactor".lower():
if "relative_step" not in optimizer_kwargs:
optimizer_kwargs["relative_step"] = True
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
optimizer_kwargs["relative_step"] = True
print(f"use Adafactor optimizer | {optimizer_kwargs}")
if optimizer_kwargs["relative_step"]:
print(f"relative_step is true / relative_stepがtrueです")
if lr != 0.0:
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
args.learning_rate = None
if type(trainable_params) == list and type(trainable_params[0]) == dict:
has_group_lr = False
for group in trainable_params:
p = group.pop("lr", None)
has_group_lr = has_group_lr or (p is not None)
if has_group_lr:
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
args.unet_lr = None
args.text_encoder_lr = None
if args.lr_scheduler != "adafactor":
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
args.lr_scheduler = f"adafactor:{lr}"
lr = None
else:
if args.max_grad_norm != 0.0:
print(
f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
)
if args.lr_scheduler != "constant_with_warmup":
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
print(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "OSS".lower():
print(f"use OSS optimizer | {optimizer_kwargs}")
from fairscale.optim import OSS
optimizer_class = OSS
optimizer = optimizer_class(trainable_params, optim=torch.optim.AdamW, lr=lr, **optimizer_kwargs)
if optimizer is None:
optimizer_type = args.optimizer_type
print(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
return optimizer_name, optimizer_args, optimizer
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
lr_scheduler_kwargs = {}
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
for arg in args.lr_scheduler_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
lr_scheduler_kwargs[key] = value
def wrap_check_needless_num_warmup_steps(return_vals):
if num_warmup_steps is not None and num_warmup_steps != 0:
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
return return_vals
if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
if "." not in lr_scheduler_type:
lr_scheduler_module = torch.optim.lr_scheduler
else:
values = lr_scheduler_type.split(".")
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return wrap_check_needless_num_warmup_steps(lr_scheduler)
if name.startswith("adafactor"):
assert (
type(optimizer) == transformers.optimization.Adafactor
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(":")[1])
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs)
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
**lr_scheduler_kwargs,
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
args.caption_extention = None
if args.resolution is not None:
args.resolution = tuple([int(r) for r in args.resolution.split(",")])
if len(args.resolution) == 1:
args.resolution = (args.resolution[0], args.resolution[0])
assert (
len(args.resolution) == 2
), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
if args.face_crop_aug_range is not None:
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
assert (
len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1]
), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
else:
args.face_crop_aug_range = None
if support_metadata:
if args.in_json is not None and (args.color_aug or args.random_crop):
print(
f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます"
)
def load_tokenizer(args: argparse.Namespace):
print("prepare tokenizer")
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def prepare_accelerator(args: argparse.Namespace):
if args.logging_dir is None:
logging_dir = None
else:
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
if args.log_with is None:
if logging_dir is not None:
log_with = "tensorboard"
else:
log_with = None
else:
log_with = args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir, exist_ok=True)
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
kwargs_handlers = (
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
)
return accelerator
def prepare_dtype(args: argparse.Namespace):
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
return weight_dtype, save_dtype
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path)
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
)
else:
print(f"load Diffusers pretrained models: {name_or_path}")
try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
except EnvironmentError as ex:
print(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
)
raise ex
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
del pipe
original_unet = UNet2DConditionModel(
unet.config.sample_size,
unet.config.attention_head_dim,
unet.config.cross_attention_dim,
unet.config.use_linear_projection,
unet.config.upcast_attention,
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
print("U-Net converted to original U-Net")
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
print("additional VAE loaded")
return text_encoder, vae, unet, load_stable_diffusion_format
def transform_if_model_is_DDP(text_encoder, unet, network=None):
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
def transform_models_if_DDP(models):
return [model.module if type(model) == DDP else model for model in models if model is not None]
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args,
weight_dtype,
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)
return text_encoder, vae, unet, load_stable_diffusion_format
def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
if input_ids.size()[-1] != tokenizer.model_max_length:
return text_encoder(input_ids)[0]
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length))
if args.clip_skip is None:
encoder_hidden_states = text_encoder(input_ids)[0]
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
if args.v2:
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
for i in range(1, args.max_token_length, tokenizer.model_max_length):
chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]
if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token:
chunk[j, 0] = chunk[j, 1]
states_list.append(chunk)
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2])
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states
def pool_workaround(
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
eos_token_mask = (input_ids == eos_token_id).int()
eos_token_index = torch.argmax(eos_token_mask, dim=1)
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output
def get_hidden_states_sdxl(
max_token_length: int,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
):
b_size = input_ids1.size()[0]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length))
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length))
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2]
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
n_size = 1 if max_token_length is None else max_token_length // (tokenizer1.model_max_length - 2)
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
states_list = [hidden_states1[:, 0].unsqueeze(1)]
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2])
states_list.append(hidden_states1[:, -1].unsqueeze(1))
hidden_states1 = torch.cat(states_list, dim=1)
states_list = [hidden_states2[:, 0].unsqueeze(1)]
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2]
states_list.append(chunk)
states_list.append(hidden_states2[:, -1].unsqueeze(1))
hidden_states2 = torch.cat(states_list, dim=1)
pool2 = pool2[::n_size]
if weight_dtype is not None:
hidden_states1 = hidden_states1.to(weight_dtype)
hidden_states2 = hidden_states2.to(weight_dtype)
return hidden_states1, hidden_states2, pool2
def get_hidden_states_sdxl_mt5(
max_token_length: int,
input_ids: torch.Tensor,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
weight_dtype: Optional[str] = None,
caption=None,
):
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length))
device = input_ids.device
text_tokens_and_mask = tokenizer(caption, max_length=512, padding="max_length", truncation=True,
return_attention_mask=True, return_tensors='pt')
input_ids = text_tokens_and_mask['input_ids']
attention_mask = text_tokens_and_mask["attention_mask"]
enc_out = text_encoder(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))
hidden_states = enc_out.last_hidden_state
hidden_states = hidden_states.reshape((b_size, -1, hidden_states.shape[-1]))
if max_token_length is not None:
states_list = [hidden_states[:, 0].unsqueeze(1)]
for i in range(1, max_token_length, tokenizer.model_max_length):
states_list.append(hidden_states[:, i : i + tokenizer.model_max_length - 2])
states_list.append(hidden_states[:, -1].unsqueeze(1))
hidden_states1 = torch.cat(states_list, dim=1)
if weight_dtype is not None:
hidden_states1 = hidden_states.to(weight_dtype)
pool2 = None
return hidden_states, pool2, attention_mask
def default_if_none(value, default):
return default if value is None else value
def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int):
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext
def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int):
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
return STEP_FILE_NAME.format(model_name, step_no) + ext
def get_last_ckpt_name(args: argparse.Namespace, ext: str):
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
return model_name + ext
def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
if args.save_last_n_epochs is None:
return None
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
if remove_epoch_no < 0:
return None
return remove_epoch_no
def get_remove_step_no(args: argparse.Namespace, step_no: int):
if args.save_last_n_steps is None:
return None
remove_step_no = step_no - args.save_last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no < 0:
return None
return remove_step_no
def save_sd_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
src_path: str,
save_stable_diffusion_format: bool,
use_safetensors: bool,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
text_encoder,
unet,
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
)
def diffusers_saver(out_dir):
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
save_stable_diffusion_format,
use_safetensors,
epoch,
num_train_epochs,
global_step,
sd_saver,
diffusers_saver,
)
def save_sd_model_on_epoch_end_or_stepwise_common(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_stable_diffusion_format: bool,
use_safetensors: bool,
epoch: int,
num_train_epochs: int,
global_step: int,
sd_saver,
diffusers_saver,
):
if on_epoch_end:
epoch_no = epoch + 1
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
if not saving:
return
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
remove_no = get_remove_epoch_no(args, epoch_no)
else:
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
epoch_no = epoch
remove_no = get_remove_step_no(args, global_step)
os.makedirs(args.output_dir, exist_ok=True)
if save_stable_diffusion_format:
ext = ".safetensors" if use_safetensors else ".ckpt"
if on_epoch_end:
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
else:
ckpt_name = get_step_ckpt_name(args, ext, global_step)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
sd_saver(ckpt_file, epoch_no, global_step)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
if remove_no is not None:
if on_epoch_end:
remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no)
else:
remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no)
remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name)
if os.path.exists(remove_ckpt_file):
print(f"removing old checkpoint: {remove_ckpt_file}")
os.remove(remove_ckpt_file)
else:
if on_epoch_end:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
else:
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
print(f"\nsaving model: {out_dir}")
diffusers_saver(out_dir)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
if remove_no is not None:
if on_epoch_end:
remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
else:
remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
if os.path.exists(remove_out_dir):
print(f"removing old model: {remove_out_dir}")
shutil.rmtree(remove_out_dir)
if args.save_state:
if on_epoch_end:
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
else:
save_and_remove_state_stepwise(args, accelerator, global_step)
def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
print(f"\nsaving state at epoch {epoch_no}")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
print(f"removing old state: {state_dir_old}")
shutil.rmtree(state_dir_old)
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
print(f"\nsaving state at step {step_no}")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
if last_n_steps is not None:
remove_step_no = step_no - last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no > 0:
state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
if os.path.exists(state_dir_old):
print(f"removing old state: {state_dir_old}")
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
print("\nsaving last state.")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
def save_sd_model_on_train_end(
args: argparse.Namespace,
src_path: str,
save_stable_diffusion_format: bool,
use_safetensors: bool,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
text_encoder,
unet,
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
)
def diffusers_saver(out_dir):
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
save_sd_model_on_train_end_common(
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
)
def save_sd_model_on_train_end_common(
args: argparse.Namespace,
save_stable_diffusion_format: bool,
use_safetensors: bool,
epoch: int,
global_step: int,
sd_saver,
diffusers_saver,
):
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
if save_stable_diffusion_format:
os.makedirs(args.output_dir, exist_ok=True)
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
sd_saver(ckpt_file, epoch, global_step)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else:
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
print(f"save trained model as Diffusers to {out_dir}")
diffusers_saver(out_dir)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
if args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
)
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
if args.ip_noise_gamma:
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names = []
if including_unet:
names.append("unet")
names.append("text_encoder1")
names.append("text_encoder2")
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
lrs = lr_scheduler.get_last_lr()
for lr_index in range(len(lrs)):
name = names[lr_index]
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
def sample_images_common(
pipe_class,
accelerator,
args: argparse.Namespace,
epoch,
steps,
device,
vae,
tokenizer,
text_encoder,
unet,
prompt_replacement=None,
controlnet=None,
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None:
return
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
org_vae_device = vae.device
vae.to(device)
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
sched_init_args = {}
if args.sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif args.sample_sampler == "ddpm":
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
scheduler.config.clip_sample = True
pipeline = pipe_class(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
with torch.no_grad():
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m:
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m:
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m:
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m:
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8)
width = max(64, width - width % 8)
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
image = pipeline.latents_to_image(latents)[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)
image.save(os.path.join(save_dir, img_filename))
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except:
pass
del pipeline
torch.cuda.empty_cache()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
class ImageLoadingDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
tensor_pil = transforms.functional.pil_to_tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor_pil, img_path)
class collator_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
class LossRecorder:
def __init__(self):
self.loss_list: List[float] = []
self.loss_total: float = 0.0
def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)