Skip to content

Commit 5786b0e

Browse files
authored
handle dtype xformers attention (#1196)
handle dtype xformers
1 parent 32b0736 commit 5786b0e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,8 @@ def forward(self, hidden_states, context=None, mask=None):
492492
# attention, what we cannot get enough of
493493
if self._use_memory_efficient_attention_xformers:
494494
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
495+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
496+
hidden_states = hidden_states.to(query.dtype)
495497
else:
496498
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
497499
hidden_states = self._attention(query, key, value)

0 commit comments

Comments
 (0)