diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py index bb5672669d5b..d368453a4054 100644 --- a/examples/dreambooth/train_dreambooth_inpaint.py +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -295,10 +295,15 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose( + self.image_transforms_resize_and_crop = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + ] + ) + + self.image_transforms = transforms.Compose( + [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -312,6 +317,7 @@ def __getitem__(self, index): instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") + instance_image = self.image_transforms_resize_and_crop(instance_image) example["PIL_images"] = instance_image example["instance_images"] = self.image_transforms(instance_image) @@ -327,6 +333,7 @@ def __getitem__(self, index): class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") + class_image = self.image_transforms_resize_and_crop(class_image) example["class_images"] = self.image_transforms(class_image) example["class_PIL_images"] = class_image example["class_prompt_ids"] = self.tokenizer( @@ -513,12 +520,6 @@ def main(): ) def collate_fn(examples): - image_transforms = transforms.Compose( - [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - ] - ) input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -535,9 +536,6 @@ def collate_fn(examples): pil_image = example["PIL_images"] # generate a random mask mask = random_mask(pil_image.size, 1, False) - # apply transforms - mask = image_transforms(mask) - pil_image = image_transforms(pil_image) # prepare mask and masked image mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) @@ -548,9 +546,6 @@ def collate_fn(examples): for pil_image in pior_pil: # generate a random mask mask = random_mask(pil_image.size, 1, False) - # apply transforms - mask = image_transforms(mask) - pil_image = image_transforms(pil_image) # prepare mask and masked image mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)