diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bac85e2f39cf..e8ea37970e04 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -492,6 +492,8 @@ def forward(self, hidden_states, context=None, mask=None): # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value)