diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index ca06136d7829..e974898f5f6b 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -63,6 +63,31 @@ def torch_dfs(model: torch.nn.Module): class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline): + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): refimage = refimage.to(device=device, dtype=dtype) @@ -230,8 +255,6 @@ def __call__( self.check_inputs( prompt, image, - height, - width, callback_steps, negative_prompt, prompt_embeds,