@@ -673,6 +673,8 @@ def forward(
673673 encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
674674 attention_mask : Optional [torch .FloatTensor ] = None ,
675675 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
676+ # parameter exists only for interface-compatibility with other blocks. prefer attention_mask
677+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
676678 ):
677679 cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
678680 hidden_states = self .resnets [0 ](hidden_states , temb )
@@ -1524,6 +1526,8 @@ def forward(
15241526 encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
15251527 attention_mask : Optional [torch .FloatTensor ] = None ,
15261528 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1529+ # parameter exists only for interface-compatibility with other blocks. prefer attention_mask
1530+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
15271531 ):
15281532 output_states = ()
15291533 cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
@@ -2623,6 +2627,8 @@ def forward(
26232627 upsample_size : Optional [int ] = None ,
26242628 attention_mask : Optional [torch .FloatTensor ] = None ,
26252629 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
2630+ # parameter exists only for interface-compatibility with other blocks. prefer attention_mask
2631+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
26262632 ):
26272633 cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
26282634 for resnet , attn in zip (self .resnets , self .attentions ):
0 commit comments