Skip to content

Conversation

@jackiehimel
Copy link

@jackiehimel jackiehimel commented Nov 16, 2025

What does this PR do?

Adds SDPA and FlashAttention-2 support to LayoutLMv3 using the unified attention interface pattern, following the same architecture used in BERT and other recent model implementations.

Fixes #35467

Implementation

  • Refactored LayoutLMv3SelfAttention to use ALL_ATTENTION_FUNCTIONS interface instead of separate attention classes
  • Created layoutlmv3_eager_attention_forward function that implements the CogView attention mechanism (alpha=32 scaling) with support for LayoutLMv3's relative position bias and spatial attention bias
  • Added _supports_sdpa = True and _supports_flash_attn = True flags to LayoutLMv3PreTrainedModel
  • Updated mask creation to use create_bidirectional_mask (replacing get_extended_attention_mask)
  • Threaded layer_idx parameter through attention classes for consistency
  • Added automatic enforcement in LayoutLMv3Config to set attn_implementation="eager" when relative or spatial attention biases are enabled (default behavior)

Note: SDPA and FlashAttention-2 are incompatible with LayoutLMv3's relative position bias and spatial attention bias. The config automatically enforces eager attention when these biases are enabled (the default). To use SDPA/FlashAttention-2, users must disable both biases (has_relative_attention_bias=False and has_spatial_attention_bias=False).

Type of change

  • New feature (non-breaking change which adds functionality)

How has this change been tested?

  • CI tests passing, linting and formatting locally validated.
  • Added test skips for SDPA/Flash comparison tests (LayoutLMv3 defaults to eager when biases are enabled) and overrode test_batching_equivalence to ensure eager attention is used.
  • Implementation follows the unified attention pattern from BERT.

Before submitting

  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@vasqu - Thanks for the feedback on #41801! I've refactored this to use the unified attention interface pattern as you suggested. Would appreciate another look when you have time.

@ArthurZucker @Cyrilvallez - attention implementation reviewers

@jackiehimel jackiehimel marked this pull request as draft November 16, 2025 21:37
- Implement unified attention interface following BERT pattern
- Add layoutlmv3_eager_attention_forward with support for relative position bias and spatial attention bias
- Add support flags _supports_flash_attn and _supports_sdpa
- Update attention classes to use unified interface
- Automatically set _attn_implementation='eager' when relative/spatial biases are enabled in config
- Fix test configurations to use eager attention by default
- Override incompatible SDPA/FlashAttention tests with skipTest
- Fix missing case for spatial-only attention bias handling
- Fix position_ids expansion to support inputs_embeds
- Replace get_extended_attention_mask with create_bidirectional_mask

Fixes huggingface#35467
@jackiehimel jackiehimel force-pushed the layoutlmv3-sdpa-flash-attn2 branch from 69b372a to 26aa046 Compare November 16, 2025 22:14
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: layoutlmv3

@jackiehimel jackiehimel marked this pull request as ready for review November 16, 2025 22:47
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I did go through the code and gave some comments but I'm noticing that layout lm only uses the relative bias. If that's the case, then the usage of other attention flavors is questionable as they won't be used either way.

Imo, it would make more sense to go for other models that are suitable. You already did a overall pretty good job over here.

# Take the dot product between "query" and "key" to get the raw attention scores.
# The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
# Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
attention_scores = torch.matmul(query / math.sqrt(query.size(-1)), key.transpose(-1, -2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's essentially the same as bert

attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

with flipped order: attn_weights = torch.matmul(query * scaling, key.transpose(2, 3))

Comment on lines +254 to +258
attention_scores = attention_scores + (rel_pos + rel_2d_pos) / math.sqrt(query.size(-1))
elif module.has_relative_attention_bias and rel_pos is not None:
attention_scores = attention_scores + rel_pos / math.sqrt(query.size(-1))
elif module.has_spatial_attention_bias and rel_2d_pos is not None:
attention_scores = attention_scores + rel_2d_pos / math.sqrt(query.size(-1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we applying the scaling twice here? Do the integration tests still pass? I.e. another / math.sqrt(query.size(-1)


self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's missing an is_causal attribute, I would think it's not causal

self,
hidden_states,
attention_mask=None,
output_attentions=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_attentions=False,

Comment on lines +326 to +336
use_eager = self.config._attn_implementation == "eager"

if not use_eager:
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
"relative position bias and spatial attention bias are not compatible with it. "
'Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_eager = self.config._attn_implementation == "eager"
if not use_eager:
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
"relative position bias and spatial attention bias are not compatible with it. "
'Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.config._attn_implementation != "eager":
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
"relative position bias and spatial attention bias are not compatible with it. "
'Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

For now we raise a value error, thinking about forcing a fallback instead. Let's keep it as is for now

Comment on lines +352 to 353
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
return outputs, attn_weights

You can take a look at

_can_record_outputs = {
"hidden_states": BertLayer,
"attentions": BertSelfAttention,

which takes care of the output xxx. Will need more changes here then but better to go for the right thing from the get go.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bert is a good reference for this + you need the same decorators, they are essential.



# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is layoutlm v2 vastly different, might be worth to change along at the same time.

Comment on lines +182 to +183
# Ensure eager attention is set before model creation
config._attn_implementation = "eager"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the change in prepare_config_and_inputs not enough? Same for below

Comment on lines +334 to +337
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip("LayoutLMv3's relative position bias and spatial attention bias are incompatible with SDPA.")
def test_eager_matches_sdpa_inference(self, *args):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean that layout lm always uses the relative bias? Then it might not even make sense to support sdpa/flash attn - it won't be used either way

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support SDPA & Flash Attention 2 for LayoutLMv3

2 participants