@@ -244,7 +244,9 @@ def __init__(
244244 self .to_k = nn .Linear (context_dim , inner_dim , bias = False )
245245 self .to_v = nn .Linear (context_dim , inner_dim , bias = False )
246246
247- self .to_out = nn .Sequential (nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout ))
247+ self .to_out = nn .ModuleList ([])
248+ self .to_out .append (nn .Linear (inner_dim , query_dim ))
249+ self .to_out .append (nn .Dropout (dropout ))
248250
249251 def reshape_heads_to_batch_dim (self , tensor ):
250252 batch_size , seq_len , dim = tensor .shape
@@ -283,7 +285,11 @@ def forward(self, hidden_states, context=None, mask=None):
283285 else :
284286 hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim )
285287
286- return self .to_out (hidden_states )
288+ # linear proj
289+ hidden_states = self .to_out [0 ](hidden_states )
290+ # dropout
291+ hidden_states = self .to_out [1 ](hidden_states )
292+ return hidden_states
287293
288294 def _attention (self , query , key , value ):
289295 # TODO: use baddbmm for better performance
@@ -354,12 +360,19 @@ def __init__(
354360 super ().__init__ ()
355361 inner_dim = int (dim * mult )
356362 dim_out = dim_out if dim_out is not None else dim
357- project_in = GEGLU ( dim , inner_dim )
363+ self . net = nn . ModuleList ([] )
358364
359- self .net = nn .Sequential (project_in , nn .Dropout (dropout ), nn .Linear (inner_dim , dim_out ))
365+ # project in
366+ self .net .append (GEGLU (dim , inner_dim ))
367+ # project dropout
368+ self .net .append (nn .Dropout (dropout ))
369+ # project out
370+ self .net .append (nn .Linear (inner_dim , dim_out ))
360371
361372 def forward (self , hidden_states ):
362- return self .net (hidden_states )
373+ for module in self .net :
374+ hidden_states = module (hidden_states )
375+ return hidden_states
363376
364377
365378# feedforward
0 commit comments