Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
latents (`List[torch.Tensor]`, *optional*, returned when `output_latents=True` and `return_dict=True` is passed)
List (one element for each diffusion step) of `torch.Tensor` of shape `(batch_size, in_channels, height // 8, width // 8)`
"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
latents: Optional[List[torch.Tensor]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me! What do you think @pcuenca @anton-l @patil-suraj ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me as well.



if is_transformers_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,9 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str], torch.Tensor],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
Expand All @@ -103,14 +102,17 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
output_latents: bool = False,
enable_grad: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.

Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str`, `List[str]` or `torch.Tensor`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ufff, I don't think the input prompt should ever be a tensor that's confusing and opens the box for hacky code - can't we just work with the latents inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we use inputs_embeds as some methods in transformers have?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Patrick.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we can pass an image to StableDiffusionImg2ImgPipeline.call as a tensor but we can't pass a text...
Would and extra inputs_embeds or prompt_embeds argument work as an alternative?

The prompt or prompts to guide the image generation. If a `torch.Tensor` is provided, it should
have the shape (sequence len, embedding dim) or (batch size, sequence len, embedding dim).
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
Expand Down Expand Up @@ -140,6 +142,11 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
output_latents (`bool`, *optional*, defaults to `False`):
Whether or not to return the latents from all the diffusion steps. See `latents` under returned tensors
for more details.
enable_grad (`bool`, *optional*, defaults to `False`):
Whether or not to enable gradient calculation during diffusion process.

Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand All @@ -161,33 +168,52 @@ def __call__(
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

# enable/disable grad
was_grad_enabled = torch.is_grad_enabled()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also be a bit difficult to accept for me. It's a) a bit hacky to me and b) Pipelines by definition should only be used for inference. I assume the gradients are needed for analysis and the idea is not to do training?
I'm not 100% sure whether they are enough use cases that warrant allowing gradient flow here on the other hand it also shouldn't hurt really if we leave the default to False. IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

What do you think @patil-suraj @anton-l @pcuenca ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Patrick, not really in favor of this.

IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

+1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

By that means, keeping @torch.no_grad decorator in __call__ and add put the whole method under with torch.enable_grad() if enable_grad else nullcontext(): ?
Not sure what you meant.

torch.set_grad_enabled(enable_grad)

if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
elif torch.is_tensor(prompt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think prompt should ever be a tensor

if len(prompt.shape) == 2:
# Add batch dimension
prompt = prompt.unsqueeze(0)

if len(prompt.shape) != 3:
raise ValueError(
f"If `prompt` is of type `torch.Tensor`, it is expected to have a 2 dimensions "
f"(sequence len, embedding dim) or 3 dimensions (batch size, sequence len, embedding dim), "
f"but found tensor with shape {prompt.shape}"
)
batch_size = prompt.shape[0]
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
if torch.is_tensor(prompt):
text_embeddings = prompt
else:
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
max_length = text_embeddings.shape[-2]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
Expand Down Expand Up @@ -237,6 +263,7 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

all_latents = [latents] if output_latents else None
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand All @@ -259,6 +286,10 @@ def __call__(
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

if output_latents:
# save latents from all diffusion steps
all_latents.append(latents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with this! Think this makes a lot of sense


# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
Expand All @@ -276,4 +307,7 @@ def __call__(
if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
# reset
torch.set_grad_enabled(was_grad_enabled)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, latents=all_latents)