Skip to content

Commit 30a933e

Browse files
SkyTNTPrathik Rao
authored andcommitted
[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffusion (huggingface#925)
[Community Pipelines] fix pad_tokens_and_weights in lpw_stable_diffusion
1 parent 8409fba commit 30a933e

File tree

2 files changed

+90
-34
lines changed

2 files changed

+90
-34
lines changed

examples/community/lpw_stable_diffusion.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
132132
"""
133133
tokens = []
134134
weights = []
135+
truncated = False
135136
for text in prompt:
136137
texts_and_weights = parse_prompt_attention(text)
137138
text_token = []
@@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
140141
# tokenize and discard the starting and the ending token
141142
token = pipe.tokenizer(word).input_ids[1:-1]
142143
text_token += token
143-
144144
# copy the weight by length of token
145145
text_weight += [weight] * len(token)
146-
147146
# stop if the text is too long (longer than truncation limit)
148147
if len(text_token) > max_length:
148+
truncated = True
149149
break
150-
151150
# truncate
152151
if len(text_token) > max_length:
152+
truncated = True
153153
text_token = text_token[:max_length]
154154
text_weight = text_weight[:max_length]
155-
156155
tokens.append(text_token)
157156
weights.append(text_weight)
157+
if truncated:
158+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
158159
return tokens, weights
159160

160161

@@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
173174
if len(weights[i]) == 0:
174175
w = [1.0] * weights_length
175176
else:
176-
for j in range((len(weights[i]) - 1) // chunk_length + 1):
177+
for j in range(max_embeddings_multiples):
177178
w.append(1.0) # weight for starting token in this chunk
178-
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
179+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
179180
w.append(1.0) # weight for ending token in this chunk
180181
w += [1.0] * (weights_length - len(w))
181182
weights[i] = w[:]
@@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
184185

185186

186187
def get_unweighted_text_embeddings(
187-
pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
188+
pipe: DiffusionPipeline,
189+
text_input: torch.Tensor,
190+
chunk_length: int,
191+
no_boseos_middle: Optional[bool] = True,
188192
):
189193
"""
190194
When the length of tokens is a multiple of the capacity of the text encoder,
@@ -285,7 +289,8 @@ def get_weighted_text_embeddings(
285289
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
286290

287291
max_embeddings_multiples = min(
288-
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
292+
max_embeddings_multiples,
293+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
289294
)
290295
max_embeddings_multiples = max(1, max_embeddings_multiples)
291296
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -317,12 +322,18 @@ def get_weighted_text_embeddings(
317322

318323
# get the embeddings
319324
text_embeddings = get_unweighted_text_embeddings(
320-
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
325+
pipe,
326+
prompt_tokens,
327+
pipe.tokenizer.model_max_length,
328+
no_boseos_middle=no_boseos_middle,
321329
)
322330
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
323331
if uncond_prompt is not None:
324332
uncond_embeddings = get_unweighted_text_embeddings(
325-
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
333+
pipe,
334+
uncond_tokens,
335+
pipe.tokenizer.model_max_length,
336+
no_boseos_middle=no_boseos_middle,
326337
)
327338
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
328339

@@ -632,16 +643,29 @@ def __call__(
632643
# Unlike in other pipelines, latents need to be generated in the target device
633644
# for 1-to-1 results reproducibility with the CompVis implementation.
634645
# However this currently doesn't work in `mps`.
635-
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
646+
latents_shape = (
647+
batch_size * num_images_per_prompt,
648+
self.unet.in_channels,
649+
height // 8,
650+
width // 8,
651+
)
636652

637653
if latents is None:
638654
if self.device.type == "mps":
639655
# randn does not exist on mps
640-
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
641-
self.device
642-
)
656+
latents = torch.randn(
657+
latents_shape,
658+
generator=generator,
659+
device="cpu",
660+
dtype=latents_dtype,
661+
).to(self.device)
643662
else:
644-
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
663+
latents = torch.randn(
664+
latents_shape,
665+
generator=generator,
666+
device=self.device,
667+
dtype=latents_dtype,
668+
)
645669
else:
646670
if latents.shape != latents_shape:
647671
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -684,11 +708,19 @@ def __call__(
684708
# add noise to latents using the timesteps
685709
if self.device.type == "mps":
686710
# randn does not exist on mps
687-
noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to(
688-
self.device
689-
)
711+
noise = torch.randn(
712+
init_latents.shape,
713+
generator=generator,
714+
device="cpu",
715+
dtype=latents_dtype,
716+
).to(self.device)
690717
else:
691-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
718+
noise = torch.randn(
719+
init_latents.shape,
720+
generator=generator,
721+
device=self.device,
722+
dtype=latents_dtype,
723+
)
692724
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
693725

694726
t_start = max(num_inference_steps - init_timestep + offset, 0)
@@ -741,7 +773,8 @@ def __call__(
741773
self.device
742774
)
743775
image, has_nsfw_concept = self.safety_checker(
744-
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
776+
images=image,
777+
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
745778
)
746779
else:
747780
has_nsfw_concept = None

examples/community/lpw_stable_diffusion_onnx.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
130130
"""
131131
tokens = []
132132
weights = []
133+
truncated = False
133134
for text in prompt:
134135
texts_and_weights = parse_prompt_attention(text)
135136
text_token = []
@@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
138139
# tokenize and discard the starting and the ending token
139140
token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
140141
text_token += list(token)
141-
142142
# copy the weight by length of token
143143
text_weight += [weight] * len(token)
144-
145144
# stop if the text is too long (longer than truncation limit)
146145
if len(text_token) > max_length:
146+
truncated = True
147147
break
148-
149148
# truncate
150149
if len(text_token) > max_length:
150+
truncated = True
151151
text_token = text_token[:max_length]
152152
text_weight = text_weight[:max_length]
153-
154153
tokens.append(text_token)
155154
weights.append(text_weight)
155+
if truncated:
156+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
156157
return tokens, weights
157158

158159

@@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
171172
if len(weights[i]) == 0:
172173
w = [1.0] * weights_length
173174
else:
174-
for j in range((len(weights[i]) - 1) // chunk_length + 1):
175+
for j in range(max_embeddings_multiples):
175176
w.append(1.0) # weight for starting token in this chunk
176-
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
177+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
177178
w.append(1.0) # weight for ending token in this chunk
178179
w += [1.0] * (weights_length - len(w))
179180
weights[i] = w[:]
@@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
182183

183184

184185
def get_unweighted_text_embeddings(
185-
pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True
186+
pipe,
187+
text_input: np.array,
188+
chunk_length: int,
189+
no_boseos_middle: Optional[bool] = True,
186190
):
187191
"""
188192
When the length of tokens is a multiple of the capacity of the text encoder,
@@ -276,7 +280,10 @@ def get_weighted_text_embeddings(
276280
uncond_tokens = [
277281
token[1:-1]
278282
for token in pipe.tokenizer(
279-
uncond_prompt, max_length=max_length, truncation=True, return_tensors="np"
283+
uncond_prompt,
284+
max_length=max_length,
285+
truncation=True,
286+
return_tensors="np",
280287
).input_ids
281288
]
282289
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
@@ -287,7 +294,8 @@ def get_weighted_text_embeddings(
287294
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
288295

289296
max_embeddings_multiples = min(
290-
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
297+
max_embeddings_multiples,
298+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
291299
)
292300
max_embeddings_multiples = max(1, max_embeddings_multiples)
293301
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -319,12 +327,18 @@ def get_weighted_text_embeddings(
319327

320328
# get the embeddings
321329
text_embeddings = get_unweighted_text_embeddings(
322-
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
330+
pipe,
331+
prompt_tokens,
332+
pipe.tokenizer.model_max_length,
333+
no_boseos_middle=no_boseos_middle,
323334
)
324335
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
325336
if uncond_prompt is not None:
326337
uncond_embeddings = get_unweighted_text_embeddings(
327-
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
338+
pipe,
339+
uncond_tokens,
340+
pipe.tokenizer.model_max_length,
341+
no_boseos_middle=no_boseos_middle,
328342
)
329343
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
330344

@@ -559,7 +573,12 @@ def __call__(
559573
noise = None
560574

561575
if init_image is None:
562-
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
576+
latents_shape = (
577+
batch_size * num_images_per_prompt,
578+
4,
579+
height // 8,
580+
width // 8,
581+
)
563582

564583
if latents is None:
565584
latents = generator.randn(*latents_shape).astype(latents_dtype)
@@ -625,7 +644,9 @@ def __call__(
625644

626645
# predict the noise residual
627646
noise_pred = self.unet(
628-
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
647+
sample=latent_model_input,
648+
timestep=np.array([t]),
649+
encoder_hidden_states=text_embeddings,
629650
)
630651
noise_pred = noise_pred[0]
631652

@@ -640,7 +661,9 @@ def __call__(
640661
if mask is not None:
641662
# masking
642663
init_latents_proper = self.scheduler.add_noise(
643-
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
664+
torch.from_numpy(init_latents_orig),
665+
torch.from_numpy(noise),
666+
torch.tensor([t]),
644667
).numpy()
645668
latents = (init_latents_proper * mask) + (latents * (1 - mask))
646669

0 commit comments

Comments
 (0)