@@ -162,7 +162,7 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
162162 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
163163 """
164164
165- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3-> transformer->vae"
165+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
166166 _optional_components = []
167167 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
168168
@@ -193,11 +193,7 @@ def __init__(
193193 self .tokenizer_max_length = (
194194 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
195195 )
196- self .default_sample_size = (
197- self .transformer .config .sample_size
198- if hasattr (self , "transformer" ) and self .transformer is not None
199- else 128
200- )
196+ self .default_sample_size = 128
201197
202198 def _get_t5_prompt_embeds (
203199 self ,
@@ -213,7 +209,7 @@ def _get_t5_prompt_embeds(
213209 prompt = [prompt ] if isinstance (prompt , str ) else prompt
214210 batch_size = len (prompt )
215211
216- if self .text_encoder_3 is None :
212+ if self .text_encoder_2 is None :
217213 return torch .zeros (
218214 (
219215 batch_size * num_images_per_prompt ,
@@ -224,7 +220,7 @@ def _get_t5_prompt_embeds(
224220 dtype = dtype ,
225221 )
226222
227- text_inputs = self .tokenizer_3 (
223+ text_inputs = self .tokenizer_2 (
228224 prompt ,
229225 padding = "max_length" ,
230226 max_length = max_sequence_length ,
@@ -233,18 +229,18 @@ def _get_t5_prompt_embeds(
233229 return_tensors = "pt" ,
234230 )
235231 text_input_ids = text_inputs .input_ids
236- untruncated_ids = self .tokenizer_3 (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
232+ untruncated_ids = self .tokenizer_2 (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
237233
238234 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
239- removed_text = self .tokenizer_3 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
235+ removed_text = self .tokenizer_2 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
240236 logger .warning (
241237 "The following part of your input was truncated because `max_sequence_length` is set to "
242238 f" { max_sequence_length } tokens: { removed_text } "
243239 )
244240
245- prompt_embeds = self .text_encoder_3 (text_input_ids .to (device ))[0 ]
241+ prompt_embeds = self .text_encoder_2 (text_input_ids .to (device ))[0 ]
246242
247- dtype = self .text_encoder_3 .dtype
243+ dtype = self .text_encoder_2 .dtype
248244 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
249245
250246 _ , seq_len , _ = prompt_embeds .shape
@@ -424,7 +420,6 @@ def encode_prompt(
424420 device = device ,
425421 num_images_per_prompt = num_images_per_prompt ,
426422 clip_skip = None ,
427- clip_model_index = 0 ,
428423 )
429424 t5_negative_prompt_embed = self ._get_t5_prompt_embeds (
430425 prompt = negative_prompt_2 ,
@@ -553,9 +548,8 @@ def prepare_latents(
553548 latent_image_ids = torch .zeros (height // 2 , width // 2 , 3 )
554549 latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height // 2 )[:, None ]
555550 latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width // 2 )[None , :]
556- latent_image_ids = latent_image_ids [None , :].repeat (batch_size , 0 )
557-
558551 latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
552+ latent_image_ids = latent_image_ids [None , :].repeat (batch_size , 1 , 1 , 1 )
559553 latent_image_ids = latent_image_ids .reshape (
560554 batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
561555 )
@@ -787,11 +781,12 @@ def __call__(
787781
788782 noise_pred = self .transformer (
789783 hidden_states = latent_model_input ,
790- timestep = timestep ,
791- encoder_hidden_states = prompt_embeds ,
792- t5_hidden_states = t5_prompt_embeds ,
793- text_ids = text_ids ,
794- latent_image_ids = latent_image_ids ,
784+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
785+ timestep = timestep / 1000 , #
786+ pooled_projections = prompt_embeds ,
787+ encoder_hidden_states = t5_prompt_embeds ,
788+ txt_ids = text_ids ,
789+ img_ids = latent_image_ids ,
795790 joint_attention_kwargs = self .joint_attention_kwargs ,
796791 return_dict = False ,
797792 )[0 ]
0 commit comments