|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | + |
| 7 | +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel |
| 8 | +from transformers.utils import ModelOutput |
| 9 | + |
| 10 | + |
| 11 | +@dataclass |
| 12 | +class TransformationModelOutput(ModelOutput): |
| 13 | + """ |
| 14 | + Base class for text model's outputs that also contains a pooling of the last hidden states. |
| 15 | +
|
| 16 | + Args: |
| 17 | + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): |
| 18 | + The text embeddings obtained by applying the projection layer to the pooler_output. |
| 19 | + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| 20 | + Sequence of hidden-states at the output of the last layer of the model. |
| 21 | + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| 22 | + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| 23 | + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| 24 | +
|
| 25 | + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| 26 | + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| 27 | + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| 28 | + sequence_length)`. |
| 29 | +
|
| 30 | + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| 31 | + heads. |
| 32 | + """ |
| 33 | + |
| 34 | + projection_state: Optional[torch.FloatTensor] = None |
| 35 | + last_hidden_state: torch.FloatTensor = None |
| 36 | + hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| 37 | + attentions: Optional[Tuple[torch.FloatTensor]] = None |
| 38 | + |
| 39 | + |
| 40 | +class RobertaSeriesConfig(XLMRobertaConfig): |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + pad_token_id=1, |
| 44 | + bos_token_id=0, |
| 45 | + eos_token_id=2, |
| 46 | + project_dim=512, |
| 47 | + pooler_fn="cls", |
| 48 | + learn_encoder=False, |
| 49 | + use_attention_mask=True, |
| 50 | + **kwargs, |
| 51 | + ): |
| 52 | + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
| 53 | + self.project_dim = project_dim |
| 54 | + self.pooler_fn = pooler_fn |
| 55 | + self.learn_encoder = learn_encoder |
| 56 | + self.use_attention_mask = use_attention_mask |
| 57 | + |
| 58 | + |
| 59 | +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): |
| 60 | + _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| 61 | + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
| 62 | + base_model_prefix = "roberta" |
| 63 | + config_class = RobertaSeriesConfig |
| 64 | + |
| 65 | + def __init__(self, config): |
| 66 | + super().__init__(config) |
| 67 | + self.roberta = XLMRobertaModel(config) |
| 68 | + self.transformation = nn.Linear(config.hidden_size, config.project_dim) |
| 69 | + self.post_init() |
| 70 | + |
| 71 | + def forward( |
| 72 | + self, |
| 73 | + input_ids: Optional[torch.Tensor] = None, |
| 74 | + attention_mask: Optional[torch.Tensor] = None, |
| 75 | + token_type_ids: Optional[torch.Tensor] = None, |
| 76 | + position_ids: Optional[torch.Tensor] = None, |
| 77 | + head_mask: Optional[torch.Tensor] = None, |
| 78 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 79 | + encoder_hidden_states: Optional[torch.Tensor] = None, |
| 80 | + encoder_attention_mask: Optional[torch.Tensor] = None, |
| 81 | + output_attentions: Optional[bool] = None, |
| 82 | + return_dict: Optional[bool] = None, |
| 83 | + output_hidden_states: Optional[bool] = None, |
| 84 | + ): |
| 85 | + r""" """ |
| 86 | + |
| 87 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 88 | + |
| 89 | + outputs = self.base_model( |
| 90 | + input_ids=input_ids, |
| 91 | + attention_mask=attention_mask, |
| 92 | + token_type_ids=token_type_ids, |
| 93 | + position_ids=position_ids, |
| 94 | + head_mask=head_mask, |
| 95 | + inputs_embeds=inputs_embeds, |
| 96 | + encoder_hidden_states=encoder_hidden_states, |
| 97 | + encoder_attention_mask=encoder_attention_mask, |
| 98 | + output_attentions=output_attentions, |
| 99 | + output_hidden_states=output_hidden_states, |
| 100 | + return_dict=return_dict, |
| 101 | + ) |
| 102 | + |
| 103 | + projection_state = self.transformation(outputs.last_hidden_state) |
| 104 | + |
| 105 | + return TransformationModelOutput( |
| 106 | + projection_state=projection_state, |
| 107 | + last_hidden_state=outputs.last_hidden_state, |
| 108 | + hidden_states=outputs.hidden_states, |
| 109 | + attentions=outputs.attentions, |
| 110 | + ) |
0 commit comments