@@ -30,14 +30,17 @@ class VaeImageProcessor(ConfigMixin):
3030
3131 Args:
3232 do_resize (`bool`, *optional*, defaults to `True`):
33- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
33+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34+ `height` and `width` arguments from `preprocess` method
3435 vae_scale_factor (`int`, *optional*, defaults to `8`):
3536 VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
3637 factor.
3738 resample (`str`, *optional*, defaults to `lanczos`):
3839 Resampling filter to use when resizing the image.
3940 do_normalize (`bool`, *optional*, defaults to `True`):
4041 Whether to normalize the image to [-1,1]
42+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
43+ Whether to convert the images to RGB format.
4144 """
4245
4346 config_name = CONFIG_NAME
@@ -49,11 +52,12 @@ def __init__(
4952 vae_scale_factor : int = 8 ,
5053 resample : str = "lanczos" ,
5154 do_normalize : bool = True ,
55+ do_convert_rgb : bool = False ,
5256 ):
5357 super ().__init__ ()
5458
5559 @staticmethod
56- def numpy_to_pil (images ) :
60+ def numpy_to_pil (images : np . ndarray ) -> PIL . Image . Image :
5761 """
5862 Convert a numpy image or a batch of images to a PIL image.
5963 """
@@ -69,7 +73,19 @@ def numpy_to_pil(images):
6973 return pil_images
7074
7175 @staticmethod
72- def numpy_to_pt (images ):
76+ def pil_to_numpy (images : Union [List [PIL .Image .Image ], PIL .Image .Image ]) -> np .ndarray :
77+ """
78+ Convert a PIL image or a list of PIL images to numpy arrays.
79+ """
80+ if not isinstance (images , list ):
81+ images = [images ]
82+ images = [np .array (image ).astype (np .float32 ) / 255.0 for image in images ]
83+ images = np .stack (images , axis = 0 )
84+
85+ return images
86+
87+ @staticmethod
88+ def numpy_to_pt (images : np .ndarray ) -> torch .FloatTensor :
7389 """
7490 Convert a numpy image to a pytorch tensor
7591 """
@@ -80,7 +96,7 @@ def numpy_to_pt(images):
8096 return images
8197
8298 @staticmethod
83- def pt_to_numpy (images ) :
99+ def pt_to_numpy (images : torch . FloatTensor ) -> np . ndarray :
84100 """
85101 Convert a pytorch tensor to a numpy image
86102 """
@@ -101,18 +117,39 @@ def denormalize(images):
101117 """
102118 return (images / 2 + 0.5 ).clamp (0 , 1 )
103119
104- def resize (self , images : PIL .Image .Image ) -> PIL .Image .Image :
120+ @staticmethod
121+ def convert_to_rgb (image : PIL .Image .Image ) -> PIL .Image .Image :
122+ """
123+ Converts an image to RGB format.
124+ """
125+ image = image .convert ("RGB" )
126+ return image
127+
128+ def resize (
129+ self ,
130+ image : PIL .Image .Image ,
131+ height : Optional [int ] = None ,
132+ width : Optional [int ] = None ,
133+ ) -> PIL .Image .Image :
105134 """
106135 Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
107136 """
108- w , h = images .size
109- w , h = (x - x % self .config .vae_scale_factor for x in (w , h )) # resize to integer multiple of vae_scale_factor
110- images = images .resize ((w , h ), resample = PIL_INTERPOLATION [self .config .resample ])
111- return images
137+ if height is None :
138+ height = image .height
139+ if width is None :
140+ width = image .width
141+
142+ width , height = (
143+ x - x % self .config .vae_scale_factor for x in (width , height )
144+ ) # resize to integer multiple of vae_scale_factor
145+ image = image .resize ((width , height ), resample = PIL_INTERPOLATION [self .config .resample ])
146+ return image
112147
113148 def preprocess (
114149 self ,
115150 image : Union [torch .FloatTensor , PIL .Image .Image , np .ndarray ],
151+ height : Optional [int ] = None ,
152+ width : Optional [int ] = None ,
116153 ) -> torch .Tensor :
117154 """
118155 Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
@@ -126,10 +163,11 @@ def preprocess(
126163 )
127164
128165 if isinstance (image [0 ], PIL .Image .Image ):
166+ if self .config .do_convert_rgb :
167+ image = [self .convert_to_rgb (i ) for i in image ]
129168 if self .config .do_resize :
130- image = [self .resize (i ) for i in image ]
131- image = [np .array (i ).astype (np .float32 ) / 255.0 for i in image ]
132- image = np .stack (image , axis = 0 ) # to np
169+ image = [self .resize (i , height , width ) for i in image ]
170+ image = self .pil_to_numpy (image ) # to np
133171 image = self .numpy_to_pt (image ) # to pt
134172
135173 elif isinstance (image [0 ], np .ndarray ):
@@ -146,7 +184,12 @@ def preprocess(
146184
147185 elif isinstance (image [0 ], torch .Tensor ):
148186 image = torch .cat (image , axis = 0 ) if image [0 ].ndim == 4 else torch .stack (image , axis = 0 )
149- _ , _ , height , width = image .shape
187+ _ , channel , height , width = image .shape
188+
189+ # don't need any preprocess if the image is latents
190+ if channel == 4 :
191+ return image
192+
150193 if self .config .do_resize and (
151194 height % self .config .vae_scale_factor != 0 or width % self .config .vae_scale_factor != 0
152195 ):
0 commit comments