Skip to content

Conversation

@thomwolf
Copy link
Member

@thomwolf thomwolf commented Aug 26, 2019

This PR tests how easy it would be to incorporate TF 2.0 models in the current library:

  • adds a few models: TFBertPreTrainedModel, TFBertModel, TFBertForPretraining, TFBertForMaskedLM, TFBertForNextSentencePrediction,
  • weights conversion script to convert the PyTorch weights (only the bert-base-uncased model is up on our AWS S3 bucket for the moment),
  • a few tests.

The library is (very) slightly reorganized to allow for this, mostly by spinning configuration classes out of (PyTorch) modeling classes to allow reusability between PyTorch and TF 2.0 models.

With TF 2.0 Keras imperative interface and Eager, the workflow and models are suprisingly similar:

import numpy
import torch
import tensorflow as tf
from pytorch_transformers import BertModel, TFBertModel, BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
pytorch_model = BertModel.from_pretrained('bert-base-uncased')
tf_model = TFBertModel.from_pretrained('bert-base-uncased')

text = "[CLS] Who was Jim Henson ? Jim [MASK] was a puppeteer [SEP]"
tokens = tokenizer.encode(text)

pytorch_inputs = torch.tensor([tokens])
tf_inputs = tf.constant([tokens])

with torch.no_grad():
    pytorch_outputs = pytorch_model(pytorch_inputs)

tf_output = tf_model(tf_inputs, training=False)

numpy.amax(numpy.abs(pytorch_outputs[0].numpy() - tf_output[0].numpy()))
# >>> 2.861023e-06 => we are good,  a few 1e-6 is the expected difference
# between TF and PT arising from internals computation ops

If you want to play with this, you can install from the tf branch like this:

  • install TF 2.0: pip install tensorflow==2.0.0-rc0
  • install pytorch-transformers from the tf branch: pip install https://github.com/huggingface/pytorch-transformers/archive/tf.zip

@thomwolf thomwolf changed the title TensorFlow 2.0 - Testing with Bert TensorFlow 2.0 - Testing with a few Bert architectures Aug 26, 2019
@codecov-io
Copy link

Codecov Report

Merging #1104 into master will decrease coverage by 0.49%.
The diff coverage is 81.56%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master    #1104     +/-   ##
=========================================
- Coverage   79.61%   79.12%   -0.5%     
=========================================
  Files          42       56     +14     
  Lines        6898     7654    +756     
=========================================
+ Hits         5492     6056    +564     
- Misses       1406     1598    +192
Impacted Files Coverage Δ
pytorch_transformers/tests/modeling_xlnet_test.py 95.91% <100%> (+0.02%) ⬆️
pytorch_transformers/modeling_transfo_xl.py 55.2% <100%> (-2.33%) ⬇️
pytorch_transformers/tests/modeling_xlm_test.py 71.2% <100%> (+0.23%) ⬆️
pytorch_transformers/modeling_openai.py 73.35% <100%> (-1.42%) ⬇️
pytorch_transformers/tests/modeling_auto_test.py 96.15% <100%> (+0.15%) ⬆️
pytorch_transformers/tests/modeling_gpt2_test.py 85% <100%> (+0.78%) ⬆️
pytorch_transformers/tests/conftest.py 91.66% <100%> (+1.66%) ⬆️
pytorch_transformers/modeling_gpt2.py 74.74% <100%> (-1.1%) ⬇️
pytorch_transformers/modeling_roberta.py 75.45% <100%> (-0.44%) ⬇️
pytorch_transformers/modeling_xlnet.py 78.04% <100%> (-0.98%) ⬇️
... and 39 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update df9d6ef...3231797. Read the comment docs.

@thomwolf thomwolf closed this Sep 24, 2019
@LysandreJik LysandreJik deleted the tf branch April 27, 2022 15:35
qgallouedec pushed a commit to qgallouedec/transformers that referenced this pull request May 17, 2025
* Fix token decode in fill-mask pipeline

* Add support for ModernBERT

* Add modernbert unit tests

* Cleanup bert unit tests

* Add unit test for `sequence_length > local_attention_window`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants