-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Warning for too long prompts in DiffusionPipelines (Resolve #447) #472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Warning for too long prompts in DiffusionPipelines (Resolve #447) #472
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Hey @shirayu, Thanks for opening the PR! I think we can make our lives a bit easier here by simply catching whether the input was truncated or not before hand - then we don't have to return the text embeddings and check after the fact :-). How about we replace these lines here:
with: text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length:])
logger.warn(f"The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: {removed_text}")
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length].to(self.device)Then the user gets a nice warning and we are error-robust :-) |
|
Thank you for the comment @patrickvonplaten ! I think using |
|
Hey @shirayu, I think you can catch warnings with PyTorch: https://stackoverflow.com/questions/5644836/in-python-how-does-one-catch-warnings-as-if-they-were-exceptions I'm not sure we want to return text embeddings just for a potential warning that could be displayed later - the use case is too small to warrant adding a new output tuple which might break pipelines that expect only two outputs. Would it be ok for you to try out catching the warning or adding a specific logger? https://stackoverflow.com/questions/14058453/making-python-loggers-output-all-messages-to-stdout-in-addition-to-log-file |
|
Also @patil-suraj @pcuenca @anton-l what do you think here? |
|
|
||
| import PIL | ||
| from PIL import Image | ||
| from transformers import BatchEncoding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please move this under line 34 (below is_transformers_available()) ? Otherwise this breaks the init as transformers is not a hard dependency
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anton-l @patil-suraj would love to hear your opinion here.
I would have preferred to just throw a warning given the goal of this PR, but I also see why it could make sense to return text_embeddings - what do you think?
@shirayu - either way we can only return the text embeddings optionally (add a return_embeddings=True/False flag) to the __call__ to not break the outputs.
Overall, I'm leaning towards not adding this functionality though as it will add another argument to the __call__ API and another output to stable diffusion for IMO quite an edge case.
|
Keen to hear you thoughs here @shirayu |
|
Thanks for the PR @shirayu ! Returning |
|
Thank you for your comments. Note, this will drop the last special token |
…ce#447) (huggingface#472) * Return encoded texts by DiffusionPipelines * Updated README to show hot to use enoded_text_input * Reverted examples in README.md * Reverted all * Warning for long prompts * Fix bugs * Formatted
…ce#447) (huggingface#472) * Return encoded texts by DiffusionPipelines * Updated README to show hot to use enoded_text_input * Reverted examples in README.md * Reverted all * Warning for long prompts * Fix bugs * Formatted
Uh oh!
There was an error while loading. Please reload this page.