Skip to content

Commit 2779920

Browse files
committed
Fix-2
Signed-off-by: Amit Raj <[email protected]>
1 parent 6908f62 commit 2779920

File tree

9 files changed

+530
-230
lines changed

9 files changed

+530
-230
lines changed

QEfficient/diffusers/models/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# ----------------------------------------------------------------------------
77

88
import torch
9+
910
from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
1011

1112

QEfficient/diffusers/models/attention_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import torch
11+
1112
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
1213

1314

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
# -----------------------------------------------------------------------------
77
from typing import Tuple
88

9-
from diffusers.models.attention import JointTransformerBlock
10-
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
119
from torch import nn
1210

11+
from diffusers.models.attention import JointTransformerBlock
12+
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
13+
from diffusers.models.normalization import RMSNorm
1314
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
15+
from QEfficient.customop.rms_norm import CustomRMSNormAIC
1416
from QEfficient.diffusers.models.attention import QEffJointTransformerBlock
1517
from QEfficient.diffusers.models.attention_processor import (
1618
QEffAttention,
@@ -19,7 +21,7 @@
1921

2022

2123
class CustomOpsTransform(ModuleMappingTransform):
22-
_module_mapping = {}
24+
_module_mapping = {RMSNorm: CustomRMSNormAIC}
2325

2426

2527
class AttentionTransform(ModuleMappingTransform):

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,18 @@
1212

1313
from QEfficient.base.modeling_qeff import QEFFBaseModel
1414
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
15-
from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform
15+
from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform, CustomOpsTransform
1616
from QEfficient.transformers.models.pytorch_transforms import (
17-
CustomOpsTransform,
1817
KVCacheExternalModuleMapperTransform,
1918
KVCacheTransform,
19+
T5ModelTransform,
2020
)
2121
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform
2222
from QEfficient.utils.cache import to_hashable
2323

2424

2525
class QEffTextEncoder(QEFFBaseModel):
26-
_pytorch_transforms = [
27-
AwqToMatmulNbitsTransform,
28-
GPTQToMatmulNbitsTransform,
29-
CustomOpsTransform,
30-
KVCacheTransform,
31-
KVCacheExternalModuleMapperTransform,
32-
]
26+
_pytorch_transforms = [CustomOpsTransform, T5ModelTransform]
3327
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
3428

3529
def __init__(self, model: nn.modules):
@@ -129,8 +123,8 @@ def compile(
129123
def model_hash(self) -> str:
130124
# Compute the hash with: model_config, continuous_batching, transforms
131125
mhash = hashlib.sha256()
132-
# mhash.update(to_hashable(dict(self.model.config)))
133-
# mhash.update(to_hashable(self._transform_names()))
126+
mhash.update(to_hashable(dict(self.model.config)))
127+
mhash.update(to_hashable(self._transform_names()))
134128
mhash = mhash.hexdigest()[:16]
135129
return mhash
136130

@@ -151,8 +145,6 @@ class QEffVAE(QEFFBaseModel):
151145
AwqToMatmulNbitsTransform,
152146
GPTQToMatmulNbitsTransform,
153147
CustomOpsTransform,
154-
KVCacheTransform,
155-
KVCacheExternalModuleMapperTransform,
156148
]
157149
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
158150

@@ -192,7 +184,7 @@ def compile(
192184
def model_hash(self) -> str:
193185
# Compute the hash with: model_config, continuous_batching, transforms
194186
mhash = hashlib.sha256()
195-
# mhash.update(to_hashable(dict(self.model.config)))
187+
mhash.update(to_hashable(dict(self.model.config)))
196188
mhash.update(to_hashable(self._transform_names()))
197189
mhash.update(to_hashable(self.type))
198190
mhash = mhash.hexdigest()[:16]
@@ -215,8 +207,6 @@ class QEffSafetyChecker(QEFFBaseModel):
215207
AwqToMatmulNbitsTransform,
216208
GPTQToMatmulNbitsTransform,
217209
CustomOpsTransform,
218-
KVCacheTransform,
219-
KVCacheExternalModuleMapperTransform,
220210
]
221211
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
222212

@@ -311,8 +301,8 @@ def compile(
311301
def model_hash(self) -> str:
312302
# Compute the hash with: model_config, continuous_batching, transforms
313303
mhash = hashlib.sha256()
314-
# mhash.update(to_hashable(dict(self.model.config)))
315-
# mhash.update(to_hashable(self._transform_names()))
304+
mhash.update(to_hashable(dict(self.model.config)))
305+
mhash.update(to_hashable(self._transform_names()))
316306
mhash = mhash.hexdigest()[:16]
317307
return mhash
318308

QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py

Lines changed: 120 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,59 @@ def export(self, export_dir: Optional[str] = None) -> str:
129129

130130
print("###################### TEXT ENCODER 2 EXPORTED ######################")
131131

132-
# # T5 TEXT ENCODER
133-
# example_inputs = {"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64)}
132+
# T5 TEXT ENCODER
133+
example_inputs = {"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64)}
134134

135-
# dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
135+
dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
136+
137+
output_names = ["last_hidden_state"]
138+
139+
## Changes for the testing ####
140+
wo_sfs = [
141+
61,
142+
203,
143+
398,
144+
615,
145+
845,
146+
1190,
147+
1402,
148+
2242,
149+
1875,
150+
2393,
151+
3845,
152+
3213,
153+
3922,
154+
4429,
155+
5020,
156+
5623,
157+
6439,
158+
6206,
159+
5165,
160+
4593,
161+
2802,
162+
2618,
163+
1891,
164+
1419,
165+
]
166+
167+
assert len(wo_sfs) == 24
168+
with torch.no_grad():
169+
prev_sf = 1
170+
for i in range(len(self.text_encoder_3.model.encoder.block)):
171+
wosf = wo_sfs[i]
172+
self.text_encoder_3.model.encoder.block[i].layer[0].SelfAttention.o.weight *= 1 / wosf
173+
self.text_encoder_3.model.encoder.block[i].layer[0].scaling_factor *= prev_sf / wosf
174+
self.text_encoder_3.model.encoder.block[i].layer[1].DenseReluDense.wo.weight *= 1 / wosf
175+
prev_sf = wosf
136176

137-
# output_names = ["last_hidden_state"]
177+
### End ####
138178

139-
# self.text_encoder_3_onnx_path = self.text_encoder_3.export(
140-
# inputs=example_inputs,
141-
# output_names=output_names,
142-
# dynamic_axes=dynamic_axes,
143-
# export_dir=export_dir,
144-
# )
179+
self.text_encoder_3_onnx_path = self.text_encoder_3.export(
180+
inputs=example_inputs,
181+
output_names=output_names,
182+
dynamic_axes=dynamic_axes,
183+
export_dir=export_dir,
184+
)
145185

146186
print("###################### TEXT ENCODER 3 EXPORTED ######################")
147187

@@ -267,23 +307,23 @@ def compile(
267307
print("###################### Text Encoder 2 Compiled #####################")
268308

269309
# # Compile text_encoder 3
270-
# seq_len= 256
271-
272-
# specializations = [
273-
# {"batch_size": batch_size, "seq_len": seq_len},
274-
# ]
275-
276-
# self.text_encoder_3_compile_path=self.text_encoder_3._compile(
277-
# onnx_path,
278-
# compile_dir,
279-
# compile_only=True,
280-
# specializations=specializations,
281-
# convert_to_fp16=True,
282-
# mxfp6_matmul=mxfp6_matmul,
283-
# mdp_ts_num_devices=num_devices_text_encoder,
284-
# aic_num_cores=num_cores,
285-
# **compiler_options,
286-
# )
310+
seq_len = 256
311+
312+
specializations = [
313+
{"batch_size": batch_size, "seq_len": seq_len},
314+
]
315+
316+
self.text_encoder_3_compile_path = self.text_encoder_3._compile(
317+
onnx_path,
318+
compile_dir,
319+
compile_only=True,
320+
specializations=specializations,
321+
convert_to_fp16=True,
322+
mxfp6_matmul=mxfp6_matmul,
323+
mdp_ts_num_devices=num_devices_text_encoder,
324+
aic_num_cores=num_cores,
325+
**compiler_options,
326+
)
287327
print("###################### Text Encoder 3 Compiled #####################")
288328

289329
# Compile transformer
@@ -331,6 +371,7 @@ def compile(
331371
convert_to_fp16=True,
332372
mdp_ts_num_devices=num_devices_vae_decoder,
333373
)
374+
print("###################### vae_decoder Compiled #####################")
334375

335376
def _get_clip_prompt_embeds(
336377
self,
@@ -480,12 +521,27 @@ def _get_t5_prompt_embeds(
480521
"The following part of your input was truncated because `max_sequence_length` is set to "
481522
f" {max_sequence_length} tokens: {removed_text}"
482523
)
483-
# if self.text_encoder_3.qpc_session is None:
484-
# self.text_encoder_3.qpc_session = QAICInferenceSession(str(self.text_encoder_3_compile_path))
524+
if self.text_encoder_3.qpc_session is None:
525+
self.text_encoder_3.qpc_session = QAICInferenceSession(str(self.text_encoder_3_compile_path))
485526

486527
prompt_embeds = self.text_encoder_3.model(text_input_ids.to(device))[0]
487-
# aic_text_input={"input_ids": text_input_ids.numpy().astype(np.int64)}
488-
# aic_embeddings= self.text_encoder_3.qpc_session.run(aic_text_input)
528+
aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
529+
aic_embeddings = torch.tensor(self.text_encoder_3.qpc_session.run(aic_text_input)["last_hidden_state"])
530+
mad = torch.abs(prompt_embeds - aic_embeddings).mean()
531+
print("Clip text-encoder-3 Pytorch vs AI 100:", mad)
532+
prompt_embeds = aic_embeddings
533+
534+
# import onnxruntime as ort
535+
# ort_session=ort.InferenceSession(self.text_encoder_3_onnx_path)
536+
# input_names = [input.name for input in ort_session.get_inputs()]
537+
# output_names = [output.name for output in ort_session.get_outputs()]
538+
# inputs={input_names[0]: text_input_ids.numpy()}
539+
# output=ort_session.run(output_names, inputs)
540+
# prompt_embeds_ort = torch.from_numpy(output[0])
541+
542+
# # mad between promp_embed and prompt_embed_ort
543+
# mad=torch.abs(prompt_embeds-prompt_embeds_ort).mean()
544+
# print("mad between ort and pytorch", mad)
489545

490546
_, seq_len, _ = prompt_embeds.shape
491547

@@ -623,16 +679,32 @@ def __call__(
623679
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
624680
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
625681
max_sequence_length: int = 256,
626-
sigmas: Optional[List[float]] = None,
627-
skip_guidance_layers: List[int] = None,
628-
skip_layer_guidance_scale: float = 2.8,
629-
skip_layer_guidance_stop: float = 0.2,
630-
skip_layer_guidance_start: float = 0.01,
631-
mu: Optional[float] = None,
632-
vae_type="vae",
633682
):
634683
height = height or self.default_sample_size * self.vae_scale_factor
635684
width = width or self.default_sample_size * self.vae_scale_factor
685+
device = "cpu"
686+
687+
self.check_inputs(
688+
prompt,
689+
prompt_2,
690+
prompt_3,
691+
height,
692+
width,
693+
negative_prompt=negative_prompt,
694+
negative_prompt_2=negative_prompt_2,
695+
negative_prompt_3=negative_prompt_3,
696+
prompt_embeds=prompt_embeds,
697+
negative_prompt_embeds=negative_prompt_embeds,
698+
pooled_prompt_embeds=pooled_prompt_embeds,
699+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
700+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
701+
max_sequence_length=max_sequence_length,
702+
)
703+
704+
self._guidance_scale = guidance_scale
705+
self._clip_skip = clip_skip
706+
self._joint_attention_kwargs = joint_attention_kwargs
707+
self._interrupt = False
636708

637709
(
638710
prompt_embeds,
@@ -654,11 +726,6 @@ def __call__(
654726
max_sequence_length=max_sequence_length,
655727
)
656728

657-
self._guidance_scale = guidance_scale
658-
self._clip_skip = clip_skip
659-
self._joint_attention_kwargs = joint_attention_kwargs
660-
self._interrupt = False
661-
662729
# 2. Define call parameters
663730
if prompt is not None and isinstance(prompt, str):
664731
batch_size = 1
@@ -667,34 +734,28 @@ def __call__(
667734
else:
668735
batch_size = prompt_embeds.shape[0]
669736

670-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
671-
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
737+
if self.do_classifier_free_guidance:
738+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
739+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
672740

673-
# 4. Prepare latent variables
741+
# 4. Prepare timesteps
742+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
743+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
744+
self._num_timesteps = len(timesteps)
745+
746+
# 5. Prepare latent variables
674747
num_channels_latents = self.transformer.model.config.in_channels
675748
latents = self.prepare_latents(
676749
batch_size * num_images_per_prompt,
677750
num_channels_latents,
678751
height,
679752
width,
680753
prompt_embeds.dtype,
681-
"cpu",
754+
device,
682755
generator,
683756
latents,
684757
)
685758

686-
# 5. Prepare timesteps
687-
scheduler_kwargs = {}
688-
timesteps, num_inference_steps = retrieve_timesteps(
689-
self.scheduler,
690-
num_inference_steps,
691-
"cpu",
692-
sigmas=sigmas,
693-
**scheduler_kwargs,
694-
)
695-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
696-
self._num_timesteps = len(timesteps)
697-
698759
###### AIC related changes of transformers ######
699760
if self.transformer.qpc_session is None:
700761
self.transformer.qpc_session = QAICInferenceSession(str(self.transformer.qpc_path))

QEfficient/transformers/models/modeling_auto.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,6 @@
1414
import numpy as np
1515
import torch
1616
import torch.nn as nn
17-
from transformers import (
18-
AutoModel,
19-
AutoModelForCausalLM,
20-
AutoModelForImageTextToText,
21-
AutoModelForSpeechSeq2Seq,
22-
PreTrainedTokenizer,
23-
PreTrainedTokenizerFast,
24-
TextStreamer,
25-
)
2617

2718
import QEfficient
2819
from QEfficient.base.modeling_qeff import QEFFBaseModel
@@ -58,6 +49,15 @@
5849
)
5950
from QEfficient.utils.cache import to_hashable
6051
from QEfficient.utils.logging_utils import logger
52+
from transformers import (
53+
AutoModel,
54+
AutoModelForCausalLM,
55+
AutoModelForImageTextToText,
56+
AutoModelForSpeechSeq2Seq,
57+
PreTrainedTokenizer,
58+
PreTrainedTokenizerFast,
59+
TextStreamer,
60+
)
6161

6262

6363
class QEFFTransformersBase(QEFFBaseModel):

0 commit comments

Comments
 (0)