-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Add SDPA and FlashAttention-2 support to LayoutLMv3 #42225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| import collections | ||||||||||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||||||||||
| from collections.abc import Callable | ||||||||||||||||||||||||||||||||||||||||||
| from typing import Optional, Union | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -25,16 +26,19 @@ | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from ... import initialization as init | ||||||||||||||||||||||||||||||||||||||||||
| from ...activations import ACT2FN | ||||||||||||||||||||||||||||||||||||||||||
| from ...masking_utils import create_bidirectional_mask | ||||||||||||||||||||||||||||||||||||||||||
| from ...modeling_layers import GradientCheckpointingLayer | ||||||||||||||||||||||||||||||||||||||||||
| from ...modeling_outputs import ( | ||||||||||||||||||||||||||||||||||||||||||
| BaseModelOutput, | ||||||||||||||||||||||||||||||||||||||||||
| QuestionAnsweringModelOutput, | ||||||||||||||||||||||||||||||||||||||||||
| SequenceClassifierOutput, | ||||||||||||||||||||||||||||||||||||||||||
| TokenClassifierOutput, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| from ...modeling_utils import PreTrainedModel | ||||||||||||||||||||||||||||||||||||||||||
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel | ||||||||||||||||||||||||||||||||||||||||||
| from ...processing_utils import Unpack | ||||||||||||||||||||||||||||||||||||||||||
| from ...pytorch_utils import apply_chunking_to_forward | ||||||||||||||||||||||||||||||||||||||||||
| from ...utils import ( | ||||||||||||||||||||||||||||||||||||||||||
| TransformersKwargs, | ||||||||||||||||||||||||||||||||||||||||||
| auto_docstring, | ||||||||||||||||||||||||||||||||||||||||||
| logging, | ||||||||||||||||||||||||||||||||||||||||||
| torch_int, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -203,6 +207,8 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel): | |||||||||||||||||||||||||||||||||||||||||
| config: LayoutLMv3Config | ||||||||||||||||||||||||||||||||||||||||||
| base_model_prefix = "layoutlmv3" | ||||||||||||||||||||||||||||||||||||||||||
| input_modalities = ["image", "text"] | ||||||||||||||||||||||||||||||||||||||||||
| _supports_flash_attn = True | ||||||||||||||||||||||||||||||||||||||||||
| _supports_sdpa = True | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||
| def _init_weights(self, module): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -214,18 +220,80 @@ def _init_weights(self, module): | |||||||||||||||||||||||||||||||||||||||||
| init.zeros_(module.pos_embed) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def layoutlmv3_eager_attention_forward( | ||||||||||||||||||||||||||||||||||||||||||
| module: nn.Module, | ||||||||||||||||||||||||||||||||||||||||||
| query: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
| key: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
| value: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||
| scaling: Optional[float] = None, | ||||||||||||||||||||||||||||||||||||||||||
| dropout: float = 0.0, | ||||||||||||||||||||||||||||||||||||||||||
| rel_pos: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs: Unpack[TransformersKwargs], | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| LayoutLMv3 eager attention with support for relative position bias and spatial attention bias. | ||||||||||||||||||||||||||||||||||||||||||
| Based on the CogView attention trick for training stability. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| if scaling is None: | ||||||||||||||||||||||||||||||||||||||||||
| scaling = 1.0 / math.sqrt(query.size(-1)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Take the dot product between "query" and "key" to get the raw attention scores. | ||||||||||||||||||||||||||||||||||||||||||
| # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. | ||||||||||||||||||||||||||||||||||||||||||
| # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290) | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = torch.matmul(query / math.sqrt(query.size(-1)), key.transpose(-1, -2)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Add relative position bias and spatial attention bias if available | ||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||
| module.has_relative_attention_bias | ||||||||||||||||||||||||||||||||||||||||||
| and module.has_spatial_attention_bias | ||||||||||||||||||||||||||||||||||||||||||
| and rel_pos is not None | ||||||||||||||||||||||||||||||||||||||||||
| and rel_2d_pos is not None | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = attention_scores + (rel_pos + rel_2d_pos) / math.sqrt(query.size(-1)) | ||||||||||||||||||||||||||||||||||||||||||
| elif module.has_relative_attention_bias and rel_pos is not None: | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = attention_scores + rel_pos / math.sqrt(query.size(-1)) | ||||||||||||||||||||||||||||||||||||||||||
| elif module.has_spatial_attention_bias and rel_2d_pos is not None: | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = attention_scores + rel_2d_pos / math.sqrt(query.size(-1)) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+254
to
+258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we applying the scaling twice here? Do the integration tests still pass? I.e. another |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if attention_mask is not None: | ||||||||||||||||||||||||||||||||||||||||||
| # Apply the attention mask | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = attention_scores + attention_mask | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Normalize the attention scores to probabilities. | ||||||||||||||||||||||||||||||||||||||||||
| # Use the trick of the CogView paper to stabilize training | ||||||||||||||||||||||||||||||||||||||||||
| # https://huggingface.co/papers/2105.13290 Section 2.4 | ||||||||||||||||||||||||||||||||||||||||||
| alpha = 32 | ||||||||||||||||||||||||||||||||||||||||||
| scaled_attention_scores = attention_scores / alpha | ||||||||||||||||||||||||||||||||||||||||||
| max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1) | ||||||||||||||||||||||||||||||||||||||||||
| new_attention_scores = (scaled_attention_scores - max_value) * alpha | ||||||||||||||||||||||||||||||||||||||||||
| attention_probs = nn.functional.softmax(new_attention_scores, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # This is actually dropping out entire tokens to attend to, which might | ||||||||||||||||||||||||||||||||||||||||||
| # seem a bit unusual, but is taken from the original Transformer paper. | ||||||||||||||||||||||||||||||||||||||||||
| attention_probs = nn.functional.dropout(attention_probs, p=dropout, training=module.training) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| attn_output = torch.matmul(attention_probs, value) | ||||||||||||||||||||||||||||||||||||||||||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return attn_output, attention_probs | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class LayoutLMv3SelfAttention(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config, layer_idx=None): | ||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " | ||||||||||||||||||||||||||||||||||||||||||
| f"heads ({config.num_attention_heads})" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| self.config = config | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| self.num_attention_heads = config.num_attention_heads | ||||||||||||||||||||||||||||||||||||||||||
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) | ||||||||||||||||||||||||||||||||||||||||||
| self.all_head_size = self.num_attention_heads * self.attention_head_size | ||||||||||||||||||||||||||||||||||||||||||
| self.scaling = self.attention_head_size**-0.5 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| self.query = nn.Linear(config.hidden_size, self.all_head_size) | ||||||||||||||||||||||||||||||||||||||||||
| self.key = nn.Linear(config.hidden_size, self.all_head_size) | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -234,18 +302,7 @@ def __init__(self, config): | |||||||||||||||||||||||||||||||||||||||||
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | ||||||||||||||||||||||||||||||||||||||||||
| self.has_relative_attention_bias = config.has_relative_attention_bias | ||||||||||||||||||||||||||||||||||||||||||
| self.has_spatial_attention_bias = config.has_spatial_attention_bias | ||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's missing an |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def cogview_attention(self, attention_scores, alpha=32): | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation | ||||||||||||||||||||||||||||||||||||||||||
| (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs | ||||||||||||||||||||||||||||||||||||||||||
| will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs, | ||||||||||||||||||||||||||||||||||||||||||
| cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| scaled_attention_scores = attention_scores / alpha | ||||||||||||||||||||||||||||||||||||||||||
| max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1) | ||||||||||||||||||||||||||||||||||||||||||
| new_attention_scores = (scaled_attention_scores - max_value) * alpha | ||||||||||||||||||||||||||||||||||||||||||
| return nn.Softmax(dim=-1)(new_attention_scores) | ||||||||||||||||||||||||||||||||||||||||||
| self.layer_idx = layer_idx | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -254,54 +311,45 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||
| output_attentions=False, | ||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| rel_pos=None, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos=None, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs: Unpack[TransformersKwargs], | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| batch_size, seq_length, _ = hidden_states.shape | ||||||||||||||||||||||||||||||||||||||||||
| query_layer = ( | ||||||||||||||||||||||||||||||||||||||||||
| self.query(hidden_states) | ||||||||||||||||||||||||||||||||||||||||||
| .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) | ||||||||||||||||||||||||||||||||||||||||||
| .transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| key_layer = ( | ||||||||||||||||||||||||||||||||||||||||||
| self.key(hidden_states) | ||||||||||||||||||||||||||||||||||||||||||
| .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) | ||||||||||||||||||||||||||||||||||||||||||
| .transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| value_layer = ( | ||||||||||||||||||||||||||||||||||||||||||
| self.value(hidden_states) | ||||||||||||||||||||||||||||||||||||||||||
| .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) | ||||||||||||||||||||||||||||||||||||||||||
| .transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Take the dot product between "query" and "key" to get the raw attention scores. | ||||||||||||||||||||||||||||||||||||||||||
| # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. | ||||||||||||||||||||||||||||||||||||||||||
| # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290) | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if self.has_relative_attention_bias and self.has_spatial_attention_bias: | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) | ||||||||||||||||||||||||||||||||||||||||||
| elif self.has_relative_attention_bias: | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores += rel_pos / math.sqrt(self.attention_head_size) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if attention_mask is not None: | ||||||||||||||||||||||||||||||||||||||||||
| # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) | ||||||||||||||||||||||||||||||||||||||||||
| attention_scores = attention_scores + attention_mask | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Normalize the attention scores to probabilities. | ||||||||||||||||||||||||||||||||||||||||||
| # Use the trick of the CogView paper to stabilize training | ||||||||||||||||||||||||||||||||||||||||||
| attention_probs = self.cogview_attention(attention_scores) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # This is actually dropping out entire tokens to attend to, which might | ||||||||||||||||||||||||||||||||||||||||||
| # seem a bit unusual, but is taken from the original Transformer paper. | ||||||||||||||||||||||||||||||||||||||||||
| attention_probs = self.dropout(attention_probs) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| context_layer = torch.matmul(attention_probs, value_layer) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | ||||||||||||||||||||||||||||||||||||||||||
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | ||||||||||||||||||||||||||||||||||||||||||
| context_layer = context_layer.view(*new_context_layer_shape) | ||||||||||||||||||||||||||||||||||||||||||
| input_shape = hidden_states.shape[:-1] | ||||||||||||||||||||||||||||||||||||||||||
| hidden_shape = (*input_shape, -1, self.attention_head_size) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Get query, key, value projections | ||||||||||||||||||||||||||||||||||||||||||
| query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||||
| key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||||
| value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Determine attention implementation | ||||||||||||||||||||||||||||||||||||||||||
| attention_interface: Callable = layoutlmv3_eager_attention_forward | ||||||||||||||||||||||||||||||||||||||||||
| use_eager = self.config._attn_implementation == "eager" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if not use_eager: | ||||||||||||||||||||||||||||||||||||||||||
| # SDPA and Flash Attention don't support custom relative position bias and spatial attention bias | ||||||||||||||||||||||||||||||||||||||||||
| if self.has_relative_attention_bias or self.has_spatial_attention_bias: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||
| f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's " | ||||||||||||||||||||||||||||||||||||||||||
| "relative position bias and spatial attention bias are not compatible with it. " | ||||||||||||||||||||||||||||||||||||||||||
| 'Please load the model with `attn_implementation="eager"`.' | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+326
to
+336
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
For now we raise a value error, thinking about forcing a fallback instead. Let's keep it as is for now |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) | ||||||||||||||||||||||||||||||||||||||||||
| attn_output, attn_weights = attention_interface( | ||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||
| query_layer, | ||||||||||||||||||||||||||||||||||||||||||
| key_layer, | ||||||||||||||||||||||||||||||||||||||||||
| value_layer, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||||||||||
| dropout=0.0 if not self.training else self.dropout.p, | ||||||||||||||||||||||||||||||||||||||||||
| scaling=self.scaling, | ||||||||||||||||||||||||||||||||||||||||||
| rel_pos=rel_pos, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos=rel_2d_pos, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) | ||||||||||||||||||||||||||||||||||||||||||
| return outputs | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+352
to
353
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
You can take a look at transformers/src/transformers/models/bert/modeling_bert.py Lines 560 to 562 in 47227f4
which takes care of the output xxx. Will need more changes here then but better to go for the right thing from the get go.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bert is a good reference for this + you need the same decorators, they are essential. |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -320,11 +368,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to | |||||||||||||||||||||||||||||||||||||||||
| return hidden_states | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 | ||||||||||||||||||||||||||||||||||||||||||
| # Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 | ||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is layoutlm v2 vastly different, might be worth to change along at the same time. |
||||||||||||||||||||||||||||||||||||||||||
| class LayoutLMv3Attention(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config, layer_idx=None): | ||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||
| self.self = LayoutLMv3SelfAttention(config) | ||||||||||||||||||||||||||||||||||||||||||
| self.self = LayoutLMv3SelfAttention(config, layer_idx=layer_idx) | ||||||||||||||||||||||||||||||||||||||||||
| self.output = LayoutLMv3SelfOutput(config) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -334,26 +382,28 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||
| output_attentions=False, | ||||||||||||||||||||||||||||||||||||||||||
| rel_pos=None, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos=None, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs: Unpack[TransformersKwargs], | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| self_outputs = self.self( | ||||||||||||||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||||||||||
| output_attentions, | ||||||||||||||||||||||||||||||||||||||||||
| rel_pos=rel_pos, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos=rel_2d_pos, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| attention_output = self.output(self_outputs[0], hidden_states) | ||||||||||||||||||||||||||||||||||||||||||
| outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them | ||||||||||||||||||||||||||||||||||||||||||
| return outputs | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 | ||||||||||||||||||||||||||||||||||||||||||
| # Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 | ||||||||||||||||||||||||||||||||||||||||||
| class LayoutLMv3Layer(GradientCheckpointingLayer): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config, layer_idx=None): | ||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||
| self.chunk_size_feed_forward = config.chunk_size_feed_forward | ||||||||||||||||||||||||||||||||||||||||||
| self.seq_len_dim = 1 | ||||||||||||||||||||||||||||||||||||||||||
| self.attention = LayoutLMv3Attention(config) | ||||||||||||||||||||||||||||||||||||||||||
| self.attention = LayoutLMv3Attention(config, layer_idx=layer_idx) | ||||||||||||||||||||||||||||||||||||||||||
| self.intermediate = LayoutLMv3Intermediate(config) | ||||||||||||||||||||||||||||||||||||||||||
| self.output = LayoutLMv3Output(config) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -364,13 +414,15 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||
| output_attentions=False, | ||||||||||||||||||||||||||||||||||||||||||
| rel_pos=None, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos=None, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs: Unpack[TransformersKwargs], | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| self_attention_outputs = self.attention( | ||||||||||||||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||||||||||
| output_attentions=output_attentions, | ||||||||||||||||||||||||||||||||||||||||||
| rel_pos=rel_pos, | ||||||||||||||||||||||||||||||||||||||||||
| rel_2d_pos=rel_2d_pos, | ||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| attention_output = self_attention_outputs[0] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -393,9 +445,8 @@ class LayoutLMv3Encoder(nn.Module): | |||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config): | ||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||
| self.config = config | ||||||||||||||||||||||||||||||||||||||||||
| self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)]) | ||||||||||||||||||||||||||||||||||||||||||
| self.layer = nn.ModuleList([LayoutLMv3Layer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) | ||||||||||||||||||||||||||||||||||||||||||
| self.gradient_checkpointing = False | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| self.has_relative_attention_bias = config.has_relative_attention_bias | ||||||||||||||||||||||||||||||||||||||||||
| self.has_spatial_attention_bias = config.has_spatial_attention_bias | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -803,18 +854,20 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||
| final_bbox = bbox | ||||||||||||||||||||||||||||||||||||||||||
| if self.config.has_relative_attention_bias: | ||||||||||||||||||||||||||||||||||||||||||
| position_ids = self.embeddings.position_ids[:, : input_shape[1]] | ||||||||||||||||||||||||||||||||||||||||||
| position_ids = position_ids.expand_as(input_ids) | ||||||||||||||||||||||||||||||||||||||||||
| position_ids = position_ids.expand(input_shape) | ||||||||||||||||||||||||||||||||||||||||||
| final_position_ids = position_ids | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask, None, device, dtype=embedding_output.dtype | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask = create_bidirectional_mask( | ||||||||||||||||||||||||||||||||||||||||||
| config=self.config, | ||||||||||||||||||||||||||||||||||||||||||
| input_embeds=embedding_output, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| encoder_outputs = self.encoder( | ||||||||||||||||||||||||||||||||||||||||||
| embedding_output, | ||||||||||||||||||||||||||||||||||||||||||
| bbox=final_bbox, | ||||||||||||||||||||||||||||||||||||||||||
| position_ids=final_position_ids, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask=extended_attention_mask, | ||||||||||||||||||||||||||||||||||||||||||
| attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||||||||||
| output_attentions=output_attentions, | ||||||||||||||||||||||||||||||||||||||||||
| output_hidden_states=output_hidden_states, | ||||||||||||||||||||||||||||||||||||||||||
| return_dict=return_dict, | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's essentially the same as bert
transformers/src/transformers/models/bert/modeling_bert.py
Line 131 in 47227f4
with flipped order:
attn_weights = torch.matmul(query * scaling, key.transpose(2, 3))