Skip to content

Commit d30f968

Browse files
committed
model can run in other precisions without autocast
1 parent 39994cc commit d30f968

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

src/diffusers/models/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def forward(self, x, temb):
333333

334334
# make sure hidden states is in float32
335335
# when running in half-precision
336-
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
336+
hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
337337
hidden_states = self.nonlinearity(hidden_states)
338338

339339
if self.upsample is not None:
@@ -351,7 +351,7 @@ def forward(self, x, temb):
351351

352352
# make sure hidden states is in float32
353353
# when running in half-precision
354-
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
354+
hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
355355
hidden_states = self.nonlinearity(hidden_states)
356356

357357
hidden_states = self.dropout(hidden_states)

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def forward(
177177
timesteps = timesteps.expand(sample.shape[0])
178178

179179
t_emb = self.time_proj(timesteps)
180-
emb = self.time_embedding(t_emb)
180+
emb = self.time_embedding(t_emb.to(self.dtype))
181181

182182
# 2. pre-process
183183
sample = self.conv_in(sample)
@@ -215,7 +215,7 @@ def forward(
215215
# 6. post-process
216216
# make sure hidden states is in float32
217217
# when running in half-precision
218-
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
218+
sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
219219
sample = self.conv_act(sample)
220220
sample = self.conv_out(sample)
221221

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def __call__(
204204
latents_shape,
205205
generator=generator,
206206
device=self.device,
207+
dtype=text_embeddings.dtype,
207208
)
208209
else:
209210
if latents.shape != latents_shape:
@@ -263,7 +264,7 @@ def __call__(
263264

264265
# run safety checker
265266
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
266-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
267+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype))
267268

268269
if output_type == "pil":
269270
image = self.numpy_to_pil(image)

0 commit comments

Comments
 (0)