Skip to content

Commit cc9bc13

Browse files
committed
Revert "remove restriction to run conv_norm in fp32"
This reverts commit cec5928.
1 parent 47c668c commit cc9bc13

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

src/diffusers/models/resnet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ def forward(self, x, temb):
332332

333333
# make sure hidden states is in float32
334334
# when running in half-precision
335-
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
336-
# hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
335+
hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
337336
hidden_states = self.nonlinearity(hidden_states)
338337

339338
if self.upsample is not None:
@@ -351,8 +350,7 @@ def forward(self, x, temb):
351350

352351
# make sure hidden states is in float32
353352
# when running in half-precision
354-
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
355-
# hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
353+
hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
356354
hidden_states = self.nonlinearity(hidden_states)
357355

358356
hidden_states = self.dropout(hidden_states)

src/diffusers/models/unet_2d_condition.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,7 @@ def forward(
261261
# 6. post-process
262262
# make sure hidden states is in float32
263263
# when running in half-precision
264-
# sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
265-
sample = self.conv_norm_out(sample).type(sample.dtype)
264+
sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
266265
sample = self.conv_act(sample)
267266
sample = self.conv_out(sample)
268267

0 commit comments

Comments
 (0)