From 0c78a19fd158148258721752da72a17222081ba6 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Fri, 30 Sep 2022 15:58:40 +0000 Subject: [PATCH 1/2] revert using baddbmm in attention - to fix `test_stable_diffusion_memory_chunking` test --- src/diffusers/models/attention.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b4e5f2e07f7d..9bfe85b01e54 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -274,13 +274,7 @@ def forward(self, hidden_states, context=None, mask=None): return self.to_out(hidden_states) def _attention(self, query, key, value): - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale #TODO: use baddbmm for better performance attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value) From 12fc28fe122d5b6ea722a331bdb4aa0047ec92b8 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Fri, 30 Sep 2022 16:03:34 +0000 Subject: [PATCH 2/2] styling --- src/diffusers/models/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9bfe85b01e54..c2f27bd9282d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -274,7 +274,8 @@ def forward(self, hidden_states, context=None, mask=None): return self.to_out(hidden_states) def _attention(self, query, key, value): - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale #TODO: use baddbmm for better performance + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value)