@@ -129,19 +129,59 @@ def export(self, export_dir: Optional[str] = None) -> str:
129
129
130
130
print ("###################### TEXT ENCODER 2 EXPORTED ######################" )
131
131
132
- # # T5 TEXT ENCODER
133
- # example_inputs = {"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64)}
132
+ # T5 TEXT ENCODER
133
+ example_inputs = {"input_ids" : torch .zeros ((bs , seq_len ), dtype = torch .int64 )}
134
134
135
- # dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
135
+ dynamic_axes = {"input_ids" : {0 : "batch_size" , 1 : "seq_len" }}
136
+
137
+ output_names = ["last_hidden_state" ]
138
+
139
+ ## Changes for the testing ####
140
+ wo_sfs = [
141
+ 61 ,
142
+ 203 ,
143
+ 398 ,
144
+ 615 ,
145
+ 845 ,
146
+ 1190 ,
147
+ 1402 ,
148
+ 2242 ,
149
+ 1875 ,
150
+ 2393 ,
151
+ 3845 ,
152
+ 3213 ,
153
+ 3922 ,
154
+ 4429 ,
155
+ 5020 ,
156
+ 5623 ,
157
+ 6439 ,
158
+ 6206 ,
159
+ 5165 ,
160
+ 4593 ,
161
+ 2802 ,
162
+ 2618 ,
163
+ 1891 ,
164
+ 1419 ,
165
+ ]
166
+
167
+ assert len (wo_sfs ) == 24
168
+ with torch .no_grad ():
169
+ prev_sf = 1
170
+ for i in range (len (self .text_encoder_3 .model .encoder .block )):
171
+ wosf = wo_sfs [i ]
172
+ self .text_encoder_3 .model .encoder .block [i ].layer [0 ].SelfAttention .o .weight *= 1 / wosf
173
+ self .text_encoder_3 .model .encoder .block [i ].layer [0 ].scaling_factor *= prev_sf / wosf
174
+ self .text_encoder_3 .model .encoder .block [i ].layer [1 ].DenseReluDense .wo .weight *= 1 / wosf
175
+ prev_sf = wosf
136
176
137
- # output_names = ["last_hidden_state"]
177
+ ### End ####
138
178
139
- # self.text_encoder_3_onnx_path = self.text_encoder_3.export(
140
- # inputs=example_inputs,
141
- # output_names=output_names,
142
- # dynamic_axes=dynamic_axes,
143
- # export_dir=export_dir,
144
- # )
179
+ self .text_encoder_3_onnx_path = self .text_encoder_3 .export (
180
+ inputs = example_inputs ,
181
+ output_names = output_names ,
182
+ dynamic_axes = dynamic_axes ,
183
+ export_dir = export_dir ,
184
+ )
145
185
146
186
print ("###################### TEXT ENCODER 3 EXPORTED ######################" )
147
187
@@ -267,23 +307,23 @@ def compile(
267
307
print ("###################### Text Encoder 2 Compiled #####################" )
268
308
269
309
# # Compile text_encoder 3
270
- # seq_len= 256
271
-
272
- # specializations = [
273
- # {"batch_size": batch_size, "seq_len": seq_len},
274
- # ]
275
-
276
- # self.text_encoder_3_compile_path= self.text_encoder_3._compile(
277
- # onnx_path,
278
- # compile_dir,
279
- # compile_only=True,
280
- # specializations=specializations,
281
- # convert_to_fp16=True,
282
- # mxfp6_matmul=mxfp6_matmul,
283
- # mdp_ts_num_devices=num_devices_text_encoder,
284
- # aic_num_cores=num_cores,
285
- # **compiler_options,
286
- # )
310
+ seq_len = 256
311
+
312
+ specializations = [
313
+ {"batch_size" : batch_size , "seq_len" : seq_len },
314
+ ]
315
+
316
+ self .text_encoder_3_compile_path = self .text_encoder_3 ._compile (
317
+ onnx_path ,
318
+ compile_dir ,
319
+ compile_only = True ,
320
+ specializations = specializations ,
321
+ convert_to_fp16 = True ,
322
+ mxfp6_matmul = mxfp6_matmul ,
323
+ mdp_ts_num_devices = num_devices_text_encoder ,
324
+ aic_num_cores = num_cores ,
325
+ ** compiler_options ,
326
+ )
287
327
print ("###################### Text Encoder 3 Compiled #####################" )
288
328
289
329
# Compile transformer
@@ -331,6 +371,7 @@ def compile(
331
371
convert_to_fp16 = True ,
332
372
mdp_ts_num_devices = num_devices_vae_decoder ,
333
373
)
374
+ print ("###################### vae_decoder Compiled #####################" )
334
375
335
376
def _get_clip_prompt_embeds (
336
377
self ,
@@ -480,12 +521,27 @@ def _get_t5_prompt_embeds(
480
521
"The following part of your input was truncated because `max_sequence_length` is set to "
481
522
f" { max_sequence_length } tokens: { removed_text } "
482
523
)
483
- # if self.text_encoder_3.qpc_session is None:
484
- # self.text_encoder_3.qpc_session = QAICInferenceSession(str(self.text_encoder_3_compile_path))
524
+ if self .text_encoder_3 .qpc_session is None :
525
+ self .text_encoder_3 .qpc_session = QAICInferenceSession (str (self .text_encoder_3_compile_path ))
485
526
486
527
prompt_embeds = self .text_encoder_3 .model (text_input_ids .to (device ))[0 ]
487
- # aic_text_input={"input_ids": text_input_ids.numpy().astype(np.int64)}
488
- # aic_embeddings= self.text_encoder_3.qpc_session.run(aic_text_input)
528
+ aic_text_input = {"input_ids" : text_input_ids .numpy ().astype (np .int64 )}
529
+ aic_embeddings = torch .tensor (self .text_encoder_3 .qpc_session .run (aic_text_input )["last_hidden_state" ])
530
+ mad = torch .abs (prompt_embeds - aic_embeddings ).mean ()
531
+ print ("Clip text-encoder-3 Pytorch vs AI 100:" , mad )
532
+ prompt_embeds = aic_embeddings
533
+
534
+ # import onnxruntime as ort
535
+ # ort_session=ort.InferenceSession(self.text_encoder_3_onnx_path)
536
+ # input_names = [input.name for input in ort_session.get_inputs()]
537
+ # output_names = [output.name for output in ort_session.get_outputs()]
538
+ # inputs={input_names[0]: text_input_ids.numpy()}
539
+ # output=ort_session.run(output_names, inputs)
540
+ # prompt_embeds_ort = torch.from_numpy(output[0])
541
+
542
+ # # mad between promp_embed and prompt_embed_ort
543
+ # mad=torch.abs(prompt_embeds-prompt_embeds_ort).mean()
544
+ # print("mad between ort and pytorch", mad)
489
545
490
546
_ , seq_len , _ = prompt_embeds .shape
491
547
@@ -623,16 +679,32 @@ def __call__(
623
679
callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
624
680
callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
625
681
max_sequence_length : int = 256 ,
626
- sigmas : Optional [List [float ]] = None ,
627
- skip_guidance_layers : List [int ] = None ,
628
- skip_layer_guidance_scale : float = 2.8 ,
629
- skip_layer_guidance_stop : float = 0.2 ,
630
- skip_layer_guidance_start : float = 0.01 ,
631
- mu : Optional [float ] = None ,
632
- vae_type = "vae" ,
633
682
):
634
683
height = height or self .default_sample_size * self .vae_scale_factor
635
684
width = width or self .default_sample_size * self .vae_scale_factor
685
+ device = "cpu"
686
+
687
+ self .check_inputs (
688
+ prompt ,
689
+ prompt_2 ,
690
+ prompt_3 ,
691
+ height ,
692
+ width ,
693
+ negative_prompt = negative_prompt ,
694
+ negative_prompt_2 = negative_prompt_2 ,
695
+ negative_prompt_3 = negative_prompt_3 ,
696
+ prompt_embeds = prompt_embeds ,
697
+ negative_prompt_embeds = negative_prompt_embeds ,
698
+ pooled_prompt_embeds = pooled_prompt_embeds ,
699
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
700
+ callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
701
+ max_sequence_length = max_sequence_length ,
702
+ )
703
+
704
+ self ._guidance_scale = guidance_scale
705
+ self ._clip_skip = clip_skip
706
+ self ._joint_attention_kwargs = joint_attention_kwargs
707
+ self ._interrupt = False
636
708
637
709
(
638
710
prompt_embeds ,
@@ -654,11 +726,6 @@ def __call__(
654
726
max_sequence_length = max_sequence_length ,
655
727
)
656
728
657
- self ._guidance_scale = guidance_scale
658
- self ._clip_skip = clip_skip
659
- self ._joint_attention_kwargs = joint_attention_kwargs
660
- self ._interrupt = False
661
-
662
729
# 2. Define call parameters
663
730
if prompt is not None and isinstance (prompt , str ):
664
731
batch_size = 1
@@ -667,34 +734,28 @@ def __call__(
667
734
else :
668
735
batch_size = prompt_embeds .shape [0 ]
669
736
670
- prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
671
- pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
737
+ if self .do_classifier_free_guidance :
738
+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
739
+ pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
672
740
673
- # 4. Prepare latent variables
741
+ # 4. Prepare timesteps
742
+ timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
743
+ num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
744
+ self ._num_timesteps = len (timesteps )
745
+
746
+ # 5. Prepare latent variables
674
747
num_channels_latents = self .transformer .model .config .in_channels
675
748
latents = self .prepare_latents (
676
749
batch_size * num_images_per_prompt ,
677
750
num_channels_latents ,
678
751
height ,
679
752
width ,
680
753
prompt_embeds .dtype ,
681
- "cpu" ,
754
+ device ,
682
755
generator ,
683
756
latents ,
684
757
)
685
758
686
- # 5. Prepare timesteps
687
- scheduler_kwargs = {}
688
- timesteps , num_inference_steps = retrieve_timesteps (
689
- self .scheduler ,
690
- num_inference_steps ,
691
- "cpu" ,
692
- sigmas = sigmas ,
693
- ** scheduler_kwargs ,
694
- )
695
- num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
696
- self ._num_timesteps = len (timesteps )
697
-
698
759
###### AIC related changes of transformers ######
699
760
if self .transformer .qpc_session is None :
700
761
self .transformer .qpc_session = QAICInferenceSession (str (self .transformer .qpc_path ))
0 commit comments