Skip to content

Commit 5ac1f61

Browse files
Add an argument "negative_prompt" (#549)
* Add an argument "negative_prompt" * Fix argument order * Fix to use TypeError instead of ValueError * Removed needless batch_size multiplying * Fix to multiply by batch_size * Add truncation=True for long negative prompt * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py Co-authored-by: Patrick von Platen <[email protected]> * Fix styles * Renamed ucond_tokens to uncond_tokens * Added description about "negative_prompt" Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7e92c5b commit 5ac1f61

File tree

4 files changed

+105
-4
lines changed

4 files changed

+105
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __call__(
116116
width: int = 512,
117117
num_inference_steps: int = 50,
118118
guidance_scale: float = 7.5,
119+
negative_prompt: Optional[Union[str, List[str]]] = None,
119120
eta: float = 0.0,
120121
generator: Optional[torch.Generator] = None,
121122
latents: Optional[torch.FloatTensor] = None,
@@ -144,6 +145,9 @@ def __call__(
144145
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
145146
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
146147
usually at the expense of lower image quality.
148+
negative_prompt (`str` or `List[str]`, *optional*):
149+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
150+
if `guidance_scale` is less than `1`).
147151
eta (`float`, *optional*, defaults to 0.0):
148152
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
149153
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -217,9 +221,32 @@ def __call__(
217221
do_classifier_free_guidance = guidance_scale > 1.0
218222
# get unconditional embeddings for classifier free guidance
219223
if do_classifier_free_guidance:
224+
uncond_tokens: List[str]
225+
if negative_prompt is None:
226+
uncond_tokens = [""] * batch_size
227+
elif type(prompt) is not type(negative_prompt):
228+
raise TypeError(
229+
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
230+
" {type(prompt)}."
231+
)
232+
elif isinstance(negative_prompt, str):
233+
uncond_tokens = [negative_prompt] * batch_size
234+
elif batch_size != len(negative_prompt):
235+
raise ValueError(
236+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
237+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
238+
" the batch size of `prompt`."
239+
)
240+
else:
241+
uncond_tokens = negative_prompt
242+
220243
max_length = text_input_ids.shape[-1]
221244
uncond_input = self.tokenizer(
222-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
245+
uncond_tokens,
246+
padding="max_length",
247+
max_length=max_length,
248+
truncation=True,
249+
return_tensors="pt",
223250
)
224251
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
225252

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __call__(
128128
strength: float = 0.8,
129129
num_inference_steps: Optional[int] = 50,
130130
guidance_scale: Optional[float] = 7.5,
131+
negative_prompt: Optional[Union[str, List[str]]] = None,
131132
eta: Optional[float] = 0.0,
132133
generator: Optional[torch.Generator] = None,
133134
output_type: Optional[str] = "pil",
@@ -160,6 +161,9 @@ def __call__(
160161
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
161162
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
162163
usually at the expense of lower image quality.
164+
negative_prompt (`str` or `List[str]`, *optional*):
165+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
166+
if `guidance_scale` is less than `1`).
163167
eta (`float`, *optional*, defaults to 0.0):
164168
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
165169
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -258,9 +262,28 @@ def __call__(
258262
do_classifier_free_guidance = guidance_scale > 1.0
259263
# get unconditional embeddings for classifier free guidance
260264
if do_classifier_free_guidance:
265+
uncond_tokens: List[str]
266+
if negative_prompt is None:
267+
uncond_tokens = [""] * batch_size
268+
elif type(prompt) is not type(negative_prompt):
269+
raise TypeError(
270+
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
271+
" {type(prompt)}."
272+
)
273+
elif isinstance(negative_prompt, str):
274+
uncond_tokens = [negative_prompt] * batch_size
275+
elif batch_size != len(negative_prompt):
276+
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
277+
else:
278+
uncond_tokens = negative_prompt
279+
261280
max_length = text_input_ids.shape[-1]
262281
uncond_input = self.tokenizer(
263-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
282+
uncond_tokens,
283+
padding="max_length",
284+
max_length=max_length,
285+
truncation=True,
286+
return_tensors="pt",
264287
)
265288
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
266289

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __call__(
144144
strength: float = 0.8,
145145
num_inference_steps: Optional[int] = 50,
146146
guidance_scale: Optional[float] = 7.5,
147+
negative_prompt: Optional[Union[str, List[str]]] = None,
147148
eta: Optional[float] = 0.0,
148149
generator: Optional[torch.Generator] = None,
149150
output_type: Optional[str] = "pil",
@@ -180,6 +181,9 @@ def __call__(
180181
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
181182
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
182183
usually at the expense of lower image quality.
184+
negative_prompt (`str` or `List[str]`, *optional*):
185+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
186+
if `guidance_scale` is less than `1`).
183187
eta (`float`, *optional*, defaults to 0.0):
184188
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
185189
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -292,9 +296,32 @@ def __call__(
292296
do_classifier_free_guidance = guidance_scale > 1.0
293297
# get unconditional embeddings for classifier free guidance
294298
if do_classifier_free_guidance:
299+
uncond_tokens: List[str]
300+
if negative_prompt is None:
301+
uncond_tokens = [""] * batch_size
302+
elif type(prompt) is not type(negative_prompt):
303+
raise TypeError(
304+
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
305+
" {type(prompt)}."
306+
)
307+
elif isinstance(negative_prompt, str):
308+
uncond_tokens = [negative_prompt] * batch_size
309+
elif batch_size != len(negative_prompt):
310+
raise ValueError(
311+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
312+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
313+
" the batch size of `prompt`."
314+
)
315+
else:
316+
uncond_tokens = negative_prompt
317+
295318
max_length = text_input_ids.shape[-1]
296319
uncond_input = self.tokenizer(
297-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
320+
uncond_tokens,
321+
padding="max_length",
322+
max_length=max_length,
323+
truncation=True,
324+
return_tensors="pt",
298325
)
299326
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
300327

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __call__(
5252
width: Optional[int] = 512,
5353
num_inference_steps: Optional[int] = 50,
5454
guidance_scale: Optional[float] = 7.5,
55+
negative_prompt: Optional[Union[str, List[str]]] = None,
5556
eta: Optional[float] = 0.0,
5657
latents: Optional[np.ndarray] = None,
5758
output_type: Optional[str] = "pil",
@@ -102,9 +103,32 @@ def __call__(
102103
do_classifier_free_guidance = guidance_scale > 1.0
103104
# get unconditional embeddings for classifier free guidance
104105
if do_classifier_free_guidance:
106+
uncond_tokens: List[str]
107+
if negative_prompt is None:
108+
uncond_tokens = [""] * batch_size
109+
elif type(prompt) is not type(negative_prompt):
110+
raise TypeError(
111+
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
112+
" {type(prompt)}."
113+
)
114+
elif isinstance(negative_prompt, str):
115+
uncond_tokens = [negative_prompt] * batch_size
116+
elif batch_size != len(negative_prompt):
117+
raise ValueError(
118+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
119+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
120+
" the batch size of `prompt`."
121+
)
122+
else:
123+
uncond_tokens = negative_prompt
124+
105125
max_length = text_input_ids.shape[-1]
106126
uncond_input = self.tokenizer(
107-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
127+
uncond_tokens,
128+
padding="max_length",
129+
max_length=max_length,
130+
truncation=True,
131+
return_tensors="np",
108132
)
109133
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
110134

0 commit comments

Comments
 (0)