Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from functools import partial
from urllib.parse import urljoin

from typing import Optional, Callable, Dict, Union, Any
Expand All @@ -15,7 +14,7 @@
RobertaModel,
)

from .transforms import get_xlmr_transform
import torchtext.transforms as T

from torchtext import _TEXT_BUCKET

Expand Down Expand Up @@ -156,10 +155,13 @@ def encoderConf(self) -> RobertaEncoderConf:
XLMR_BASE_ENCODER = RobertaModelBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
transform=partial(get_xlmr_transform,
vocab_path=urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"),
spm_model_path=urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
)
transform=lambda: T.Sequential(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, reading this, it might not

https://stackoverflow.com/a/27928036

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yet I feel overwhelmed by the fact that this lambda spans multiple lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, interesting I wasn't aware of this around lambda usage. thanks for sharing. I honestly thought it looks nicer being able to define a lazy transform in-place :). The alternative would be to define a private function that returns the composite transform.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I am looking at all the possible reasoning, but I am also being convinced that lambda is the simplest here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way would be to compose a no-op lambda with the rest of the transform with partial, though it isn't as easy to read.

T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
T.Truncate(510),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious as to why we are truncating to a length of 510 exactly? In the XLMR transform that is being removed, I noticed that max_seq_len was an input arg with a default value of 514 and we would truncate to a length of self.max_seq_len - 2 which would be 512 instead of 510 right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 510 + the BOS and EOS is right. The 514 seems like an off-by-two.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this up. Right, I think 514 added earlier was not correct. The final max length should be 512 including the special tokens. Let me also confirm from the author of XLM-RoBERTa if this the right setting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like some tests would help.

T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
Comment on lines +162 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that the bos and eos tokens will always have an index of 0 and 2 respectively? Could there be a usecase where the user uses their own vocab instead of loading from xlmr.vocab.pt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly plausible that they'd add, say, additional special tokens. They would be appended to the vocab though, so 0 and 2 should remain as BOS and EOS respectively. If they didn't, you'd break the "contract" with how RoBERTa was trained.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are pre-trained models, and the transform comply with what the model was originally trained with. Certainly users can bring in their own flavors or Sentencepiece models, vocab etc, but then that transform is specific to user's model and won't apply to the pre-trained weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks for the clarification!

)
)

XLMR_BASE_ENCODER.__doc__ = (
Expand All @@ -174,10 +176,13 @@ def encoderConf(self) -> RobertaEncoderConf:
XLMR_LARGE_ENCODER = RobertaModelBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.large.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24),
transform=partial(get_xlmr_transform,
vocab_path=urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"),
spm_model_path=urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
)
transform=lambda: T.Sequential(
T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
T.Truncate(510),
T.AddToken(token=0, begin=True),
T.AddToken(token=2, begin=False),
)
)

XLMR_LARGE_ENCODER.__doc__ = (
Expand Down
82 changes: 0 additions & 82 deletions torchtext/models/roberta/transforms.py

This file was deleted.