Skip to content

Commit b4f5cb9

Browse files
committed
put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface.
1 parent e35b7df commit b4f5cb9

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/diffusers/models/unet_2d_blocks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,8 @@ def forward(
673673
encoder_hidden_states: Optional[torch.FloatTensor] = None,
674674
attention_mask: Optional[torch.FloatTensor] = None,
675675
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
676+
# parameter exists only for interface-compatibility with other blocks. prefer attention_mask
677+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
676678
):
677679
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
678680
hidden_states = self.resnets[0](hidden_states, temb)
@@ -1524,6 +1526,8 @@ def forward(
15241526
encoder_hidden_states: Optional[torch.FloatTensor] = None,
15251527
attention_mask: Optional[torch.FloatTensor] = None,
15261528
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1529+
# parameter exists only for interface-compatibility with other blocks. prefer attention_mask
1530+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
15271531
):
15281532
output_states = ()
15291533
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
@@ -2623,6 +2627,8 @@ def forward(
26232627
upsample_size: Optional[int] = None,
26242628
attention_mask: Optional[torch.FloatTensor] = None,
26252629
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2630+
# parameter exists only for interface-compatibility with other blocks. prefer attention_mask
2631+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
26262632
):
26272633
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
26282634
for resnet, attn in zip(self.resnets, self.attentions):

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,8 @@ def forward(
17071707
encoder_hidden_states: Optional[torch.FloatTensor] = None,
17081708
attention_mask: Optional[torch.FloatTensor] = None,
17091709
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1710+
# parameter exists only for interface-compatibility with other blocks. prefer attention_mask
1711+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
17101712
):
17111713
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
17121714
hidden_states = self.resnets[0](hidden_states, temb)

0 commit comments

Comments
 (0)