Skip to content

Commit cec5928

Browse files
committed
remove restriction to run conv_norm in fp32
1 parent 4e67675 commit cec5928

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/diffusers/models/resnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ def forward(self, x, temb):
332332

333333
# make sure hidden states is in float32
334334
# when running in half-precision
335-
hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
335+
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
336+
# hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
336337
hidden_states = self.nonlinearity(hidden_states)
337338

338339
if self.upsample is not None:
@@ -350,7 +351,8 @@ def forward(self, x, temb):
350351

351352
# make sure hidden states is in float32
352353
# when running in half-precision
353-
hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
354+
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
355+
# hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
354356
hidden_states = self.nonlinearity(hidden_states)
355357

356358
hidden_states = self.dropout(hidden_states)

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def forward(
261261
# 6. post-process
262262
# make sure hidden states is in float32
263263
# when running in half-precision
264-
sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
264+
# sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
265+
sample = self.conv_norm_out(sample).type(sample.dtype)
265266
sample = self.conv_act(sample)
266267
sample = self.conv_out(sample)
267268

0 commit comments

Comments
 (0)