Skip to content

Commit 2fa9c69

Browse files
committed
hardcore whats needed for jitting
1 parent 31c58ea commit 2fa9c69

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __call__(
195195
truncation=True,
196196
return_tensors="pt",
197197
)
198-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
198+
text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0]
199199

200200
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201201
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -207,7 +207,7 @@ def __call__(
207207
uncond_input = self.tokenizer(
208208
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
209209
)
210-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
210+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0]
211211

212212
# For classifier free guidance, we need to do two forward passes.
213213
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -219,8 +219,8 @@ def __call__(
219219
# Unlike in other pipelines, latents need to be generated in the target device
220220
# for 1-to-1 results reproducibility with the CompVis implementation.
221221
# However this currently doesn't work in `mps`.
222-
latents_device = "cpu" if self.device.type == "mps" else self.device
223-
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
222+
latents_device = "cuda"
223+
latents_shape = (batch_size, 4, height // 8, width // 8)
224224
if latents is None:
225225
latents = torch.randn(
226226
latents_shape,
@@ -259,7 +259,7 @@ def __call__(
259259
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
260260

261261
# predict the noise residual
262-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
262+
noise_pred = self.unet(latent_model_input, t, text_embeddings)[0] # TODO: fix for return_dict case
263263

264264
# perform guidance
265265
if do_classifier_free_guidance:
@@ -280,9 +280,9 @@ def __call__(
280280
image = image.cpu().permute(0, 2, 3, 1).numpy()
281281

282282
# run safety checker
283-
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
283+
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to("cuda")
284284
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype))
285-
285+
286286
if output_type == "pil":
287287
image = self.numpy_to_pil(image)
288288

0 commit comments

Comments
 (0)