Skip to content

Commit dd35f1b

Browse files
committed
up
1 parent dfad4a2 commit dd35f1b

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
162162
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
163163
"""
164164

165-
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
165+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
166166
_optional_components = []
167167
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
168168

@@ -193,11 +193,7 @@ def __init__(
193193
self.tokenizer_max_length = (
194194
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
195195
)
196-
self.default_sample_size = (
197-
self.transformer.config.sample_size
198-
if hasattr(self, "transformer") and self.transformer is not None
199-
else 128
200-
)
196+
self.default_sample_size = 128
201197

202198
def _get_t5_prompt_embeds(
203199
self,
@@ -213,7 +209,7 @@ def _get_t5_prompt_embeds(
213209
prompt = [prompt] if isinstance(prompt, str) else prompt
214210
batch_size = len(prompt)
215211

216-
if self.text_encoder_3 is None:
212+
if self.text_encoder_2 is None:
217213
return torch.zeros(
218214
(
219215
batch_size * num_images_per_prompt,
@@ -224,7 +220,7 @@ def _get_t5_prompt_embeds(
224220
dtype=dtype,
225221
)
226222

227-
text_inputs = self.tokenizer_3(
223+
text_inputs = self.tokenizer_2(
228224
prompt,
229225
padding="max_length",
230226
max_length=max_sequence_length,
@@ -233,18 +229,18 @@ def _get_t5_prompt_embeds(
233229
return_tensors="pt",
234230
)
235231
text_input_ids = text_inputs.input_ids
236-
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
232+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
237233

238234
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
239-
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
235+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
240236
logger.warning(
241237
"The following part of your input was truncated because `max_sequence_length` is set to "
242238
f" {max_sequence_length} tokens: {removed_text}"
243239
)
244240

245-
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
241+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device))[0]
246242

247-
dtype = self.text_encoder_3.dtype
243+
dtype = self.text_encoder_2.dtype
248244
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
249245

250246
_, seq_len, _ = prompt_embeds.shape
@@ -424,7 +420,6 @@ def encode_prompt(
424420
device=device,
425421
num_images_per_prompt=num_images_per_prompt,
426422
clip_skip=None,
427-
clip_model_index=0,
428423
)
429424
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
430425
prompt=negative_prompt_2,
@@ -553,9 +548,8 @@ def prepare_latents(
553548
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
554549
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
555550
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
556-
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 0)
557-
558551
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
552+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
559553
latent_image_ids = latent_image_ids.reshape(
560554
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
561555
)
@@ -787,11 +781,12 @@ def __call__(
787781

788782
noise_pred = self.transformer(
789783
hidden_states=latent_model_input,
790-
timestep=timestep,
791-
encoder_hidden_states=prompt_embeds,
792-
t5_hidden_states=t5_prompt_embeds,
793-
text_ids=text_ids,
794-
latent_image_ids=latent_image_ids,
784+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
785+
timestep=timestep/1000, #
786+
pooled_projections=prompt_embeds,
787+
encoder_hidden_states=t5_prompt_embeds,
788+
txt_ids=text_ids,
789+
img_ids=latent_image_ids,
795790
joint_attention_kwargs=self.joint_attention_kwargs,
796791
return_dict=False,
797792
)[0]

0 commit comments

Comments
 (0)