-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add better compatibility with diffusers-interpret (and possibly other use cases!)
#506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da5c092
eee018d
275ba79
259c149
12ca969
01ae240
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,10 +90,9 @@ def disable_attention_slicing(self): | |
| # set slice_size = `None` to disable `attention slicing` | ||
| self.enable_attention_slicing(None) | ||
|
|
||
| @torch.no_grad() | ||
| def __call__( | ||
| self, | ||
| prompt: Union[str, List[str]], | ||
| prompt: Union[str, List[str], torch.Tensor], | ||
| height: Optional[int] = 512, | ||
| width: Optional[int] = 512, | ||
| num_inference_steps: Optional[int] = 50, | ||
|
|
@@ -103,14 +102,17 @@ def __call__( | |
| latents: Optional[torch.FloatTensor] = None, | ||
| output_type: Optional[str] = "pil", | ||
| return_dict: bool = True, | ||
| output_latents: bool = False, | ||
| enable_grad: bool = False, | ||
| **kwargs, | ||
| ): | ||
| r""" | ||
| Function invoked when calling the pipeline for generation. | ||
|
|
||
| Args: | ||
| prompt (`str` or `List[str]`): | ||
| The prompt or prompts to guide the image generation. | ||
| prompt (`str`, `List[str]` or `torch.Tensor`): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ufff, I don't think the input
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we use
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with Patrick.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that we can pass an image to StableDiffusionImg2ImgPipeline.call as a tensor but we can't pass a text... |
||
| The prompt or prompts to guide the image generation. If a `torch.Tensor` is provided, it should | ||
| have the shape (sequence len, embedding dim) or (batch size, sequence len, embedding dim). | ||
| height (`int`, *optional*, defaults to 512): | ||
| The height in pixels of the generated image. | ||
| width (`int`, *optional*, defaults to 512): | ||
|
|
@@ -140,6 +142,11 @@ def __call__( | |
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||
| plain tuple. | ||
| output_latents (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to return the latents from all the diffusion steps. See `latents` under returned tensors | ||
| for more details. | ||
| enable_grad (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to enable gradient calculation during diffusion process. | ||
|
|
||
| Returns: | ||
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||
|
|
@@ -161,33 +168,52 @@ def __call__( | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| self.to(device) | ||
|
|
||
| # enable/disable grad | ||
| was_grad_enabled = torch.is_grad_enabled() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will also be a bit difficult to accept for me. It's a) a bit hacky to me and b) Pipelines by definition should only be used for inference. I assume the gradients are needed for analysis and the idea is not to do training? What do you think @patil-suraj @anton-l @pcuenca ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as Patrick, not really in favor of this.
+1
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
By that means, keeping |
||
| torch.set_grad_enabled(enable_grad) | ||
|
|
||
| if isinstance(prompt, str): | ||
| batch_size = 1 | ||
| elif isinstance(prompt, list): | ||
| batch_size = len(prompt) | ||
| elif torch.is_tensor(prompt): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think |
||
| if len(prompt.shape) == 2: | ||
| # Add batch dimension | ||
| prompt = prompt.unsqueeze(0) | ||
|
|
||
| if len(prompt.shape) != 3: | ||
| raise ValueError( | ||
| f"If `prompt` is of type `torch.Tensor`, it is expected to have a 2 dimensions " | ||
| f"(sequence len, embedding dim) or 3 dimensions (batch size, sequence len, embedding dim), " | ||
| f"but found tensor with shape {prompt.shape}" | ||
JoaoLages marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| batch_size = prompt.shape[0] | ||
| else: | ||
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
|
|
||
| if height % 8 != 0 or width % 8 != 0: | ||
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
|
|
||
| # get prompt text embeddings | ||
| text_input = self.tokenizer( | ||
| prompt, | ||
| padding="max_length", | ||
| max_length=self.tokenizer.model_max_length, | ||
| truncation=True, | ||
| return_tensors="pt", | ||
| ) | ||
| text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | ||
| if torch.is_tensor(prompt): | ||
| text_embeddings = prompt | ||
| else: | ||
| text_input = self.tokenizer( | ||
| prompt, | ||
| padding="max_length", | ||
| max_length=self.tokenizer.model_max_length, | ||
| truncation=True, | ||
| return_tensors="pt", | ||
| ) | ||
| text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | ||
|
|
||
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
| # corresponds to doing no classifier free guidance. | ||
| do_classifier_free_guidance = guidance_scale > 1.0 | ||
| # get unconditional embeddings for classifier free guidance | ||
| if do_classifier_free_guidance: | ||
| max_length = text_input.input_ids.shape[-1] | ||
| max_length = text_embeddings.shape[-2] | ||
| uncond_input = self.tokenizer( | ||
| [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | ||
| ) | ||
|
|
@@ -237,6 +263,7 @@ def __call__( | |
| if accepts_eta: | ||
| extra_step_kwargs["eta"] = eta | ||
|
|
||
| all_latents = [latents] if output_latents else None | ||
| for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): | ||
| # expand the latents if we are doing classifier free guidance | ||
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
|
|
@@ -259,6 +286,10 @@ def __call__( | |
| else: | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
|
|
||
| if output_latents: | ||
| # save latents from all diffusion steps | ||
| all_latents.append(latents) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fine with this! Think this makes a lot of sense |
||
|
|
||
| # scale and decode the image latents with vae | ||
| latents = 1 / 0.18215 * latents | ||
| image = self.vae.decode(latents).sample | ||
|
|
@@ -276,4 +307,7 @@ def __call__( | |
| if not return_dict: | ||
| return (image, has_nsfw_concept) | ||
|
|
||
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | ||
| # reset | ||
| torch.set_grad_enabled(was_grad_enabled) | ||
|
|
||
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, latents=all_latents) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine with me! What do you think @pcuenca @anton-l @patil-suraj ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine with me as well.