From a2e863df34451faeed227e5cb8dba444a4dadde1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 17:10:29 +0100 Subject: [PATCH] handle dtype xformers --- src/diffusers/models/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bac85e2f39cf..e8ea37970e04 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -492,6 +492,8 @@ def forward(self, hidden_states, context=None, mask=None): # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value)