@@ -191,6 +191,7 @@ def __init__(
191191 transformer : FluxTransformer2DModel ,
192192 image_encoder : CLIPVisionModelWithProjection = None ,
193193 feature_extractor : CLIPImageProcessor = None ,
194+ variant : str = "flux" ,
194195 ):
195196 super ().__init__ ()
196197
@@ -213,6 +214,17 @@ def __init__(
213214 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
214215 )
215216 self .default_sample_size = 128
217+ if variant not in {"flux" , "chroma" }:
218+ raise ValueError ("`variant` must be `'flux' or `'chroma'`." )
219+
220+ self .variant = variant
221+
222+ def _get_chroma_attn_mask (self , length : torch .Tensor , max_sequence_length : int ) -> torch .Tensor :
223+ attention_mask = torch .zeros ((length .shape [0 ], max_sequence_length ), dtype = torch .bool , device = length .device )
224+ for i , n_tokens in enumerate (length ):
225+ n_tokens = torch .max (n_tokens + 1 , max_sequence_length )
226+ attention_mask [i , :n_tokens ] = True
227+ return attention_mask
216228
217229 def _get_t5_prompt_embeds (
218230 self ,
@@ -236,7 +248,7 @@ def _get_t5_prompt_embeds(
236248 padding = "max_length" ,
237249 max_length = max_sequence_length ,
238250 truncation = True ,
239- return_length = False ,
251+ return_length = ( self . variant == "chroma" ) ,
240252 return_overflowing_tokens = False ,
241253 return_tensors = "pt" ,
242254 )
@@ -250,7 +262,15 @@ def _get_t5_prompt_embeds(
250262 f" { max_sequence_length } tokens: { removed_text } "
251263 )
252264
253- prompt_embeds = self .text_encoder_2 (text_input_ids .to (device ), output_hidden_states = False )[0 ]
265+ prompt_embeds = self .text_encoder_2 (
266+ text_input_ids .to (device ),
267+ output_hidden_states = False ,
268+ attention_mask = (
269+ self ._get_chroma_attn_mask (text_inputs .length , max_sequence_length ).to (device )
270+ if self .variant == "chroma"
271+ else None
272+ ),
273+ )[0 ]
254274
255275 dtype = self .text_encoder_2 .dtype
256276 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
0 commit comments