@@ -295,10 +295,15 @@ def __init__(
295295 else :
296296 self .class_data_root = None
297297
298- self .image_transforms = transforms .Compose (
298+ self .image_transforms_resize_and_crop = transforms .Compose (
299299 [
300300 transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR ),
301301 transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
302+ ]
303+ )
304+
305+ self .image_transforms = transforms .Compose (
306+ [
302307 transforms .ToTensor (),
303308 transforms .Normalize ([0.5 ], [0.5 ]),
304309 ]
@@ -312,6 +317,7 @@ def __getitem__(self, index):
312317 instance_image = Image .open (self .instance_images_path [index % self .num_instance_images ])
313318 if not instance_image .mode == "RGB" :
314319 instance_image = instance_image .convert ("RGB" )
320+ instance_image = self .image_transforms_resize_and_crop (instance_image )
315321
316322 example ["PIL_images" ] = instance_image
317323 example ["instance_images" ] = self .image_transforms (instance_image )
@@ -327,6 +333,7 @@ def __getitem__(self, index):
327333 class_image = Image .open (self .class_images_path [index % self .num_class_images ])
328334 if not class_image .mode == "RGB" :
329335 class_image = class_image .convert ("RGB" )
336+ class_image = self .image_transforms_resize_and_crop (class_image )
330337 example ["class_images" ] = self .image_transforms (class_image )
331338 example ["class_PIL_images" ] = class_image
332339 example ["class_prompt_ids" ] = self .tokenizer (
@@ -513,12 +520,6 @@ def main():
513520 )
514521
515522 def collate_fn (examples ):
516- image_transforms = transforms .Compose (
517- [
518- transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
519- transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution ),
520- ]
521- )
522523 input_ids = [example ["instance_prompt_ids" ] for example in examples ]
523524 pixel_values = [example ["instance_images" ] for example in examples ]
524525
@@ -535,9 +536,6 @@ def collate_fn(examples):
535536 pil_image = example ["PIL_images" ]
536537 # generate a random mask
537538 mask = random_mask (pil_image .size , 1 , False )
538- # apply transforms
539- mask = image_transforms (mask )
540- pil_image = image_transforms (pil_image )
541539 # prepare mask and masked image
542540 mask , masked_image = prepare_mask_and_masked_image (pil_image , mask )
543541
@@ -548,9 +546,6 @@ def collate_fn(examples):
548546 for pil_image in pior_pil :
549547 # generate a random mask
550548 mask = random_mask (pil_image .size , 1 , False )
551- # apply transforms
552- mask = image_transforms (mask )
553- pil_image = image_transforms (pil_image )
554549 # prepare mask and masked image
555550 mask , masked_image = prepare_mask_and_masked_image (pil_image , mask )
556551
0 commit comments