-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Add SDPA and FlashAttention-2 support to LayoutLMv3 #42225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add SDPA and FlashAttention-2 support to LayoutLMv3 #42225
Conversation
- Implement unified attention interface following BERT pattern - Add layoutlmv3_eager_attention_forward with support for relative position bias and spatial attention bias - Add support flags _supports_flash_attn and _supports_sdpa - Update attention classes to use unified interface - Automatically set _attn_implementation='eager' when relative/spatial biases are enabled in config - Fix test configurations to use eager attention by default - Override incompatible SDPA/FlashAttention tests with skipTest - Fix missing case for spatial-only attention bias handling - Fix position_ids expansion to support inputs_embeds - Replace get_extended_attention_mask with create_bidirectional_mask Fixes huggingface#35467
69b372a to
26aa046
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: layoutlmv3 |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I did go through the code and gave some comments but I'm noticing that layout lm only uses the relative bias. If that's the case, then the usage of other attention flavors is questionable as they won't be used either way.
Imo, it would make more sense to go for other models that are suitable. You already did a overall pretty good job over here.
| # Take the dot product between "query" and "key" to get the raw attention scores. | ||
| # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. | ||
| # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290) | ||
| attention_scores = torch.matmul(query / math.sqrt(query.size(-1)), key.transpose(-1, -2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's essentially the same as bert
| attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling |
with flipped order: attn_weights = torch.matmul(query * scaling, key.transpose(2, 3))
| attention_scores = attention_scores + (rel_pos + rel_2d_pos) / math.sqrt(query.size(-1)) | ||
| elif module.has_relative_attention_bias and rel_pos is not None: | ||
| attention_scores = attention_scores + rel_pos / math.sqrt(query.size(-1)) | ||
| elif module.has_spatial_attention_bias and rel_2d_pos is not None: | ||
| attention_scores = attention_scores + rel_2d_pos / math.sqrt(query.size(-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we applying the scaling twice here? Do the integration tests still pass? I.e. another / math.sqrt(query.size(-1)
|
|
||
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | ||
| self.has_relative_attention_bias = config.has_relative_attention_bias | ||
| self.has_spatial_attention_bias = config.has_spatial_attention_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's missing an is_causal attribute, I would think it's not causal
| self, | ||
| hidden_states, | ||
| attention_mask=None, | ||
| output_attentions=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output_attentions=False, |
| use_eager = self.config._attn_implementation == "eager" | ||
|
|
||
| if not use_eager: | ||
| # SDPA and Flash Attention don't support custom relative position bias and spatial attention bias | ||
| if self.has_relative_attention_bias or self.has_spatial_attention_bias: | ||
| raise ValueError( | ||
| f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's " | ||
| "relative position bias and spatial attention bias are not compatible with it. " | ||
| 'Please load the model with `attn_implementation="eager"`.' | ||
| ) | ||
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| use_eager = self.config._attn_implementation == "eager" | |
| if not use_eager: | |
| # SDPA and Flash Attention don't support custom relative position bias and spatial attention bias | |
| if self.has_relative_attention_bias or self.has_spatial_attention_bias: | |
| raise ValueError( | |
| f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's " | |
| "relative position bias and spatial attention bias are not compatible with it. " | |
| 'Please load the model with `attn_implementation="eager"`.' | |
| ) | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| if self.config._attn_implementation != "eager": | |
| # SDPA and Flash Attention don't support custom relative position bias and spatial attention bias | |
| if self.has_relative_attention_bias or self.has_spatial_attention_bias: | |
| raise ValueError( | |
| f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's " | |
| "relative position bias and spatial attention bias are not compatible with it. " | |
| 'Please load the model with `attn_implementation="eager"`.' | |
| ) | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
For now we raise a value error, thinking about forcing a fallback instead. Let's keep it as is for now
| outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) | ||
| return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) | |
| return outputs | |
| return outputs, attn_weights |
You can take a look at
transformers/src/transformers/models/bert/modeling_bert.py
Lines 560 to 562 in 47227f4
| _can_record_outputs = { | |
| "hidden_states": BertLayer, | |
| "attentions": BertSelfAttention, |
which takes care of the output xxx. Will need more changes here then but better to go for the right thing from the get go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bert is a good reference for this + you need the same decorators, they are essential.
|
|
||
|
|
||
| # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 | ||
| # Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is layoutlm v2 vastly different, might be worth to change along at the same time.
| # Ensure eager attention is set before model creation | ||
| config._attn_implementation = "eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the change in prepare_config_and_inputs not enough? Same for below
| @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) | ||
| @unittest.skip("LayoutLMv3's relative position bias and spatial attention bias are incompatible with SDPA.") | ||
| def test_eager_matches_sdpa_inference(self, *args): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean that layout lm always uses the relative bias? Then it might not even make sense to support sdpa/flash attn - it won't be used either way
What does this PR do?
Adds SDPA and FlashAttention-2 support to LayoutLMv3 using the unified attention interface pattern, following the same architecture used in BERT and other recent model implementations.
Fixes #35467
Implementation
LayoutLMv3SelfAttentionto useALL_ATTENTION_FUNCTIONSinterface instead of separate attention classeslayoutlmv3_eager_attention_forwardfunction that implements the CogView attention mechanism (alpha=32 scaling) with support for LayoutLMv3's relative position bias and spatial attention bias_supports_sdpa = Trueand_supports_flash_attn = Trueflags toLayoutLMv3PreTrainedModelcreate_bidirectional_mask(replacingget_extended_attention_mask)layer_idxparameter through attention classes for consistencyLayoutLMv3Configto setattn_implementation="eager"when relative or spatial attention biases are enabled (default behavior)Note: SDPA and FlashAttention-2 are incompatible with LayoutLMv3's relative position bias and spatial attention bias. The config automatically enforces eager attention when these biases are enabled (the default). To use SDPA/FlashAttention-2, users must disable both biases (has_relative_attention_bias=False and has_spatial_attention_bias=False).
Type of change
How has this change been tested?
test_batching_equivalenceto ensure eager attention is used.Before submitting
Who can review?
@vasqu - Thanks for the feedback on #41801! I've refactored this to use the unified attention interface pattern as you suggested. Would appreciate another look when you have time.
@ArthurZucker @Cyrilvallez - attention implementation reviewers