Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,12 @@ def __init__(
self.patch_size = patch_size
self.classifier_dropout = classifier_dropout

# LayoutLMv3's relative position bias and spatial attention bias are incompatible with SDPA/FlashAttention
# Automatically set eager attention when these biases are enabled, unless explicitly set by user
if has_relative_attention_bias or has_spatial_attention_bias:
# Only set if not already explicitly set via kwargs (attn_implementation is processed in super().__init__)
if self._attn_implementation is None:
self._attn_implementation = "eager"


__all__ = ["LayoutLMv3Config"]
195 changes: 124 additions & 71 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import collections
import math
from collections.abc import Callable
from typing import Optional, Union

import torch
Expand All @@ -25,16 +26,19 @@

from ... import initialization as init
from ...activations import ACT2FN
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
TransformersKwargs,
auto_docstring,
logging,
torch_int,
Expand Down Expand Up @@ -203,6 +207,8 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel):
config: LayoutLMv3Config
base_model_prefix = "layoutlmv3"
input_modalities = ["image", "text"]
_supports_flash_attn = True
_supports_sdpa = True

@torch.no_grad()
def _init_weights(self, module):
Expand All @@ -214,18 +220,80 @@ def _init_weights(self, module):
init.zeros_(module.pos_embed)


def layoutlmv3_eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
rel_pos: Optional[torch.Tensor] = None,
rel_2d_pos: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
):
"""
LayoutLMv3 eager attention with support for relative position bias and spatial attention bias.
Based on the CogView attention trick for training stability.
"""
if scaling is None:
scaling = 1.0 / math.sqrt(query.size(-1))

# Take the dot product between "query" and "key" to get the raw attention scores.
# The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
# Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
attention_scores = torch.matmul(query / math.sqrt(query.size(-1)), key.transpose(-1, -2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's essentially the same as bert

attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

with flipped order: attn_weights = torch.matmul(query * scaling, key.transpose(2, 3))


# Add relative position bias and spatial attention bias if available
if (
module.has_relative_attention_bias
and module.has_spatial_attention_bias
and rel_pos is not None
and rel_2d_pos is not None
):
attention_scores = attention_scores + (rel_pos + rel_2d_pos) / math.sqrt(query.size(-1))
elif module.has_relative_attention_bias and rel_pos is not None:
attention_scores = attention_scores + rel_pos / math.sqrt(query.size(-1))
elif module.has_spatial_attention_bias and rel_2d_pos is not None:
attention_scores = attention_scores + rel_2d_pos / math.sqrt(query.size(-1))
Comment on lines +254 to +258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we applying the scaling twice here? Do the integration tests still pass? I.e. another / math.sqrt(query.size(-1)


if attention_mask is not None:
# Apply the attention mask
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
# Use the trick of the CogView paper to stabilize training
# https://huggingface.co/papers/2105.13290 Section 2.4
alpha = 32
scaled_attention_scores = attention_scores / alpha
max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
new_attention_scores = (scaled_attention_scores - max_value) * alpha
attention_probs = nn.functional.softmax(new_attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = nn.functional.dropout(attention_probs, p=dropout, training=module.training)

attn_output = torch.matmul(attention_probs, value)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_probs


class LayoutLMv3SelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5

self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
Expand All @@ -234,18 +302,7 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's missing an is_causal attribute, I would think it's not causal


def cogview_attention(self, attention_scores, alpha=32):
"""
https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
(PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
"""
scaled_attention_scores = attention_scores / alpha
max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
new_attention_scores = (scaled_attention_scores - max_value) * alpha
return nn.Softmax(dim=-1)(new_attention_scores)
self.layer_idx = layer_idx

def forward(
self,
Expand All @@ -254,54 +311,45 @@ def forward(
output_attentions=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_attentions=False,

rel_pos=None,
rel_2d_pos=None,
**kwargs: Unpack[TransformersKwargs],
):
batch_size, seq_length, _ = hidden_states.shape
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)

# Take the dot product between "query" and "key" to get the raw attention scores.
# The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
# Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))

if self.has_relative_attention_bias and self.has_spatial_attention_bias:
attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
elif self.has_relative_attention_bias:
attention_scores += rel_pos / math.sqrt(self.attention_head_size)

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
# Use the trick of the CogView paper to stabilize training
attention_probs = self.cogview_attention(attention_scores)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)

# Get query, key, value projections
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)

# Determine attention implementation
attention_interface: Callable = layoutlmv3_eager_attention_forward
use_eager = self.config._attn_implementation == "eager"

if not use_eager:
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
"relative position bias and spatial attention bias are not compatible with it. "
'Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
Comment on lines +326 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_eager = self.config._attn_implementation == "eager"
if not use_eager:
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
"relative position bias and spatial attention bias are not compatible with it. "
'Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.config._attn_implementation != "eager":
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
"relative position bias and spatial attention bias are not compatible with it. "
'Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

For now we raise a value error, thinking about forcing a fallback instead. Let's keep it as is for now


outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
Comment on lines +352 to 353
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
return outputs, attn_weights

You can take a look at

_can_record_outputs = {
"hidden_states": BertLayer,
"attentions": BertSelfAttention,

which takes care of the output xxx. Will need more changes here then but better to go for the right thing from the get go.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bert is a good reference for this + you need the same decorators, they are essential.



Expand All @@ -320,11 +368,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is layoutlm v2 vastly different, might be worth to change along at the same time.

class LayoutLMv3Attention(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.self = LayoutLMv3SelfAttention(config)
self.self = LayoutLMv3SelfAttention(config, layer_idx=layer_idx)
self.output = LayoutLMv3SelfOutput(config)

def forward(
Expand All @@ -334,26 +382,28 @@ def forward(
output_attentions=False,
rel_pos=None,
rel_2d_pos=None,
**kwargs: Unpack[TransformersKwargs],
):
self_outputs = self.self(
hidden_states,
attention_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs


# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
class LayoutLMv3Layer(GradientCheckpointingLayer):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = LayoutLMv3Attention(config)
self.attention = LayoutLMv3Attention(config, layer_idx=layer_idx)
self.intermediate = LayoutLMv3Intermediate(config)
self.output = LayoutLMv3Output(config)

Expand All @@ -364,13 +414,15 @@ def forward(
output_attentions=False,
rel_pos=None,
rel_2d_pos=None,
**kwargs: Unpack[TransformersKwargs],
):
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
output_attentions=output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
**kwargs,
)
attention_output = self_attention_outputs[0]

Expand All @@ -393,9 +445,8 @@ class LayoutLMv3Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([LayoutLMv3Layer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias

Expand Down Expand Up @@ -803,18 +854,20 @@ def forward(
final_bbox = bbox
if self.config.has_relative_attention_bias:
position_ids = self.embeddings.position_ids[:, : input_shape[1]]
position_ids = position_ids.expand_as(input_ids)
position_ids = position_ids.expand(input_shape)
final_position_ids = position_ids

extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, None, device, dtype=embedding_output.dtype
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

encoder_outputs = self.encoder(
embedding_output,
bbox=final_bbox,
position_ids=final_position_ids,
attention_mask=extended_attention_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down
Loading