From 9b85f7c9dae6013fcd7eec2616935aac31a7cd9b Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Mon, 12 Sep 2022 13:30:47 +0900 Subject: [PATCH 1/7] Return encoded texts by DiffusionPipelines --- src/diffusers/pipelines/stable_diffusion/__init__.py | 4 ++++ .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 ++++-- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 6 ++++-- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 6 ++++-- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 6 ++++-- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index e41043e0ad53..f2dbfd32c2ec 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -5,6 +5,7 @@ import PIL from PIL import Image +from transformers import BatchEncoding from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available @@ -21,10 +22,13 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. + enoded_text_input (`transformers.BatchEncoding`) + Encoded text by the tokenizer """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] + enoded_text_input: BatchEncoding if is_transformers_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9f1211b43013..e9f849bb7b7a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -285,6 +285,8 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return (image, has_nsfw_concept, text_input) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input + ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e7adb4d1a33b..95de3888e174 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -297,6 +297,8 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return (image, has_nsfw_concept, text_input) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input + ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b9ad36f1a2bf..44c4715adc23 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -320,6 +320,8 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return (image, has_nsfw_concept, text_input) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input + ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index ccba29ade5d3..be6ddbcbbdb6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -155,6 +155,8 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return (image, has_nsfw_concept, text_input) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input + ) From 9b8fa560c14eb506949a8a1c6d31cb9e7526edfe Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Mon, 12 Sep 2022 20:07:32 +0900 Subject: [PATCH 2/7] Updated README to show hot to use enoded_text_input --- README.md | 51 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5a25ce501263..d65a802de452 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,15 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + pipeline_output = pipe(prompt) + image = pipeline_output.images[0] + used_prompt = pipe.tokenizer.batch_decode( + pipeline_output.enoded_text_input["input_ids"], + skip_special_tokens=True, + )[0] + if used_prompt != prompt: + # Too long prompt will be truncated + print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") ``` **Note**: If you don't want to use the token, you can also simply download the model weights @@ -105,7 +113,15 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + pipeline_output = pipe(prompt) + image = pipeline_output.images[0] + used_prompt = pipe.tokenizer.batch_decode( + pipeline_output.enoded_text_input["input_ids"], + skip_special_tokens=True, + )[0] + if used_prompt != prompt: + # Too long prompt will be truncated + print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") ``` If you are limited by GPU memory, you might want to consider using the model in `fp16` as @@ -124,7 +140,15 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() with autocast("cuda"): - image = pipe(prompt).images[0] + pipeline_output = pipe(prompt) + image = pipeline_output.images[0] + used_prompt = pipe.tokenizer.batch_decode( + pipeline_output.enoded_text_input["input_ids"], + skip_special_tokens=True, + )[0] + if used_prompt != prompt: + # Too long prompt will be truncated + print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") ``` Finally, if you wish to use a different scheduler, you can simply instantiate @@ -150,7 +174,15 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + pipeline_output = pipe(prompt) + image = pipeline_output.images[0] + used_prompt = pipe.tokenizer.batch_decode( + pipeline_output.enoded_text_input["input_ids"], + skip_special_tokens=True, + )[0] + if used_prompt != prompt: + # Too long prompt will be truncated + print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") image.save("astronaut_rides_horse.png") ``` @@ -191,8 +223,15 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images - + pipeline_output = pipe(prompt) + image = pipeline_output.images[0] + used_prompt = pipe.tokenizer.batch_decode( + pipeline_output.enoded_text_input["input_ids"], + skip_special_tokens=True, + )[0] + if used_prompt != prompt: + # Too long prompt will be truncated + print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") images[0].save("fantasy_landscape.png") ``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) From 3081070d1b4a0a0dd8169d4294a48985239f1cd6 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 17 Sep 2022 20:53:07 +0900 Subject: [PATCH 3/7] Reverted examples in README.md --- README.md | 51 ++++++--------------------------------------------- 1 file changed, 6 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index d65a802de452..240034cec237 100644 --- a/README.md +++ b/README.md @@ -84,15 +84,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - pipeline_output = pipe(prompt) - image = pipeline_output.images[0] - used_prompt = pipe.tokenizer.batch_decode( - pipeline_output.enoded_text_input["input_ids"], - skip_special_tokens=True, - )[0] - if used_prompt != prompt: - # Too long prompt will be truncated - print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") + image = pipe(prompt).images[0] ``` **Note**: If you don't want to use the token, you can also simply download the model weights @@ -113,15 +105,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - pipeline_output = pipe(prompt) - image = pipeline_output.images[0] - used_prompt = pipe.tokenizer.batch_decode( - pipeline_output.enoded_text_input["input_ids"], - skip_special_tokens=True, - )[0] - if used_prompt != prompt: - # Too long prompt will be truncated - print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") + image = pipe(prompt).images[0] ``` If you are limited by GPU memory, you might want to consider using the model in `fp16` as @@ -140,15 +124,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() with autocast("cuda"): - pipeline_output = pipe(prompt) - image = pipeline_output.images[0] - used_prompt = pipe.tokenizer.batch_decode( - pipeline_output.enoded_text_input["input_ids"], - skip_special_tokens=True, - )[0] - if used_prompt != prompt: - # Too long prompt will be truncated - print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") + image = pipe(prompt).images[0] ``` Finally, if you wish to use a different scheduler, you can simply instantiate @@ -174,15 +150,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - pipeline_output = pipe(prompt) - image = pipeline_output.images[0] - used_prompt = pipe.tokenizer.batch_decode( - pipeline_output.enoded_text_input["input_ids"], - skip_special_tokens=True, - )[0] - if used_prompt != prompt: - # Too long prompt will be truncated - print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -223,15 +191,8 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" with autocast("cuda"): - pipeline_output = pipe(prompt) - image = pipeline_output.images[0] - used_prompt = pipe.tokenizer.batch_decode( - pipeline_output.enoded_text_input["input_ids"], - skip_special_tokens=True, - )[0] - if used_prompt != prompt: - # Too long prompt will be truncated - print(f"""The used prompt "{used_prompt}" differs from the given prompt "{prompt}".""") + images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images + images[0].save("fantasy_landscape.png") ``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) From c95269139cab0f488ecdf32b849a62c32716bed6 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 13:35:19 +0900 Subject: [PATCH 4/7] Reverted all --- README.md | 8 ++++---- src/diffusers/pipelines/stable_diffusion/__init__.py | 4 ---- .../stable_diffusion/pipeline_stable_diffusion.py | 6 ++---- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 6 ++---- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 6 ++---- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 6 ++---- 6 files changed, 12 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 240034cec237..5a25ce501263 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + image = pipe(prompt).images[0] ``` **Note**: If you don't want to use the token, you can also simply download the model weights @@ -105,7 +105,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + image = pipe(prompt).images[0] ``` If you are limited by GPU memory, you might want to consider using the model in `fp16` as @@ -124,7 +124,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() with autocast("cuda"): - image = pipe(prompt).images[0] + image = pipe(prompt).images[0] ``` Finally, if you wish to use a different scheduler, you can simply instantiate @@ -150,7 +150,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).images[0] + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index f2dbfd32c2ec..e41043e0ad53 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -5,7 +5,6 @@ import PIL from PIL import Image -from transformers import BatchEncoding from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available @@ -22,13 +21,10 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. - enoded_text_input (`transformers.BatchEncoding`) - Encoded text by the tokenizer """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] - enoded_text_input: BatchEncoding if is_transformers_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e9f849bb7b7a..9f1211b43013 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -285,8 +285,6 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept, text_input) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 95de3888e174..e7adb4d1a33b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -297,8 +297,6 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept, text_input) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 44c4715adc23..b9ad36f1a2bf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -320,8 +320,6 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept, text_input) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index be6ddbcbbdb6..ccba29ade5d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -155,8 +155,6 @@ def __call__( image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept, text_input) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept, enoded_text_input=text_input - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 006e3aef4aa9d30986ef755c12d48b5740697d0b Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 13:47:14 +0900 Subject: [PATCH 5/7] Warning for long prompts --- .../pipeline_stable_diffusion.py | 15 +++++++++++++-- .../pipeline_stable_diffusion_img2img.py | 15 +++++++++++++-- .../pipeline_stable_diffusion_inpaint.py | 12 ++++++++++-- .../pipeline_stable_diffusion_onnx.py | 17 ++++++++++++++--- 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9f1211b43013..1888543af9af 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -12,6 +12,10 @@ from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +from .utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionPipeline(DiffusionPipeline): @@ -188,13 +192,20 @@ def __call__( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_input = self.tokenizer( + text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e7adb4d1a33b..194d5e13e021 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -14,6 +14,10 @@ from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +from .utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): @@ -216,13 +220,20 @@ def __call__( init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) # get prompt text embeddings - text_input = self.tokenizer( + text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b9ad36f1a2bf..c2380035894b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -16,6 +16,7 @@ from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +from .utils import logging logger = logging.get_logger(__name__) @@ -249,13 +250,20 @@ def __call__( init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # get prompt text embeddings - text_input = self.tokenizer( + text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index ccba29ade5d3..f40d427850c3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -9,6 +9,10 @@ from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from . import StableDiffusionPipelineOutput +from .utils import logging + + +logger = logging.get_logger(__name__) class StableDiffusionOnnxPipeline(DiffusionPipeline): @@ -66,13 +70,20 @@ def __call__( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_input = self.tokenizer( + text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) From 6448321eae79454a4b3d9116459ff9864bcf0562 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 13:56:43 +0900 Subject: [PATCH 6/7] Fix bugs --- .../stable_diffusion/pipeline_stable_diffusion.py | 10 +++++----- .../pipeline_stable_diffusion_img2img.py | 10 +++++----- .../pipeline_stable_diffusion_inpaint.py | 9 ++++----- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 10 +++++----- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1888543af9af..307dd27ec46b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -10,9 +10,9 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker -from .utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -201,12 +201,12 @@ def __call__( text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -214,7 +214,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 194d5e13e021..634360cff4e6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -12,9 +12,9 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker -from .utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -229,12 +229,12 @@ def __call__( text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -242,7 +242,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c2380035894b..5b22b3a4f889 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -16,7 +16,6 @@ from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker -from .utils import logging logger = logging.get_logger(__name__) @@ -259,12 +258,12 @@ def __call__( text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -272,7 +271,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index f40d427850c3..31d3c645490b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -8,8 +8,8 @@ from ...onnx_utils import OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging from . import StableDiffusionPipelineOutput -from .utils import logging logger = logging.get_logger(__name__) @@ -79,12 +79,12 @@ def __call__( text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[self.tokenizer_model_max_length :]) + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer_model_max_length} tokens: {removed_text}" + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -92,7 +92,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ) From d651eaa3b4a5ed174e3a543d9e5802620acf694c Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 14:25:56 +0900 Subject: [PATCH 7/7] Formatted --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 ++- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 3 ++- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 3 ++- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 307dd27ec46b..12c629d66cd6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -203,7 +203,8 @@ def __call__( if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 634360cff4e6..200b84736659 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -231,7 +231,8 @@ def __call__( if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 5b22b3a4f889..33d96fae1b44 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -260,7 +260,8 @@ def __call__( if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 31d3c645490b..ba09f7274cc6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -81,7 +81,8 @@ def __call__( if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( - f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]