@@ -145,8 +145,9 @@ def __call__(
145145 process. This is the image whose masked region will be inpainted.
146146 mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
147147 `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
148- replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
149- converted to a single channel (luminance) before use.
148+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
149+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
150+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
150151 strength (`float`, *optional*, defaults to 0.8):
151152 Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
152153 is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -202,10 +203,12 @@ def __call__(
202203 self .scheduler .set_timesteps (num_inference_steps , ** extra_set_kwargs )
203204
204205 # preprocess image
205- init_image = preprocess_image (init_image ).to (self .device )
206+ if not isinstance (init_image , torch .FloatTensor ):
207+ init_image = preprocess_image (init_image )
208+ init_image .to (self .device )
206209
207210 # encode the init image into latents and scale the latents
208- init_latent_dist = self .vae .encode (init_image . to ( self . device ) ).latent_dist
211+ init_latent_dist = self .vae .encode (init_image ).latent_dist
209212 init_latents = init_latent_dist .sample (generator = generator )
210213
211214 init_latents = 0.18215 * init_latents
@@ -215,8 +218,10 @@ def __call__(
215218 init_latents_orig = init_latents
216219
217220 # preprocess mask
218- mask = preprocess_mask (mask_image ).to (self .device )
219- mask = torch .cat ([mask ] * batch_size )
221+ if not isinstance (mask_image , torch .FloatTensor ):
222+ mask_image = preprocess_mask (mask_image )
223+ mask_image .to (self .device )
224+ mask = torch .cat ([mask_image ] * batch_size )
220225
221226 # check sizes
222227 if not mask .shape == init_latents .shape :
0 commit comments