Skip to content

Conversation

@JoaoLages
Copy link

@JoaoLages JoaoLages commented Sep 13, 2022

Hi there!
I love this package ❤️

I'm the author of diffusers-interpret and along my work found these features very useful to add to this main package:

  • Having the option to run DiffusionPipeline.__call__ while calculating gradients;
  • Having a output_latents flag (similar to output_scores/output_attentions/etc from transformers) that adds a latents attribute to the output;
  • Deactivating safety checker; removed this option (12ca969)
  • Passing text_embeddings directly instead of the text string;
  • Gradient checkpointing (this was already a feature in transformers too).

That's about it 😄
To start this PR I made the changes only for the StableDiffusionPipeline class, but I can port those changes to the other pipelines if you agree with them.

Copy link
Contributor

@keturn keturn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi João! I was just introduced to diffusers-interpret yesterday via the discord! I have all the same questions so I love seeing that sort of thing.

I have no authority to merge anything here, but I've taken the liberty of leaving a few notes.

Comment on lines 329 to 330
if not return_dict:
return (image, has_nsfw_concept)
return (image, has_nsfw_concept, all_latents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if PipelineOutput classes are the way forward and the tuple return format here is mainly for backwards compatibility, we should leave it the same size it was (a pair) and not worry about adding new features to it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this method had a @torch.no_grad decorator, I don't think this is for backwards compatibility 🤔
But looking at the transformers package, it seems that when return_dict_generate=False, options like output_scores/output_attentions don't matter, so it makes sense to remove latents from the tuple as you mention :)


images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
latents: Optional[List[torch.Tensor]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me! What do you think @pcuenca @anton-l @patil-suraj ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me as well.

Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str`, `List[str]` or `torch.Tensor`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ufff, I don't think the input prompt should ever be a tensor that's confusing and opens the box for hacky code - can't we just work with the latents inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we use inputs_embeds as some methods in transformers have?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Patrick.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we can pass an image to StableDiffusionImg2ImgPipeline.call as a tensor but we can't pass a text...
Would and extra inputs_embeds or prompt_embeds argument work as an alternative?

batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
elif torch.is_tensor(prompt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think prompt should ever be a tensor


if output_latents:
# save latents from all diffusion steps
all_latents.append(latents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with this! Think this makes a lot of sense

self.to(device)

# enable/disable grad
was_grad_enabled = torch.is_grad_enabled()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also be a bit difficult to accept for me. It's a) a bit hacky to me and b) Pipelines by definition should only be used for inference. I assume the gradients are needed for analysis and the idea is not to do training?
I'm not 100% sure whether they are enough use cases that warrant allowing gradient flow here on the other hand it also shouldn't hurt really if we leave the default to False. IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

What do you think @patil-suraj @anton-l @pcuenca ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Patrick, not really in favor of this.

IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

+1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

By that means, keeping @torch.no_grad decorator in __call__ and add put the whole method under with torch.enable_grad() if enable_grad else nullcontext(): ?
Not sure what you meant.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @JoaoLages,

Super cool interpret library btw! Like the idea of returning the latents and could also be convinced to allow gradient computation (even though I think the use case is too niche for now and would only be pro if all the community really wants this feature cc @hysts @pcuenca @patil-suraj @anton-l - wdyt?). Don't think we should allow the prompt to be a torch.Tensor.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @JoaoLages !

I have pretty much same comments as Patrick.

  • We could definitely return intermidiate latents
  • gradient checkpointing will be supported soon. It should not be included in pipeline like this. Pipelines are intended for inference only so it's best to avoid training related logic here.
  • enabling gradients: We could add this if community is really interested in it. Could you please open an issue for this ?


images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
latents: Optional[List[torch.Tensor]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me as well.

Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str`, `List[str]` or `torch.Tensor`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Patrick.

self.to(device)

# enable/disable grad
was_grad_enabled = torch.is_grad_enabled()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Patrick, not really in favor of this.

IMO working with function decorators and enable_grad + disable_grad functions is the way to go here though instead.

+1

@JoaoLages JoaoLages requested review from patil-suraj and patrickvonplaten and removed request for patrickvonplaten September 16, 2022 13:18
@JoaoLages
Copy link
Author

  • gradient checkpointing will be supported soon. It should not be included in pipeline like this. Pipelines are intended for inference only so it's best to avoid training related logic here.

🚀

  • enabling gradients: We could add this if community is really interested in it. Could you please open an issue for this ?

There you go #529

@patrickvonplaten
Copy link
Contributor

Closing this PR as it does too many changes at once -> happy to continue the discussion on the single PRs that were opened :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants