Skip to content

Commit 80e69c8

Browse files
authored
Merge branch 'main' into main
2 parents 1d00de3 + 888468d commit 80e69c8

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

src/diffusers/models/attention.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,9 @@ def __init__(
244244
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
245245
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
246246

247-
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
247+
self.to_out = nn.ModuleList([])
248+
self.to_out.append(nn.Linear(inner_dim, query_dim))
249+
self.to_out.append(nn.Dropout(dropout))
248250

249251
def reshape_heads_to_batch_dim(self, tensor):
250252
batch_size, seq_len, dim = tensor.shape
@@ -283,7 +285,11 @@ def forward(self, hidden_states, context=None, mask=None):
283285
else:
284286
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
285287

286-
return self.to_out(hidden_states)
288+
# linear proj
289+
hidden_states = self.to_out[0](hidden_states)
290+
# dropout
291+
hidden_states = self.to_out[1](hidden_states)
292+
return hidden_states
287293

288294
def _attention(self, query, key, value):
289295
# TODO: use baddbmm for better performance
@@ -354,12 +360,19 @@ def __init__(
354360
super().__init__()
355361
inner_dim = int(dim * mult)
356362
dim_out = dim_out if dim_out is not None else dim
357-
project_in = GEGLU(dim, inner_dim)
363+
self.net = nn.ModuleList([])
358364

359-
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
365+
# project in
366+
self.net.append(GEGLU(dim, inner_dim))
367+
# project dropout
368+
self.net.append(nn.Dropout(dropout))
369+
# project out
370+
self.net.append(nn.Linear(inner_dim, dim_out))
360371

361372
def forward(self, hidden_states):
362-
return self.net(hidden_states)
373+
for module in self.net:
374+
hidden_states = module(hidden_states)
375+
return hidden_states
363376

364377

365378
# feedforward

0 commit comments

Comments
 (0)