diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 1ffb73cee4a2..c553f75a7d89 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -31,6 +31,7 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +import torch.utils.data.sampler import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -298,6 +299,26 @@ def parse_args(input_args=None): " or to a folder containing files that 🤗 Datasets can understand." ), ) + parser.add_argument( + "--bucket_size", + type=str, + default=None, + help=( + "Set Bucket datasets, Like ‘768x1280,896x1120,1024x1024,1280x768’" + ), + ) + parser.add_argument( + "--local_config_file_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--local_config_text_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) parser.add_argument( "--dataset_config_name", type=str, @@ -421,7 +442,7 @@ def parse_args(input_args=None): ) parser.add_argument( "--center_crop", - default=False, + default=True, action="store_true", help=( "Whether to center crop the input images to the resolution. If not set, the images will be randomly" @@ -792,60 +813,138 @@ def __init__( self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] - self.custom_instance_prompts = None + if args.local_config_file_name == None: + instance_files = [f for f in os.listdir(self.instance_data_root) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] + self.custom_instance_prompts = None + else: + config_path = self.instance_data_root / 'metadata.jsonl' + instance_files = [] + self.custom_instance_prompts = [] + with open(config_path, 'r', encoding='utf-8') as f: + for line in f: + #TODO + line = line.strip() + if not line: + continue + data = json.loads(line) + instance_files.append(data[args.local_config_file_name]) + self.custom_instance_prompts.append(data[args.local_config_text_name]) + # print(f'///////////{instance_files}\n??????????{self.custom_instance_prompts}') + instance_images = [Image.open(os.path.join(self.instance_data_root, path)) for path in instance_files] self.instance_images = [] - for img in instance_images: - self.instance_images.extend(itertools.repeat(img, repeats)) + if args.image_column != None: + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + self.bucket = [] + if args.bucket_size: + for size in args.bucket_size.split(','): + w, h = size.split('x') + self.bucket.append((int(w), int(h))) + else: + self.bucket.append((str(args.resolution), str(args.resolution))) # image processing to prepare for using SD-XL micro-conditioning + self.file_name = [] self.original_sizes = [] self.crop_top_lefts = [] self.pixel_values = [] - + self.bucker_index = [] + self.target_size = [] interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) if interpolation is None: raise ValueError(f"Unsupported interpolation mode {interpolation=}.") - train_resize = transforms.Resize(size, interpolation=interpolation) - - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( + # train_resize = transforms.Resize(size, interpolation=interpolation) + # train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + self.train_flip = transforms.RandomHorizontalFlip(p=1.0) + self.train_transforms = transforms.Compose( [ transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) - for image in self.instance_images: + # per_pic_best_ratio = [] + for image_index , image in enumerate(self.instance_images): image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") + # instance_size = (image.height, image.width) self.original_sizes.append((image.height, image.width)) - image = train_resize(image) + min_ratio = 999 + target_w, target_h, bucket_index = -1, -1, -1 + w_over_h = image.width / image.height + for i, (w, h) in enumerate(self.bucket): + w, h = int(w), int(h) + ratio_gap = abs(w_over_h - w / h) + if ratio_gap < min_ratio: + min_ratio = ratio_gap + target_w, target_h= w, h + bucket_index = i + scale = max(target_w / image.width, target_h / image.height) + new_w = math.ceil(scale * image.width) + new_h = math.ceil(scale * image.height) + image = image.resize((new_w, new_h), Image.LANCZOS) + self.target_size.append((target_h, target_w)) + # per_pic_best_ratio.append(min_ratio) + # image = train_resize(image) if args.random_flip and random.random() < 0.5: # flip - image = train_flip(image) + image = self.train_flip(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) - image = train_crop(image) + # y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + # x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + y1 = max(0, int(round((new_h - target_h) / 2.0))) + x1 = max(0, int(round((new_w - target_w) / 2.0))) + # image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) - image = crop(image, y1, x1, h, w) + y1, x1, _, _ = transformers.RandomCrop.get_params(image, (target_h, target_w)) + # y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + # image = crop(image, y1, x1, h, w) crop_top_left = (y1, x1) self.crop_top_lefts.append(crop_top_left) - image = train_transforms(image) - self.pixel_values.append(image) + image = crop(image, y1, x1, target_h, target_w) + # print(f'处理instance图像:原始size:{instance_size},缩放后size{image.height, image.width}') + + img_t = self.train_transforms(image) + self.pixel_values.append(img_t) + self.bucker_index.append(bucket_index) + self.file_name.append(instance_files[image_index]) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images - + self.class_images_by_bucket = {} if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) + self.class_images_by_bucket = {i: [] for i in range(len(self.bucket))} + print(f"Preprocessing {len(self.class_images_path)} class images for bucketing...") + for path in tqdm(self.class_images_path, desc="Bucketing class images"): + try: + image = Image.open(path) + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + except Exception as e: + print(f"Could not load class image {path}, skipping. Error: {e}") + continue + + cls_w_over_h = image.width / image.height + + min_ratio_gap = float('inf') + best_bucket_index = -1 + for i, (w, h) in enumerate(self.bucket): + bucket_ratio = int(w) / int(h) + ratio_gap = abs(cls_w_over_h - bucket_ratio) + if ratio_gap < min_ratio_gap: + min_ratio_gap = ratio_gap + best_bucket_index = i + if best_bucket_index != -1: + self.class_images_by_bucket[best_bucket_index].append(path) + self.available_buckets = [k for k, v in self.class_images_by_bucket.items() if len(v) > 0] + if len(self.available_buckets) != len(self.bucket): + warnings.warn("Warning: Some buckets do not have any corresponding class images. This might lead to errors if an instance image falls into one of these empty buckets.") + if class_num is not None: self.num_class_images = min(len(self.class_images_path), class_num) else: @@ -853,13 +952,10 @@ def __init__( self._length = max(self.num_class_images, self.num_instance_images) else: self.class_data_root = None - self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ] ) @@ -871,9 +967,13 @@ def __getitem__(self, index): instance_image = self.pixel_values[index % self.num_instance_images] original_size = self.original_sizes[index % self.num_instance_images] crop_top_left = self.crop_top_lefts[index % self.num_instance_images] + target_size = self.target_size[index % self.num_instance_images] + bucker_index = self.bucker_index[index % self.num_instance_images] example["instance_images"] = instance_image example["original_size"] = original_size example["crop_top_left"] = crop_top_left + example["target_size"] = target_size + example["bucker_index"] = bucker_index if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -881,44 +981,78 @@ def __getitem__(self, index): example["instance_prompt"] = caption else: example["instance_prompt"] = self.instance_prompt - + elif args.local_config_file_name!=None and args.local_config_text_name!=None: + example["instance_prompt"] = self.custom_instance_prompts[index % self.num_instance_images] else: # custom prompts were provided, but length does not match size of image dataset example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + bucket_images = self.class_images_by_bucket.get(bucker_index) + if not bucket_images: + random_bucket_index = random.choice(self.available_buckets) + bucket_images = self.class_images_by_bucket[random_bucket_index] + warnings.warn(f"Bucket {bucker_index} has no class images. Falling back to bucket {random_bucket_index}.") + cls_path = random.choice(bucket_images) + # cls_path = self.class_images_by_bucket[random_bucket_index][index % len(self.class_images_by_bucket[random_bucket_index])] + class_image = Image.open(cls_path) + if args.random_flip and random.random() < 0.5: + class_image = self.train_flip(class_image) + cls_original_size = (class_image.height, class_image.width) + example["class_original_size"] =cls_original_size class_image = exif_transpose(class_image) - if not class_image.mode == "RGB": class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) + target_h, target_w = target_size + h_scale = target_h / class_image.height + w_scale = target_w / class_image.width + scale = max(h_scale, w_scale) + new_h = math.ceil(scale * class_image.height) + new_w = math.ceil(scale * class_image.width) + class_image = class_image.resize((new_w, new_h), Image.LANCZOS) + y1 = max(0, int(round((new_h - target_h) / 2.0))) + x1 = max(0, int(round((new_w - target_w) / 2.0))) + class_image = crop(class_image, y1, x1, target_h, target_w) + class_image_tensor = self.train_transforms(class_image) + example["class_images"] = class_image_tensor + example["class_crop_top_left"] = (y1, x1) + # print(f'处理cls图像:原始size:{cls_original_size},缩放后size{class_image.height, class_image.width}') + # class_image = Image.open(self.class_images_path[index % self.num_class_images]) + # class_image = exif_transpose(class_image) + + # if not class_image.mode == "RGB": + # class_image = class_image.convert("RGB") + # example["class_images"] = self.image_transforms(class_image) example["class_prompt"] = self.class_prompt - + # print(f'check pick item:{example["instance_prompt"]},\n file_name in {self.file_name[index % self.num_instance_images]}') return example def collate_fn(examples, with_prior_preservation=False): + #TODO: add class info pixel_values = [example["instance_images"] for example in examples] prompts = [example["instance_prompt"] for example in examples] original_sizes = [example["original_size"] for example in examples] crop_top_lefts = [example["crop_top_left"] for example in examples] + bucker_index = [example["bucker_index"] for example in examples] + target_size = [example["target_size"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] prompts += [example["class_prompt"] for example in examples] - original_sizes += [example["original_size"] for example in examples] - crop_top_lefts += [example["crop_top_left"] for example in examples] + original_sizes += [example["class_original_size"] for example in examples] + crop_top_lefts += [example["class_crop_top_left"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = { "pixel_values": pixel_values, "prompts": prompts, "original_sizes": original_sizes, "crop_top_lefts": crop_top_lefts, + "bucker_index": bucker_index, + "target_size": target_size } return batch @@ -1196,7 +1330,6 @@ def main(args): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") - if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: @@ -1461,22 +1594,58 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, ) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), - num_workers=args.dataloader_num_workers, - ) + class BucketBatchSampler(torch.utils.data.Sampler): + def __init__(self, bucket_column, bs, shuffle =True): + self.bucket_dict = {} + for i, index in enumerate(bucket_column): + if int(index) not in self.bucket_dict: + self.bucket_dict[int(index)] = [i] + else: + self.bucket_dict[int(index)].append(i) + self.bs = bs + self.shuffle = shuffle + + def __iter__(self): + buckets = list(self.bucket_dict.items()) + if self.shuffle: + random.shuffle(buckets) + for _, img_index in buckets: + img_index_cp = img_index[:] # dont mess up the relation of index and buckets + if self.shuffle: + random.shuffle(img_index_cp) + for i in range(0, len(img_index), self.bs): + yield img_index_cp[i: i + self.bs] + def __len__(self): + total = 0 + for i in self.bucket_dict.values(): + total += math.ceil(len(i) / self.bs) + return total + + if not args.bucket_size: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + else: + bucket_sampler = BucketBatchSampler(train_dataset.bucker_index, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=bucket_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + ) # Computes additional embeddings/ids required by the SDXL UNet. # regular text embeddings (when `train_text_encoder` is not True) # pooled text embeddings # time ids - def compute_time_ids(original_size, crops_coords_top_left): + def compute_time_ids(original_size, crops_coords_top_left, target_size, bucket_index): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - target_size = (args.resolution, args.resolution) + # target_size = (args.resolution, args.resolution) + # print(f'当前图形size:{target_size}, 在桶:{bucket_index}') add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) @@ -1668,10 +1837,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] - - # encode batch prompts when custom prompts are provided for each image - + # print(f'???????{prompts}') + # encode batch prompts when custom prompts are provided for each image if train_dataset.custom_instance_prompts: + # print('进行分别计算每个embed') if not args.train_text_encoder: + prompt_embeds, unet_add_text_embeds = compute_text_embeddings( prompts, text_encoders, tokenizers ) @@ -1725,24 +1896,34 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # time ids add_time_ids = torch.cat( [ - compute_time_ids(original_size=s, crops_coords_top_left=c) - for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"]) + compute_time_ids(original_size=s, crops_coords_top_left=c, target_size=t, bucket_index=i) + for s, c, t, i in zip(batch["original_sizes"], batch["crop_top_lefts"], batch["target_size"], batch["bucker_index"]) ] ) # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. if not train_dataset.custom_instance_prompts: elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + # elems_to_repeat_text_embeds = bsz else: elems_to_repeat_text_embeds = 1 # Predict the noise residual if not args.train_text_encoder: + # print(f'init_embed:{unet_add_text_embeds.shape}, times:{elems_to_repeat_text_embeds}') + if elems_to_repeat_text_embeds == 1 and args.with_prior_preservation: + double_add_time_ids = add_time_ids.repeat_interleave(2, dim=0) + else: + double_add_time_ids = add_time_ids.repeat_interleave(elems_to_repeat_text_embeds, dim=0) + # print(f'qqqqqq{double_add_time_ids.shape}') unet_added_conditions = { - "time_ids": add_time_ids, - "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), + "time_ids": double_add_time_ids, + "text_embeds": unet_add_text_embeds.repeat_interleave(elems_to_repeat_text_embeds, dim=0), } - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + prompt_embeds_input = prompt_embeds.repeat_interleave(elems_to_repeat_text_embeds, dim=0) + # print(f"1111Shape of add_time_ids: {add_time_ids.shape}") + # print(f"1111Shape of unet_add_text_embeds (pooled): {unet_added_conditions['text_embeds'].shape}") + # print(f"提示词输入向量:{prompt_embeds_input.shape}, 他第一维度应该为:{elems_to_repeat_text_embeds}, 其他应为:{prompt_embeds.shape}") model_pred = unet( inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, @@ -1751,17 +1932,28 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): return_dict=False, )[0] else: - unet_added_conditions = {"time_ids": add_time_ids} + if add_time_ids.shape[0] == 1: + double_add_time_ids = add_time_ids.repeat_interleave(2, dim=0) + else: + double_add_time_ids = add_time_ids.repeat_interleave(elems_to_repeat_text_embeds, dim=0) + # print(f'qqqqqq{double_add_time_ids.shape}') + unet_added_conditions = {"time_ids": double_add_time_ids} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two], ) + # unet_added_conditions.update( + # {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + # ) + #!!change unet_added_conditions.update( - {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + {"text_embeds": pooled_prompt_embeds.repeat_interleave(elems_to_repeat_text_embeds, dim=0)} ) + prompt_embeds_input = prompt_embeds.repeat_interleave(elems_to_repeat_text_embeds, dim=0) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + # print(f'test:{prompt_embeds_input}////////////////') model_pred = unet( inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, @@ -1836,8 +2028,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + if args.with_prior_preservation: + snr = snr.chunk(2, dim=0)[0] + + # snr = compute_snr(noise_scheduler, timesteps).chunk(2, dim=0)[0] + # pred_shape = (timesteps.shape[0] // 2, *timesteps.shape[1:]) base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(snr)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": @@ -1848,6 +2045,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mse_loss_weights = base_weight loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + # print(loss.shape, mse_loss_weights.shape) loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 5fb1825f37d3..09f131ce3208 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -23,7 +23,8 @@ import shutil from contextlib import nullcontext from pathlib import Path - +import io +from PIL import Image as PILImage import datasets import numpy as np import torch @@ -205,6 +206,17 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--enable_bucketing", + action="store_true", + help="Enable aspect-ratio / size bucketing. When enabled, images are grouped into buckets and each batch contains images of the same bucket.", + ) + parser.add_argument( + "--buckets", + type=str, + default="512x512,768x768,1024x1024,1024x1536,1536x1024", + help="Comma-separated list of buckets W x H, e.g. '512x512,1024x1536'. If not specified, defaults provided.", + ) parser.add_argument( "--variant", type=str, @@ -300,7 +312,7 @@ def parse_args(input_args=None): ) parser.add_argument( "--center_crop", - default=False, + default=True, action="store_true", help=( "Whether to center crop the input images to the resolution. If not set, the images will be randomly" @@ -536,12 +548,18 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False + text_input_ids.to(text_encoder.device), output_hidden_states=True, + # return_dict=False !!change ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] + #!!change + # pooled_prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds[-1][-2] + #!!change + # prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -869,6 +887,7 @@ def load_model_hook(models, input_dir): dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) + else: data_files = {} if args.train_data_dir is not None: @@ -904,6 +923,59 @@ def load_model_hook(models, input_dir): f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" ) + def parse_buckets(buckets_str): + buckets = [] + for item in buckets_str.split(","): + item = item.strip() + if "x" in item: + try: + w, h = item.split("x") + buckets.append((int(w), int(h))) + except Exception: + continue + return buckets + + buckets = parse_buckets(args.buckets) + if len(buckets) == 0: + buckets = [(args.resolution, args.resolution)] + + # Pre-scan dataset["train"] sizes and assign bucket indices (run on main process) + bucket_assignments = None + with accelerator.main_process_first(): + sizes = [] + for idx in range(len(dataset["train"])): + example = dataset["train"][idx] + img_field = example[image_column] + try: + if isinstance(img_field, dict) and "path" in img_field: + with PILImage.open(img_field["path"]) as im: + w, h = im.size + elif hasattr(img_field, "size"): + w, h = img_field.size + elif isinstance(img_field, (bytes, bytearray)): + with PILImage.open(io.BytesIO(img_field)) as im: + w, h = im.size + else: + # fallback + with PILImage.open(img_field) as im: + w, h = im.size + except Exception: + w, h = args.resolution, args.resolution + sizes.append((w, h)) + + def aspect(w, h): + return float(w) / float(h) + + bucket_aspects = [aspect(w, h) for (w, h) in buckets] + bucket_assignments = [] + for (w, h) in sizes: + ar = aspect(w, h) + best_i = int(min(range(len(bucket_aspects)), key=lambda i: abs(ar - bucket_aspects[i]))) + bucket_assignments.append(best_i) + + # Add bucket column to dataset so transform can access it + dataset["train"] = dataset["train"].add_column("bucket", bucket_assignments) + # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(examples, is_train=True): @@ -942,26 +1014,43 @@ def tokenize_captions(examples, is_train=True): def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] # image aug + batch_buckets = examples.get('bucket', [None] * len(images)) original_sizes = [] + target_sizes = [] all_images = [] crop_top_lefts = [] - for image in images: + for i, image in enumerate(images): + iw, ih = image.width, image.height original_sizes.append((image.height, image.width)) - image = train_resize(image) + # image = train_resize(image) + b = batch_buckets[i] if batch_buckets is not None else None + if (args.enable_bucketing and b is not None): + try: + target_w, target_h = buckets[b] + # print(f'装桶成功啦:{target_w}, {target_h}') + except Exception as e: + print(f'装桶没成功?why:{e}') + target_w, target_h = args.resolution, args.resolution + else: + target_w, target_h = args.resolution, args.resolution + scale = max(target_w / float(iw), target_h / float(ih)) + new_w = math.ceil(iw * scale) + new_h = math.ceil(ih * scale) + img = image.resize((new_w, new_h), PILImage.LANCZOS) if args.random_flip and random.random() < 0.5: # flip - image = train_flip(image) + img = train_flip(img) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) - image = train_crop(image) + y1 = max(0, int(round((new_h - target_h) / 2.0))) + x1 = max(0, int(round((new_w - target_w) / 2.0))) + # image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) - image = crop(image, y1, x1, h, w) - crop_top_left = (y1, x1) - crop_top_lefts.append(crop_top_left) - image = train_transforms(image) - all_images.append(image) + y1, x1, _, _ = transforms.RandomCrop.get_params(img, (target_h, target_w)) + crop_top_lefts.append((y1, x1)) + img = crop(img, y1, x1, target_h, target_w) + img_t = train_transforms(img) + all_images.append(img_t) + target_sizes.append((target_h, target_w)) examples["original_sizes"] = original_sizes examples["crop_top_lefts"] = crop_top_lefts @@ -969,6 +1058,7 @@ def preprocess_train(examples): tokens_one, tokens_two = tokenize_captions(examples) examples["input_ids_one"] = tokens_one examples["input_ids_two"] = tokens_two + examples["target_sizes"] = target_sizes if args.debug_loss: fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename] if fnames: @@ -995,21 +1085,60 @@ def collate_fn(examples): "original_sizes": original_sizes, "crop_top_lefts": crop_top_lefts, } - + target_sizes = [example["target_sizes"] for example in examples] filenames = [example["filenames"] for example in examples if "filenames" in example] + result['bucket_index'] = [example['bucket'] for example in examples] + result["target_sizes"] = target_sizes if filenames: result["filenames"] = filenames return result - - # DataLoaders creation: - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=collate_fn, - batch_size=args.train_batch_size, - num_workers=args.dataloader_num_workers, - ) - + + class BucketBatchSample(torch.utils.data.Sampler): + def __init__(self, bucket_column, batch_size, shuffle=True): + self.bucket_to_indices = {} + for i, b in enumerate(bucket_column): + if int(b) not in self.bucket_to_indices: + self.bucket_to_indices[int(b)] = [i] + else: + self.bucket_to_indices[int(b)].append(i) + self.batch_size = batch_size + self.shuffle = shuffle + + def __iter__(self): + buckets = list(self.bucket_to_indices.items()) + if self.shuffle: + random.shuffle(buckets) + for bucket_id, index in buckets: + indx = index[:] + if self.shuffle: + random.shuffle(indx) + for i in range(0, len(index), self.batch_size): + # print(f'当前桶id:{bucket_id}') + yield indx[i:i+self.batch_size] + + def __len__(self): + total = 0 + for idx in self.bucket_to_indices.values(): + total += math.ceil(len(idx) / self.batch_size) + return total + + if args.enable_bucketing: + bucket_col = dataset['train']['bucket'] + batch_sampler = BucketBatchSample(bucket_col, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + else: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1134,16 +1263,17 @@ def collate_fn(examples): noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # time ids - def compute_time_ids(original_size, crops_coords_top_left): + def compute_time_ids(original_size, crops_coords_top_left, target_size, batch_index): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - target_size = (args.resolution, args.resolution) + # target_size = (args.resolution, args.resolution) + # print(f'------\n算时间嵌入,当前tgt_size:{target_size}, 当前桶序号{batch_index}---------\n') add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) return add_time_ids add_time_ids = torch.cat( - [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + [compute_time_ids(s, c, t, i) for s, c, t, i in zip(batch["original_sizes"], batch["crop_top_lefts"], batch["target_sizes"], batch['bucket_index'])] ) # Predict the noise residual @@ -1212,7 +1342,7 @@ def compute_time_ids(original_size, crops_coords_top_left): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) + accelerator.log({"train_loss": train_loss,'lr':optimizer.param_groups[0]['lr']}, step=global_step) train_loss = 0.0 # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. @@ -1240,6 +1370,10 @@ def compute_time_ids(original_size, crops_coords_top_left): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) + des_path = '/content/drive/MyDrive/latest_lora' + if os.path.exists(des_path): + os.remove(f'{des_path}.zip') + shutil.make_archive(des_path, 'zip', save_path) logger.info(f"Saved state to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}