Skip to content

Commit af8a038

Browse files
Merge pull request #1 from huggingface/master
Pulling commits from main repo
2 parents dbbd6c7 + 68a889e commit af8a038

27 files changed

+755
-284
lines changed

.circleci/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
- run: sudo pip install --progress-bar off .
1010
- run: sudo pip install pytest ftfy spacy
1111
- run: sudo python -m spacy download en
12-
- run: python -m pytest -sv tests/
12+
- run: python -m pytest -sv tests/ --runslow
1313
build_py2:
1414
working_directory: ~/pytorch-pretrained-BERT
1515
docker:
@@ -20,7 +20,7 @@ jobs:
2020
- run: sudo pip install pytest spacy
2121
- run: sudo pip install ftfy==4.4.3
2222
- run: sudo python -m spacy download en
23-
- run: python -m pytest -sv tests/
23+
- run: python -m pytest -sv tests/ --runslow
2424
workflows:
2525
version: 2
2626
build_and_test:

README.md

Lines changed: 124 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ This package comprises the following classes that can be imported in Python and
131131
- Configuration classes for BERT, OpenAI GPT and Transformer-XL (in the respective [`modeling.py`](./pytorch_pretrained_bert/modeling.py), [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py), [`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py) files):
132132
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files.
133133
- `OpenAIGPTConfig` - Configuration class to store the configuration of a `OpenAIGPTModel` with utilities to read and write from JSON configuration files.
134+
- `GPT2Config` - Configuration class to store the configuration of a `GPT2Model` with utilities to read and write from JSON configuration files.
134135
- `TransfoXLConfig` - Configuration class to store the configuration of a `TransfoXLModel` with utilities to read and write from JSON configuration files.
135136

136137
The repository further comprises:
@@ -461,10 +462,12 @@ Here is a detailed documentation of the classes in the package and how to use th
461462

462463
| Sub-section | Description |
463464
|-|-|
464-
| [Loading Google AI's/OpenAI's pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
465-
| [PyTorch models](#PyTorch-models) | API of the BERT, GPT, GPT-2 and Transformer-XL PyTorch model classes |
465+
| [Loading pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
466+
| [Serialization best-practices](#serialization-best-practices) | How to save and reload a fine-tuned model |
467+
| [Configurations](#configurations) | API of the configuration classes for BERT, GPT, GPT-2 and Transformer-XL |
468+
| [Models](#models) | API of the PyTorch model classes for BERT, GPT, GPT-2 and Transformer-XL |
466469
| [Tokenizers](#tokenizers) | API of the tokenizers class for BERT, GPT, GPT-2 and Transformer-XL|
467-
| [Optimizers](#optimizerss) | API of the optimizers |
470+
| [Optimizers](#optimizers) | API of the optimizers |
468471

469472
### Loading Google AI or OpenAI pre-trained weights or PyTorch dump
470473

@@ -524,7 +527,101 @@ model = GPT2Model.from_pretrained('gpt2')
524527

525528
```
526529

527-
### PyTorch models
530+
### Serialization best-practices
531+
532+
This section explain how you can save and re-load a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL).
533+
There are three types of files you need to save to be able to reload a fine-tuned model:
534+
535+
- the model it-self which should be saved following PyTorch serialization [best practices](https://pytorch.org/docs/stable/notes/serialization.html#best-practices),
536+
- the configuration file of the model which is saved as a JSON file, and
537+
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
538+
539+
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
540+
541+
```python
542+
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
543+
544+
output_dir = "./models/"
545+
546+
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
547+
548+
# If we have a distributed model, save only the encapsulated model
549+
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
550+
model_to_save = model.module if hasattr(model, 'module') else model
551+
552+
# If we save using the predefined names, we can load using `from_pretrained`
553+
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
554+
output_config_file = os.path.join(output_dir, CONFIG_NAME)
555+
556+
torch.save(model_to_save.state_dict(), output_model_file)
557+
model_to_save.config.to_json_file(output_config_file)
558+
tokenizer.save_vocabulary(output_dir)
559+
560+
# Step 2: Re-load the saved model and vocabulary
561+
562+
# Example for a Bert model
563+
model = BertForQuestionAnswering.from_pretrained(output_dir)
564+
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
565+
# Example for a GPT model
566+
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
567+
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
568+
```
569+
570+
Here is another way you can save and reload the model if you want to use specific paths for each type of files:
571+
572+
```python
573+
output_model_file = "./models/my_own_model_file.bin"
574+
output_config_file = "./models/my_own_config_file.bin"
575+
output_vocab_file = "./models/my_own_vocab_file.bin"
576+
577+
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
578+
579+
# If we have a distributed model, save only the encapsulated model
580+
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
581+
model_to_save = model.module if hasattr(model, 'module') else model
582+
583+
torch.save(model_to_save.state_dict(), output_model_file)
584+
model_to_save.config.to_json_file(output_config_file)
585+
tokenizer.save_vocabulary(output_vocab_file)
586+
587+
# Step 2: Re-load the saved model and vocabulary
588+
589+
# We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`.
590+
# Here is how to do it in this situation:
591+
592+
# Example for a Bert model
593+
config = BertConfig.from_json_file(output_config_file)
594+
model = BertForQuestionAnswering(config)
595+
state_dict = torch.load(output_model_file)
596+
model.load_state_dict(state_dict)
597+
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
598+
599+
# Example for a GPT model
600+
config = OpenAIGPTConfig.from_json_file(output_config_file)
601+
model = OpenAIGPTDoubleHeadsModel(config)
602+
state_dict = torch.load(output_model_file)
603+
model.load_state_dict(state_dict)
604+
tokenizer = OpenAIGPTTokenizer(output_vocab_file)
605+
```
606+
607+
### Configurations
608+
609+
Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are:
610+
611+
- `BertConfig` for `BertModel` and BERT classes instances.
612+
- `OpenAIGPTConfig` for `OpenAIGPTModel` and OpenAI GPT classes instances.
613+
- `GPT2Config` for `GPT2Model` and OpenAI GPT-2 classes instances.
614+
- `TransfoXLConfig` for `TransfoXLModel` and Transformer-XL classes instances.
615+
616+
These configuration classes contains a few utilities to load and save configurations:
617+
618+
- `from_dict(cls, json_object)`: A class method to construct a configuration from a Python dictionary of parameters. Returns an instance of the configuration class.
619+
- `from_json_file(cls, json_file)`: A class method to construct a configuration from a json file of parameters. Returns an instance of the configuration class.
620+
- `to_dict()`: Serializes an instance to a Python dictionary. Returns a dictionary.
621+
- `to_json_string()`: Serializes an instance to a JSON string. Returns a string.
622+
- `to_json_file(json_file_path)`: Save an instance to a json file.
623+
624+
### Models
528625

529626
#### 1. `BertModel`
530627

@@ -796,8 +893,7 @@ This model *outputs*:
796893
- `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
797894
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
798895

799-
800-
### Tokenizers:
896+
### Tokenizers
801897

802898
#### `BertTokenizer`
803899

@@ -816,6 +912,7 @@ and three methods:
816912
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
817913
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
818914
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
915+
- `save_vocabulary(directory_path)`: save the vocabulary file to `directory_path`. Return the path to the saved vocabulary file: `vocab_file_path`. The vocabulary can be reloaded with `BertTokenizer.from_pretrained('vocab_file_path')` or `BertTokenizer.from_pretrained('directory_path')`.
819916

820917
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
821918

@@ -832,18 +929,22 @@ This class has four arguments:
832929

833930
and five methods:
834931

835-
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
932+
- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing BPE tokenization.
836933
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
837934
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
838935
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
936+
- `encode(text)`: convert a `str` in a list of `int` tokens by performing BPE encoding.
839937
- `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces.
938+
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
840939

841940
Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) for the details of the `OpenAIGPTTokenizer`.
842941

843942
#### `TransfoXLTokenizer`
844943

845944
`TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details.
846945

946+
The API is similar to the API of `BertTokenizer` (see above).
947+
847948
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
848949

849950
#### `GPT2Tokenizer`
@@ -858,13 +959,17 @@ This class has three arguments:
858959

859960
and two methods:
860961

962+
- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing byte-level BPE.
963+
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
964+
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
965+
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
861966
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
862967
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
968+
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
863969

864970
Please refer to [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) for more details on the `GPT2Tokenizer`.
865971

866-
867-
### Optimizers:
972+
### Optimizers
868973

869974
#### `BertAdam`
870975

@@ -1174,18 +1279,20 @@ To get these results we used a combination of:
11741279

11751280
Here is the full list of hyper-parameters for this run:
11761281
```bash
1282+
export SQUAD_DIR=/path/to/SQUAD
1283+
11771284
python ./run_squad.py \
11781285
--bert_model bert-large-uncased \
11791286
--do_train \
11801287
--do_predict \
11811288
--do_lower_case \
1182-
--train_file $SQUAD_TRAIN \
1183-
--predict_file $SQUAD_EVAL \
1289+
--train_file $SQUAD_DIR/train-v1.1.json \
1290+
--predict_file $SQUAD_DIR/dev-v1.1.json \
11841291
--learning_rate 3e-5 \
11851292
--num_train_epochs 2 \
11861293
--max_seq_length 384 \
11871294
--doc_stride 128 \
1188-
--output_dir $OUTPUT_DIR \
1295+
--output_dir /tmp/debug_squad/ \
11891296
--train_batch_size 24 \
11901297
--gradient_accumulation_steps 2
11911298
```
@@ -1194,18 +1301,20 @@ If you have a recent GPU (starting from NVIDIA Volta series), you should try **1
11941301

11951302
Here is an example of hyper-parameters for a FP16 run we tried:
11961303
```bash
1304+
export SQUAD_DIR=/path/to/SQUAD
1305+
11971306
python ./run_squad.py \
11981307
--bert_model bert-large-uncased \
11991308
--do_train \
12001309
--do_predict \
12011310
--do_lower_case \
1202-
--train_file $SQUAD_TRAIN \
1203-
--predict_file $SQUAD_EVAL \
1311+
--train_file $SQUAD_DIR/train-v1.1.json \
1312+
--predict_file $SQUAD_DIR/dev-v1.1.json \
12041313
--learning_rate 3e-5 \
12051314
--num_train_epochs 2 \
12061315
--max_seq_length 384 \
12071316
--doc_stride 128 \
1208-
--output_dir $OUTPUT_DIR \
1317+
--output_dir /tmp/debug_squad/ \
12091318
--train_batch_size 24 \
12101319
--fp16 \
12111320
--loss_scale 128

examples/run_classifier.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,11 @@
3535
from scipy.stats import pearsonr, spearmanr
3636
from sklearn.metrics import matthews_corrcoef, f1_score
3737

38-
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
39-
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
38+
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
39+
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig
4040
from pytorch_pretrained_bert.tokenization import BertTokenizer
4141
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
4242

43-
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
44-
datefmt = '%m/%d/%Y %H:%M:%S',
45-
level = logging.INFO)
4643
logger = logging.getLogger(__name__)
4744

4845

@@ -95,7 +92,7 @@ def get_labels(self):
9592
@classmethod
9693
def _read_tsv(cls, input_file, quotechar=None):
9794
"""Reads a tab separated value file."""
98-
with open(input_file, "r") as f:
95+
with open(input_file, "r", encoding="utf-8") as f:
9996
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
10097
lines = []
10198
for line in reader:
@@ -697,6 +694,11 @@ def main():
697694
n_gpu = 1
698695
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
699696
torch.distributed.init_process_group(backend='nccl')
697+
698+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
699+
datefmt = '%m/%d/%Y %H:%M:%S',
700+
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
701+
700702
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
701703
device, n_gpu, bool(args.local_rank != -1), args.fp16))
702704

@@ -857,18 +859,21 @@ def main():
857859
optimizer.zero_grad()
858860
global_step += 1
859861

860-
# Save a trained model and the associated configuration
862+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
863+
# Save a trained model, configuration and tokenizer
861864
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
865+
866+
# If we save using the predefined names, we can load using `from_pretrained`
862867
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
863-
torch.save(model_to_save.state_dict(), output_model_file)
864868
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
865-
with open(output_config_file, 'w') as f:
866-
f.write(model_to_save.config.to_json_string())
867869

868-
# Load a trained model and config that you have fine-tuned
869-
config = BertConfig(output_config_file)
870-
model = BertForSequenceClassification(config, num_labels=num_labels)
871-
model.load_state_dict(torch.load(output_model_file))
870+
torch.save(model_to_save.state_dict(), output_model_file)
871+
model_to_save.config.to_json_file(output_config_file)
872+
tokenizer.save_vocabulary(args.output_dir)
873+
874+
# Load a trained model and vocabulary that you have fine-tuned
875+
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
876+
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
872877
else:
873878
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
874879
model.to(device)

0 commit comments

Comments
 (0)