diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 91c450d4a581..b6f5158e515d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -703,7 +703,13 @@ def __init__( self.transformer_index_for_condition = [1, 0] def forward( - self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, ): """ Args: @@ -738,6 +744,7 @@ def forward( input_states, encoder_hidden_states=condition_state, timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states)