diff --git a/examples/BERT/README.rst b/examples/BERT/README.rst new file mode 100644 index 0000000000..8cf0a55281 --- /dev/null +++ b/examples/BERT/README.rst @@ -0,0 +1,152 @@ +BERT with torchtext ++++++++++ + +This example shows how to train a BERT model with PyTorch and torchtext only. Then, we fine-tune the pre-trained BERT for the question-answer task. + + +Generate pre-trained BERT +------------------------- + +Train the BERT model with masked language modeling task and next-sentence task. Run the tasks on a local GPU or CPU: + + python mlm_task.py + python ns_task.py + +or run the tasks on a SLURM powered cluster with Distributed Data Parallel (DDP): + + srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python mlm_task.py --parallel DDP --log-interval 600 --dataset BookCorpus + + srun --label --ntasks-per-node=1 --time=4000 --mem-per-cpu=5120 --gres=gpu:8 --cpus-per-task 80 --nodes=1 --pty python ns_task.py --parallel DDP --bert-model mlm_bert.pt --dataset BookCorpus + +The result ppl of mlm_task is 18.97899 for the test set. +The result loss of ns_task is 0.05446 for the test set. + +Fine-tune pre-trained BERT for question-answer task +--------------------------------------------------- + +With SQuAD dataset, the pre-trained BERT is used for question-answer task: + + python qa_task.py --bert-model ns_bert.pt --epochs 30 + +The pre-trained BERT models and vocab are available: + +* `bert_vocab.pt `_ +* `mlm_bert.pt `_ +* `ns_bert.pt `_ + +An example train/valid/test printout with the pretrained BERT model in question-answer task: + + | epoch 1 | 200/ 1055 batches | lr 5.00000 | ms/batch 746.33 | loss 3.70 | ppl 40.45 + | epoch 1 | 400/ 1055 batches | lr 5.00000 | ms/batch 746.78 | loss 3.06 | ppl 21.25 + | epoch 1 | 600/ 1055 batches | lr 5.00000 | ms/batch 746.83 | loss 2.84 | ppl 17.15 + | epoch 1 | 800/ 1055 batches | lr 5.00000 | ms/batch 746.55 | loss 2.69 | ppl 14.73 + | epoch 1 | 1000/ 1055 batches | lr 5.00000 | ms/batch 745.48 | loss 2.55 | ppl 12.85 + ----------------------------------------------------------------------------------------- + | end of epoch 1 | time: 821.25s | valid loss 2.33 | exact 40.052% | f1 52.595% + ----------------------------------------------------------------------------------------- + | epoch 2 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.17 | loss 2.33 | ppl 10.25 + | epoch 2 | 400/ 1055 batches | lr 5.00000 | ms/batch 745.52 | loss 2.28 | ppl 9.75 + | epoch 2 | 600/ 1055 batches | lr 5.00000 | ms/batch 745.50 | loss 2.24 | ppl 9.37 + | epoch 2 | 800/ 1055 batches | lr 5.00000 | ms/batch 745.10 | loss 2.22 | ppl 9.18 + | epoch 2 | 1000/ 1055 batches | lr 5.00000 | ms/batch 744.61 | loss 2.16 | ppl 8.66 + ----------------------------------------------------------------------------------------- + | end of epoch 2 | time: 820.75s | valid loss 2.12 | exact 44.632% | f1 57.965% + ----------------------------------------------------------------------------------------- + | epoch 3 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.88 | loss 2.00 | ppl 7.41 + | epoch 3 | 400/ 1055 batches | lr 5.00000 | ms/batch 746.46 | loss 1.99 | ppl 7.29 + | epoch 3 | 600/ 1055 batches | lr 5.00000 | ms/batch 745.35 | loss 1.99 | ppl 7.30 + | epoch 3 | 800/ 1055 batches | lr 5.00000 | ms/batch 746.03 | loss 1.98 | ppl 7.27 + | epoch 3 | 1000/ 1055 batches | lr 5.00000 | ms/batch 746.01 | loss 1.98 | ppl 7.26 + ----------------------------------------------------------------------------------------- + | end of epoch 3 | time: 821.98s | valid loss 1.96 | exact 48.001% | f1 61.036% + ----------------------------------------------------------------------------------------- + | epoch 4 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.72 | loss 1.82 | ppl 6.19 + | epoch 4 | 400/ 1055 batches | lr 5.00000 | ms/batch 745.86 | loss 1.84 | ppl 6.28 + | epoch 4 | 600/ 1055 batches | lr 5.00000 | ms/batch 745.63 | loss 1.85 | ppl 6.34 + | epoch 4 | 800/ 1055 batches | lr 5.00000 | ms/batch 745.59 | loss 1.82 | ppl 6.20 + | epoch 4 | 1000/ 1055 batches | lr 5.00000 | ms/batch 745.55 | loss 1.83 | ppl 6.21 + ----------------------------------------------------------------------------------------- + | end of epoch 4 | time: 821.10s | valid loss 1.95 | exact 49.149% | f1 62.040% + ----------------------------------------------------------------------------------------- + | epoch 5 | 200/ 1055 batches | lr 5.00000 | ms/batch 748.40 | loss 1.66 | ppl 5.24 + | epoch 5 | 400/ 1055 batches | lr 5.00000 | ms/batch 756.09 | loss 1.69 | ppl 5.44 + | epoch 5 | 600/ 1055 batches | lr 5.00000 | ms/batch 769.19 | loss 1.70 | ppl 5.46 + | epoch 5 | 800/ 1055 batches | lr 5.00000 | ms/batch 764.96 | loss 1.72 | ppl 5.58 + | epoch 5 | 1000/ 1055 batches | lr 5.00000 | ms/batch 773.25 | loss 1.70 | ppl 5.49 + ----------------------------------------------------------------------------------------- + | end of epoch 5 | time: 844.20s | valid loss 1.99 | exact 49.509% | f1 61.994% + ----------------------------------------------------------------------------------------- + | epoch 6 | 200/ 1055 batches | lr 0.50000 | ms/batch 765.25 | loss 1.50 | ppl 4.49 + | epoch 6 | 400/ 1055 batches | lr 0.50000 | ms/batch 749.64 | loss 1.45 | ppl 4.25 + | epoch 6 | 600/ 1055 batches | lr 0.50000 | ms/batch 768.16 | loss 1.40 | ppl 4.06 + | epoch 6 | 800/ 1055 batches | lr 0.50000 | ms/batch 745.69 | loss 1.43 | ppl 4.18 + | epoch 6 | 1000/ 1055 batches | lr 0.50000 | ms/batch 744.90 | loss 1.40 | ppl 4.07 + ----------------------------------------------------------------------------------------- + | end of epoch 6 | time: 829.55s | valid loss 1.97 | exact 51.182% | f1 63.437% + ----------------------------------------------------------------------------------------- + | epoch 7 | 200/ 1055 batches | lr 0.50000 | ms/batch 747.73 | loss 1.36 | ppl 3.89 + | epoch 7 | 400/ 1055 batches | lr 0.50000 | ms/batch 744.50 | loss 1.37 | ppl 3.92 + | epoch 7 | 600/ 1055 batches | lr 0.50000 | ms/batch 744.20 | loss 1.35 | ppl 3.86 + | epoch 7 | 800/ 1055 batches | lr 0.50000 | ms/batch 743.85 | loss 1.36 | ppl 3.89 + | epoch 7 | 1000/ 1055 batches | lr 0.50000 | ms/batch 744.01 | loss 1.34 | ppl 3.83 + ----------------------------------------------------------------------------------------- + | end of epoch 7 | time: 820.02s | valid loss 2.01 | exact 51.507% | f1 63.885% + ----------------------------------------------------------------------------------------- + | epoch 8 | 200/ 1055 batches | lr 0.50000 | ms/batch 747.40 | loss 1.31 | ppl 3.72 + | epoch 8 | 400/ 1055 batches | lr 0.50000 | ms/batch 744.33 | loss 1.30 | ppl 3.68 + | epoch 8 | 600/ 1055 batches | lr 0.50000 | ms/batch 745.76 | loss 1.31 | ppl 3.69 + | epoch 8 | 800/ 1055 batches | lr 0.50000 | ms/batch 745.04 | loss 1.31 | ppl 3.69 + | epoch 8 | 1000/ 1055 batches | lr 0.50000 | ms/batch 745.13 | loss 1.31 | ppl 3.72 + ----------------------------------------------------------------------------------------- + | end of epoch 8 | time: 820.40s | valid loss 2.02 | exact 51.260% | f1 63.762% + ----------------------------------------------------------------------------------------- + | epoch 9 | 200/ 1055 batches | lr 0.05000 | ms/batch 748.36 | loss 1.26 | ppl 3.54 + | epoch 9 | 400/ 1055 batches | lr 0.05000 | ms/batch 744.55 | loss 1.26 | ppl 3.52 + | epoch 9 | 600/ 1055 batches | lr 0.05000 | ms/batch 745.46 | loss 1.23 | ppl 3.44 + | epoch 9 | 800/ 1055 batches | lr 0.05000 | ms/batch 745.23 | loss 1.26 | ppl 3.52 + | epoch 9 | 1000/ 1055 batches | lr 0.05000 | ms/batch 744.69 | loss 1.24 | ppl 3.47 + ----------------------------------------------------------------------------------------- + | end of epoch 9 | time: 820.41s | valid loss 2.02 | exact 51.578% | f1 63.704% + ----------------------------------------------------------------------------------------- + | epoch 10 | 200/ 1055 batches | lr 0.00500 | ms/batch 749.25 | loss 1.25 | ppl 3.50 + | epoch 10 | 400/ 1055 batches | lr 0.00500 | ms/batch 745.81 | loss 1.24 | ppl 3.47 + | epoch 10 | 600/ 1055 batches | lr 0.00500 | ms/batch 744.89 | loss 1.26 | ppl 3.51 + | epoch 10 | 800/ 1055 batches | lr 0.00500 | ms/batch 746.02 | loss 1.23 | ppl 3.42 + | epoch 10 | 1000/ 1055 batches | lr 0.00500 | ms/batch 746.61 | loss 1.25 | ppl 3.50 + ----------------------------------------------------------------------------------------- + | end of epoch 10 | time: 821.85s | valid loss 2.05 | exact 51.648% | f1 63.811% + ----------------------------------------------------------------------------------------- + ========================================================================================= + | End of training | test loss 2.05 | exact 51.337% | f1 63.645% + ========================================================================================= + +Structure of the example +======================== + +model.py +-------- + +This file defines the Transformer and MultiheadAttention models used for BERT. The embedding layer include PositionalEncoding and TokenTypeEncoding layers. MLMTask, NextSentenceTask, and QuestionAnswerTask are the models for the three tasks mentioned above. + +data.py +------- + +This file provides a few datasets required to train the BERT model and question-answer task. Please note that BookCorpus dataset is not available publicly. + + +mlm_task.py, ns_task.py, qa_task.py +----------------------------------- + +Those three files define the train/valid/test process for the tasks. + + +metrics.py +---------- + +This file provides two metrics (F1 and exact score) for question-answer task + + +utils.py +-------- + +This file provides a few utils used by the three tasks. diff --git a/examples/BERT/data.py b/examples/BERT/data.py new file mode 100644 index 0000000000..2f62459c1a --- /dev/null +++ b/examples/BERT/data.py @@ -0,0 +1,54 @@ +import glob +import torch +import logging +from torchtext.data.utils import get_tokenizer +import random +from torchtext.experimental.datasets import LanguageModelingDataset + + +################################################################### +# Set up dataset for book corpus +################################################################### +def BookCorpus(vocab, tokenizer=get_tokenizer("basic_english"), + data_select=('train', 'test', 'valid'), removed_tokens=[], + min_sentence_len=None): + + if isinstance(data_select, str): + data_select = [data_select] + if not set(data_select).issubset(set(('train', 'test', 'valid'))): + raise TypeError('data_select is not supported!') + + extracted_files = glob.glob('/datasets01/bookcorpus/021819/*/*.txt') + random.seed(1000) + random.shuffle(extracted_files) + + num_files = len(extracted_files) + _path = {'train': extracted_files[:(num_files // 20 * 17)], + 'test': extracted_files[(num_files // 20 * 17):(num_files // 20 * 18)], + 'valid': extracted_files[(num_files // 20 * 18):]} + + data = {} + for item in _path.keys(): + data[item] = [] + logging.info('Creating {} data'.format(item)) + tokens = [] + for txt_file in _path[item]: + with open(txt_file, 'r', encoding="utf8", errors='ignore') as f: + for line in f.readlines(): + _tokens = tokenizer(line.strip()) + if min_sentence_len: + if len(_tokens) >= min_sentence_len: + tokens.append([vocab.stoi[token] for token in _tokens]) + else: + tokens += [vocab.stoi[token] for token in _tokens] + data[item] = tokens + + for key in data_select: + if data[key] == []: + raise TypeError('Dataset {} is empty!'.format(key)) + if min_sentence_len: + return tuple(LanguageModelingDataset(data[d], vocab, lambda x: x, False) + for d in data_select) + else: + return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab, lambda x: x, False) + for d in data_select) diff --git a/examples/BERT/metrics.py b/examples/BERT/metrics.py new file mode 100644 index 0000000000..dba20bb753 --- /dev/null +++ b/examples/BERT/metrics.py @@ -0,0 +1,72 @@ +import collections +import re +import string + + +def compute_qa_exact(ans_pred_tokens_samples): + + ''' + Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens), + ([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens), + ... + ([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)] + ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example'] + Output: exact score of the samples + ''' + + def normalize_txt(text): + # lower case + text = text.lower() + + # remove punc + exclude = set(string.punctuation) + text = "".join(ch for ch in text if ch not in exclude) + + # remove articles + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + text = re.sub(regex, " ", text) + + # white space fix + return " ".join(text.split()) + + exact_scores = [] + for (ans_tokens, pred_tokens) in ans_pred_tokens_samples: + pred_str = " ".join(pred_tokens) + candidate_score = [] + for item in ans_tokens: + ans_str = " ".join(item) + candidate_score.append(int(normalize_txt(ans_str) == normalize_txt(pred_str))) + exact_scores.append(max(candidate_score)) + return 100.0 * sum(exact_scores) / len(exact_scores) + + +def compute_qa_f1(ans_pred_tokens_samples): + + ''' + Input: ans_pred_tokens_samples: [([ans1_tokens_candidate1, ans1_tokens_candidate2], pred1_tokens), + ([ans2_tokens_candidate1, ans2_tokens_candidate2], pred2_tokens), + ... + ([ansn_tokens_candidate1, ansn_tokens_candidate2], predn_tokens)] + ans1_tokens_candidate1 = ['this', 'is', 'an', 'sample', 'example'] + Output: f1 score of the samples + ''' + def sample_f1(ans_tokens, pred_tokens): + common = collections.Counter(ans_tokens) & collections.Counter(pred_tokens) + num_same = sum(common.values()) + if len(ans_tokens) == 0 or len(pred_tokens) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(ans_tokens == pred_tokens) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_tokens) + recall = 1.0 * num_same / len(ans_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + f1_scores = [] + for (ans_tokens, pred_tokens) in ans_pred_tokens_samples: + candidate_score = [] + for item in ans_tokens: + candidate_score.append(sample_f1(item, pred_tokens)) + f1_scores.append(max(candidate_score)) + return 100.0 * sum(f1_scores) / len(f1_scores) diff --git a/examples/BERT/mlm_task.py b/examples/BERT/mlm_task.py new file mode 100644 index 0000000000..d0623a9607 --- /dev/null +++ b/examples/BERT/mlm_task.py @@ -0,0 +1,263 @@ +import argparse +import time +import math +import torch +import torch.nn as nn +from model import MLMTask +from utils import run_demo, run_ddp, wrap_up +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader + + +def collate_batch(batch_data, args, mask_id, cls_id): + batch_data = torch.tensor(batch_data).long().view(args.batch_size, -1).t().contiguous() + # Generate masks with args.mask_frac + data_len = batch_data.size(0) + ones_num = int(data_len * args.mask_frac) + zeros_num = data_len - ones_num + lm_mask = torch.cat([torch.zeros(zeros_num), torch.ones(ones_num)]) + lm_mask = lm_mask[torch.randperm(data_len)] + batch_data = torch.cat((torch.tensor([[cls_id] * batch_data.size(1)]).long(), batch_data)) + lm_mask = torch.cat((torch.tensor([0.0]), lm_mask)) + + targets = torch.stack([batch_data[i] for i in range(lm_mask.size(0)) if lm_mask[i]]).view(-1) + batch_data = batch_data.masked_fill(lm_mask.bool().unsqueeze(1), mask_id) + return batch_data, lm_mask, targets + + +def process_raw_data(raw_data, args): + _num = raw_data.size(0) // (args.batch_size * args.bptt) + raw_data = raw_data[:(_num * args.batch_size * args.bptt)] + return raw_data + + +def evaluate(data_source, model, vocab, ntokens, criterion, args, device): + # Turn on evaluation mode which disables dropout. + model.eval() + total_loss = 0. + mask_id = vocab.stoi[''] + cls_id = vocab.stoi[''] + dataloader = DataLoader(data_source, batch_size=args.batch_size * args.bptt, + shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, cls_id)) + with torch.no_grad(): + for batch, (data, lm_mask, targets) in enumerate(dataloader): + if args.parallel == 'DDP': + data = data.to(device[0]) + targets = targets.to(device[0]) + else: + data = data.to(device) + targets = targets.to(device) + data = data.transpose(0, 1) # Wrap up by DDP or DataParallel + output = model(data) + output = torch.stack([output[i] for i in range(lm_mask.size(0)) if lm_mask[i]]) + output_flat = output.view(-1, ntokens) + total_loss += criterion(output_flat, targets).item() + return total_loss / ((len(data_source) - 1) / args.bptt / args.batch_size) + + +def train(model, vocab, train_loss_log, train_data, + optimizer, criterion, ntokens, epoch, scheduler, args, device, rank=None): + model.train() + total_loss = 0. + start_time = time.time() + mask_id = vocab.stoi[''] + cls_id = vocab.stoi[''] + train_loss_log.append(0.0) + dataloader = DataLoader(train_data, batch_size=args.batch_size * args.bptt, + shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, cls_id)) + + for batch, (data, lm_mask, targets) in enumerate(dataloader): + optimizer.zero_grad() + if args.parallel == 'DDP': + data = data.to(device[0]) + targets = targets.to(device[0]) + else: + data = data.to(device) + targets = targets.to(device) + data = data.transpose(0, 1) # Wrap up by DDP or DataParallel + output = model(data) + output = torch.stack([output[i] for i in range(lm_mask.size(0)) if lm_mask[i]]) + loss = criterion(output.view(-1, ntokens), targets) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer.step() + total_loss += loss.item() + if batch % args.log_interval == 0 and batch > 0: + cur_loss = total_loss / args.log_interval + elapsed = time.time() - start_time + if (rank is None) or rank == 0: + train_loss_log[-1] = cur_loss + print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' + 'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch, + len(train_data) // (args.bptt * args.batch_size), + scheduler.get_last_lr()[0], + elapsed * 1000 / args.log_interval, + cur_loss, math.exp(cur_loss))) + total_loss = 0 + start_time = time.time() + + +def run_main(args, rank=None): + torch.manual_seed(args.seed) + if args.parallel == 'DDP': + n = torch.cuda.device_count() // args.world_size + device = list(range(rank * n, (rank + 1) * n)) + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + import torchtext + if args.dataset == 'WikiText103': + from torchtext.experimental.datasets import WikiText103 as WLMDataset + elif args.dataset == 'WikiText2': + from torchtext.experimental.datasets import WikiText2 as WLMDataset + elif args.dataset == 'WMTNewsCrawl': + from data import WMTNewsCrawl as WLMDataset + elif args.dataset == 'EnWik9': + from torchtext.datasets import EnWik9 + elif args.dataset == 'BookCorpus': + from data import BookCorpus + else: + print("dataset for MLM task is not supported") + + try: + vocab = torch.load(args.save_vocab) + except: + train_dataset, test_dataset, valid_dataset = WLMDataset() + old_vocab = train_dataset.vocab + vocab = torchtext.vocab.Vocab(counter=old_vocab.freqs, + specials=['', '', '']) + with open(args.save_vocab, 'wb') as f: + torch.save(vocab, f) + + if args.dataset == 'WikiText103' or args.dataset == 'WikiText2': + train_dataset, test_dataset, valid_dataset = WLMDataset(vocab=vocab) + elif args.dataset == 'WMTNewsCrawl': + from torchtext.experimental.datasets import WikiText2 + test_dataset, valid_dataset = WikiText2(vocab=vocab, data_select=('test', 'valid')) + train_dataset, = WLMDataset(vocab=vocab, data_select='train') + elif args.dataset == 'EnWik9': + enwik9 = EnWik9() + idx1, idx2 = int(len(enwik9) * 0.8), int(len(enwik9) * 0.9) + train_data = torch.tensor([vocab.stoi[_id] + for _id in enwik9[0:idx1]]).long() + val_data = torch.tensor([vocab.stoi[_id] + for _id in enwik9[idx1:idx2]]).long() + test_data = torch.tensor([vocab.stoi[_id] + for _id in enwik9[idx2:]]).long() + from torchtext.experimental.datasets import LanguageModelingDataset + train_dataset = LanguageModelingDataset(train_data, vocab) + valid_dataset = LanguageModelingDataset(val_data, vocab) + test_dataset = LanguageModelingDataset(test_data, vocab) + elif args.dataset == 'BookCorpus': + train_dataset, test_dataset, valid_dataset = BookCorpus(vocab) + + train_data = process_raw_data(train_dataset.data, args) + if rank is not None: + # Chunk training data by rank for different gpus + chunk_len = len(train_data) // args.world_size + train_data = train_data[(rank * chunk_len):((rank + 1) * chunk_len)] + val_data = process_raw_data(valid_dataset.data, args) + test_data = process_raw_data(test_dataset.data, args) + + ntokens = len(train_dataset.get_vocab()) + if args.checkpoint != 'None': + model = torch.load(args.checkpoint) + else: + model = MLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) + if args.parallel == 'DDP': + model = model.to(device[0]) + model = DDP(model, device_ids=device) + else: + model = model.to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) + best_val_loss = None + train_loss_log, val_loss_log = [], [] + + for epoch in range(1, args.epochs + 1): + epoch_start_time = time.time() + train(model, train_dataset.vocab, train_loss_log, train_data, + optimizer, criterion, ntokens, epoch, scheduler, args, device, rank) + val_loss = evaluate(val_data, model, train_dataset.vocab, ntokens, criterion, args, device) + if (rank is None) or (rank == 0): + val_loss_log.append(val_loss) + print('-' * 89) + print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' + 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), + val_loss, math.exp(val_loss))) + print('-' * 89) + if not best_val_loss or val_loss < best_val_loss: + if rank is None: + with open(args.save, 'wb') as f: + torch.save(model, f) + elif rank == 0: + with open(args.save, 'wb') as f: + torch.save(model.state_dict(), f) + best_val_loss = val_loss + else: + scheduler.step() + if args.parallel == 'DDP': + dist.barrier() + rank0_devices = [x - rank * len(device) for x in device] + device_pairs = zip(rank0_devices, device) + map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} + model.load_state_dict( + torch.load(args.save, map_location=map_location)) + test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens, criterion, args, device) + if rank == 0: + wrap_up(train_loss_log, val_loss_log, test_loss, args, model.module, 'mlm_loss.txt', 'full_mlm_model.pt') + else: + with open(args.save, 'rb') as f: + model = torch.load(f) + test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens, criterion, args, device) + wrap_up(train_loss_log, val_loss_log, test_loss, args, model, 'mlm_loss.txt', 'full_mlm_model.pt') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Transformer Language Model') + parser.add_argument('--emsize', type=int, default=768, + help='size of word embeddings') + parser.add_argument('--nhid', type=int, default=3072, + help='number of hidden units per layer') + parser.add_argument('--nlayers', type=int, default=12, + help='number of layers') + parser.add_argument('--nhead', type=int, default=12, + help='the number of heads in the encoder/decoder of the transformer model') + parser.add_argument('--lr', type=float, default=6, + help='initial learning rate') + parser.add_argument('--clip', type=float, default=0.1, + help='gradient clipping') + parser.add_argument('--epochs', type=int, default=8, + help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=32, metavar='N', + help='batch size') + parser.add_argument('--bptt', type=int, default=128, + help='sequence length') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--seed', type=int, default=5431916812, + help='random seed') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='report interval') + parser.add_argument('--checkpoint', type=str, default='None', + help='path to load the checkpoint') + parser.add_argument('--save', type=str, default='mlm_bert.pt', + help='path to save the final model') + parser.add_argument('--save-vocab', type=str, default='torchtext_bert_vocab.pt', + help='path to save the vocab') + parser.add_argument('--mask_frac', type=float, default=0.15, + help='the fraction of masked tokens') + parser.add_argument('--dataset', type=str, default='WikiText2', + help='dataset used for MLM task') + parser.add_argument('--parallel', type=str, default='None', + help='Use DataParallel to train model') + parser.add_argument('--world_size', type=int, default=8, + help='the world size to initiate DPP') + args = parser.parse_args() + + if args.parallel == 'DDP': + run_demo(run_ddp, run_main, args) + else: + run_main(args) diff --git a/examples/BERT/model.py b/examples/BERT/model.py new file mode 100644 index 0000000000..316841c751 --- /dev/null +++ b/examples/BERT/model.py @@ -0,0 +1,175 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Linear, Dropout, LayerNorm, TransformerEncoder +from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEncoding, self).__init__() + self.pos_embedding = nn.Embedding(max_len, d_model) + + def forward(self, x): + S, N = x.size() + pos = torch.arange(S, + dtype=torch.long, + device=x.device).unsqueeze(0).expand((N, S)).t() + return self.pos_embedding(pos) + + +class TokenTypeEncoding(nn.Module): + def __init__(self, type_token_num, d_model): + super(TokenTypeEncoding, self).__init__() + self.token_type_embeddings = nn.Embedding(type_token_num, d_model) + + def forward(self, seq_input, token_type_input): + S, N = seq_input.size() + if token_type_input is None: + token_type_input = torch.zeros((S, N), + dtype=torch.long, + device=seq_input.device) + return self.token_type_embeddings(token_type_input) + + +class BertEmbedding(nn.Module): + def __init__(self, ntoken, ninp, dropout=0.5): + super(BertEmbedding, self).__init__() + self.ninp = ninp + self.ntoken = ntoken + self.pos_embed = PositionalEncoding(ninp) + self.embed = nn.Embedding(ntoken, ninp) + self.tok_type_embed = TokenTypeEncoding(2, ninp) # Two sentence type + self.norm = LayerNorm(ninp) + self.dropout = Dropout(dropout) + + def forward(self, src, token_type_input): + src = self.embed(src) + self.pos_embed(src) \ + + self.tok_type_embed(src, token_type_input) + return self.dropout(self.norm(src)) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, + dropout=0.1, activation="gelu"): + super(TransformerEncoderLayer, self).__init__() + in_proj_container = InProjContainer(Linear(d_model, d_model), + Linear(d_model, d_model), + Linear(d_model, d_model)) + self.mha = MultiheadAttentionContainer(nhead, in_proj_container, + ScaledDotProduct(), Linear(d_model, d_model)) + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + if activation == "relu": + self.activation = F.relu + elif activation == "gelu": + self.activation = F.gelu + else: + raise RuntimeError("only relu/gelu are supported, not {}".format(activation)) + + def init_weights(self): + self.mha.in_proj_container.query_proj.init_weights() + self.mha.in_proj_container.key_proj.init_weights() + self.mha.in_proj_container.value_proj.init_weights() + self.mha.out_proj.init_weights() + self.linear1.weight.data.normal_(mean=0.0, std=0.02) + self.linear2.weight.data.normal_(mean=0.0, std=0.02) + self.norm1.bias.data.zero_() + self.norm1.weight.data.fill_(1.0) + self.norm2.bias.data.zero_() + self.norm2.weight.data.fill_(1.0) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + attn_output, attn_output_weights = self.mha(src, src, src, attn_mask=src_mask) + src = src + self.dropout1(attn_output) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class BertModel(nn.Module): + """Contain a transformer encoder.""" + + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): + super(BertModel, self).__init__() + self.model_type = 'Transformer' + self.bert_embed = BertEmbedding(ntoken, ninp) + encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + self.ninp = ninp + + def forward(self, src, token_type_input): + src = self.bert_embed(src, token_type_input) + output = self.transformer_encoder(src) + return output + + +class MLMTask(nn.Module): + """Contain a transformer encoder plus MLM head.""" + + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): + super(MLMTask, self).__init__() + self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, dropout=0.5) + self.mlm_span = Linear(ninp, ninp) + self.activation = F.gelu + self.norm_layer = LayerNorm(ninp, eps=1e-12) + self.mlm_head = Linear(ninp, ntoken) + + def forward(self, src, token_type_input=None): + src = src.transpose(0, 1) # Wrap up by nn.DataParallel + output = self.bert_model(src, token_type_input) + output = self.mlm_span(output) + output = self.activation(output) + output = self.norm_layer(output) + output = self.mlm_head(output) + return output + + +class NextSentenceTask(nn.Module): + """Contain a pretrain BERT model and a linear layer.""" + + def __init__(self, bert_model): + super(NextSentenceTask, self).__init__() + self.bert_model = bert_model + self.linear_layer = Linear(bert_model.ninp, + bert_model.ninp) + self.ns_span = Linear(bert_model.ninp, 2) + self.activation = nn.Tanh() + + def forward(self, src, token_type_input): + src = src.transpose(0, 1) # Wrap up by nn.DataParallel + output = self.bert_model(src, token_type_input) + # Send the first <'cls'> seq to a classifier + output = self.activation(self.linear_layer(output[0])) + output = self.ns_span(output) + return output + + +class QuestionAnswerTask(nn.Module): + """Contain a pretrain BERT model and a linear layer.""" + + def __init__(self, bert_model): + super(QuestionAnswerTask, self).__init__() + self.bert_model = bert_model + self.activation = F.gelu + self.qa_span = Linear(bert_model.ninp, 2) + + def forward(self, src, token_type_input): + output = self.bert_model(src, token_type_input) + # transpose output (S, N, E) to (N, S, E) + output = output.transpose(0, 1) + output = self.activation(output) + pos_output = self.qa_span(output) + start_pos, end_pos = pos_output.split(1, dim=-1) + start_pos = start_pos.squeeze(-1) + end_pos = end_pos.squeeze(-1) + return start_pos, end_pos diff --git a/examples/BERT/ns_task.py b/examples/BERT/ns_task.py new file mode 100644 index 0000000000..062ebc343f --- /dev/null +++ b/examples/BERT/ns_task.py @@ -0,0 +1,262 @@ +import argparse +import time +import math +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from model import NextSentenceTask, BertModel +from utils import run_demo, run_ddp, wrap_up + + +def process_raw_data(whole_data, args): + processed_data = [] + for _idx in range(len(whole_data)): + item = whole_data[_idx] + if isinstance(item, list): + item = torch.tensor(item) + if len(item) > 1: + # idx to split the text into two sentencd + split_idx = torch.randint(1, len(item), size=(1, 1)).item() + # Index 2 means same sentence label. Initial true int(1) + processed_data.append([item[:split_idx], item[split_idx:], 1]) + # Random shuffle data to have args.frac_ns next sentence set up + shuffle_idx1 = torch.randperm(len(processed_data)) + shuffle_idx2 = torch.randperm(len(processed_data)) + num_shuffle = int(len(processed_data) * args.frac_ns) + shuffle_zip = list(zip(shuffle_idx1, shuffle_idx2))[:num_shuffle] + for (i, j) in shuffle_zip: + processed_data[i][1] = processed_data[j][0] + processed_data[i][2] = int(0) # Switch same sentence label to false 0 + return processed_data + + +def collate_batch(batch, args, cls_id, sep_id, pad_id): + # Fix sequence length to args.bptt with padding or trim + seq_list = [] + tok_type = [] + same_sentence_labels = [] + for item in batch: + qa_item = torch.cat([item[0], torch.tensor([sep_id]).long(), item[1], torch.tensor([sep_id]).long()]) + if qa_item.size(0) > args.bptt: + qa_item = qa_item[:args.bptt] + elif qa_item.size(0) < args.bptt: + qa_item = torch.cat((qa_item, + torch.tensor([pad_id] * (args.bptt - + qa_item.size(0))))) + seq_list.append(qa_item) + _tok_tp = torch.ones((qa_item.size(0))) + _idx = min(len(item[0]) + 1, args.bptt) + _tok_tp[:_idx] = 0.0 + tok_type.append(_tok_tp) + same_sentence_labels.append(item[2]) + seq_input = torch.stack(seq_list).long().t().contiguous() + seq_input = torch.cat((torch.tensor([[cls_id] * seq_input.size(1)]).long(), seq_input)) + tok_type = torch.stack(tok_type).long().t().contiguous() + tok_type = torch.cat((torch.tensor([[0] * tok_type.size(1)]).long(), tok_type)) + return seq_input, tok_type, torch.tensor(same_sentence_labels).long().contiguous() + + +def evaluate(data_source, model, device, criterion, cls_id, sep_id, pad_id, args): + model.eval() + total_loss = 0. + batch_size = args.batch_size + dataloader = DataLoader(data_source, batch_size=batch_size, shuffle=True, + collate_fn=lambda b: collate_batch(b, args, cls_id, sep_id, pad_id)) + with torch.no_grad(): + for idx, (seq_input, tok_type, target_ns_labels) in enumerate(dataloader): + if args.parallel == 'DDP': + seq_input = seq_input.to(device[0]) + tok_type = tok_type.to(device[0]) + target_ns_labels = target_ns_labels.to(device[0]) + else: + seq_input = seq_input.to(device) + tok_type = tok_type.to(device) + target_ns_labels = target_ns_labels.to(device) + seq_input = seq_input.transpose(0, 1) # Wrap up by DDP or DataParallel + ns_labels = model(seq_input, token_type_input=tok_type) + loss = criterion(ns_labels, target_ns_labels) + total_loss += loss.item() + return total_loss / (len(data_source) // batch_size) + + +def train(train_dataset, model, train_loss_log, device, optimizer, criterion, + epoch, scheduler, cls_id, sep_id, pad_id, args, rank=None): + model.train() + total_loss = 0. + start_time = time.time() + batch_size = args.batch_size + dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, + collate_fn=lambda b: collate_batch(b, args, cls_id, sep_id, pad_id)) + train_loss_log.append(0.0) + for idx, (seq_input, tok_type, target_ns_labels) in enumerate(dataloader): + if args.parallel == 'DDP': + seq_input = seq_input.to(device[0]) + tok_type = tok_type.to(device[0]) + target_ns_labels = target_ns_labels.to(device[0]) + else: + seq_input = seq_input.to(device) + tok_type = tok_type.to(device) + target_ns_labels = target_ns_labels.to(device) + optimizer.zero_grad() + seq_input = seq_input.transpose(0, 1) # Wrap up by DDP or DataParallel + ns_labels = model(seq_input, token_type_input=tok_type) + loss = criterion(ns_labels, target_ns_labels) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer.step() + total_loss += loss.item() + if idx % args.log_interval == 0 and idx > 0: + cur_loss = total_loss / args.log_interval + elapsed = time.time() - start_time + if (rank is None) or rank == 0: + train_loss_log[-1] = cur_loss + print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ' + 'ms/batch {:5.2f} | ' + 'loss {:8.5f} | ppl {:5.2f}'.format(epoch, idx, + len(train_dataset) // batch_size, + scheduler.get_last_lr()[0], + elapsed * 1000 / args.log_interval, + cur_loss, math.exp(cur_loss))) + total_loss = 0 + start_time = time.time() + + +def run_main(args, rank=None): + # Set the random seed manually for reproducibility. + torch.manual_seed(args.seed) + if args.parallel == 'DDP': + n = torch.cuda.device_count() // args.world_size + device = list(range(rank * n, (rank + 1) * n)) + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + vocab = torch.load(args.save_vocab) + cls_id = vocab.stoi[''] + pad_id = vocab.stoi[''] + sep_id = vocab.stoi[''] + + if args.dataset == 'WikiText103': + from torchtext.experimental.datasets import WikiText103 + train_dataset, valid_dataset, test_dataset = WikiText103(vocab=vocab, + single_line=False) + elif args.dataset == 'BookCorpus': + from data import BookCorpus + train_dataset, test_dataset, valid_dataset = BookCorpus(vocab, min_sentence_len=60) + + if rank is not None: + chunk_len = len(train_dataset.data) // args.world_size + train_dataset.data = train_dataset.data[(rank * chunk_len):((rank + 1) * chunk_len)] + + if args.checkpoint != 'None': + model = torch.load(args.checkpoint) + else: + pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) + pretrained_bert.load_state_dict(torch.load(args.bert_model)) + model = NextSentenceTask(pretrained_bert) + + if args.parallel == 'DDP': + model = model.to(device[0]) + model = DDP(model, device_ids=device) + else: + model = model.to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) + best_val_loss = None + train_loss_log, val_loss_log = [], [] + + for epoch in range(1, args.epochs + 1): + epoch_start_time = time.time() + train(process_raw_data(train_dataset, args), model, train_loss_log, device, optimizer, + criterion, epoch, scheduler, cls_id, sep_id, pad_id, args, rank) + val_loss = evaluate(process_raw_data(valid_dataset, args), model, device, criterion, + cls_id, sep_id, pad_id, args) + val_loss_log.append(val_loss) + + if (rank is None) or (rank == 0): + print('-' * 89) + print('| end of epoch {:3d} | time: {:5.2f}s ' + '| valid loss {:8.5f} | '.format(epoch, + (time.time() - epoch_start_time), + val_loss)) + print('-' * 89) + if not best_val_loss or val_loss < best_val_loss: + if rank is None: + with open(args.save, 'wb') as f: + torch.save(model, f) + elif rank == 0: + with open(args.save, 'wb') as f: + torch.save(model.state_dict(), f) + best_val_loss = val_loss + else: + scheduler.step() + if args.parallel == 'DDP': + rank0_devices = [x - rank * len(device) for x in device] + device_pairs = zip(rank0_devices, device) + map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} + model.load_state_dict(torch.load(args.save, map_location=map_location)) + test_loss = evaluate(process_raw_data(test_dataset, args), model, device, criterion, + cls_id, sep_id, pad_id, args) + if rank == 0: + wrap_up(train_loss_log, val_loss_log, test_loss, args, model.module, 'ns_loss.txt', 'ns_model.pt') + else: + with open(args.save, 'rb') as f: + model = torch.load(f) + + test_loss = evaluate(process_raw_data(test_dataset, args), model, device, criterion, + cls_id, sep_id, pad_id, args) + wrap_up(train_loss_log, val_loss_log, test_loss, args, model, 'ns_loss.txt', 'ns_model.pt') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Question-Answer fine-tuning task') + parser.add_argument('--dataset', type=str, default='WikiText103', + help='dataset used for next sentence task') + parser.add_argument('--lr', type=float, default=0.25, + help='initial learning rate') + parser.add_argument('--clip', type=float, default=0.1, + help='gradient clipping') + parser.add_argument('--epochs', type=int, default=5, + help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=24, metavar='N', + help='batch size') + parser.add_argument('--bptt', type=int, default=128, + help='max. sequence length for the next-sentence pair') + parser.add_argument('--min_sentence_len', type=int, default=60, + help='min. sequence length for the raw text tokens') + parser.add_argument('--seed', type=int, default=312216194, + help='random seed') + parser.add_argument('--cuda', action='store_true', + help='use CUDA') + parser.add_argument('--log-interval', type=int, default=600, metavar='N', + help='report interval') + parser.add_argument('--checkpoint', type=str, default='None', + help='path to load the checkpoint') + parser.add_argument('--save', type=str, default='ns_bert.pt', + help='path to save the bert model') + parser.add_argument('--save-vocab', type=str, default='torchtext_bert_vocab.pt', + help='path to save the vocab') + parser.add_argument('--bert-model', type=str, default='mlm_bert.pt', + help='path to save the pretrained bert') + parser.add_argument('--frac_ns', type=float, default=0.5, + help='fraction of not next sentence') + parser.add_argument('--parallel', type=str, default='None', + help='Use DataParallel/DDP to train model') + parser.add_argument('--world_size', type=int, default=8, + help='the world size to initiate DPP') + parser.add_argument('--emsize', type=int, default=768, + help='size of word embeddings') + parser.add_argument('--nhid', type=int, default=3072, + help='number of hidden units per layer') + parser.add_argument('--nlayers', type=int, default=12, + help='number of layers') + parser.add_argument('--nhead', type=int, default=12, + help='the number of heads in the encoder/decoder of the transformer model') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + args = parser.parse_args() + + if args.parallel == 'DDP': + run_demo(run_ddp, run_main, args) + else: + run_main(args) diff --git a/examples/BERT/qa_task.py b/examples/BERT/qa_task.py new file mode 100644 index 0000000000..a92f67be05 --- /dev/null +++ b/examples/BERT/qa_task.py @@ -0,0 +1,213 @@ +import argparse +import time +import math +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torchtext +from torchtext.experimental.datasets import SQuAD1 +from model import QuestionAnswerTask +from metrics import compute_qa_exact, compute_qa_f1 +from utils import print_loss_log +from model import BertModel + + +def process_raw_data(data): + _data = [] + for item in data: + right_length = True + for _idx in range(len(item['ans_pos'])): + if item['ans_pos'][_idx][1] + item['question'].size(0) + 2 >= args.bptt: + right_length = False + if right_length: + _data.append(item) + return _data + + +def collate_batch(batch): + seq_list = [] + ans_pos_list = [] + tok_type = [] + for item in batch: + qa_item = torch.cat((torch.tensor([cls_id]), item['question'], torch.tensor([sep_id]), + item['context'], torch.tensor([sep_id]))) + if qa_item.size(0) > args.bptt: + qa_item = qa_item[:args.bptt] + elif qa_item.size(0) < args.bptt: + qa_item = torch.cat((qa_item, + torch.tensor([pad_id] * (args.bptt - + qa_item.size(0))))) + seq_list.append(qa_item) + pos_list = [pos + item['question'].size(0) + 2 for pos in item['ans_pos']] # 1 for sep and 1 for cls + ans_pos_list.append(pos_list) + tok_type.append(torch.cat((torch.zeros((item['question'].size(0) + 2)), + torch.ones((args.bptt - + item['question'].size(0) - 2))))) + _ans_pos_list = [] + for pos in zip(*ans_pos_list): + _ans_pos_list.append(torch.stack(list(pos))) + return torch.stack(seq_list).long().t().contiguous().to(device), \ + _ans_pos_list, \ + torch.stack(tok_type).long().t().contiguous().to(device) + + +def evaluate(data_source, vocab): + model.eval() + total_loss = 0. + batch_size = args.batch_size + dataloader = DataLoader(data_source, batch_size=batch_size, shuffle=True, + collate_fn=collate_batch) + ans_pred_tokens_samples = [] + with torch.no_grad(): + for idx, (seq_input, ans_pos_list, tok_type) in enumerate(dataloader): + start_pos, end_pos = model(seq_input, token_type_input=tok_type) + target_start_pos, target_end_pos = [], [] + for item in ans_pos_list: + _target_start_pos, _target_end_pos = item.to(device).split(1, dim=-1) + target_start_pos.append(_target_start_pos.squeeze(-1)) + target_end_pos.append(_target_end_pos.squeeze(-1)) + loss = (criterion(start_pos, target_start_pos[0]) + + criterion(end_pos, target_end_pos[0])) / 2 + total_loss += loss.item() + start_pos = nn.functional.softmax(start_pos, dim=1).argmax(1) + end_pos = nn.functional.softmax(end_pos, dim=1).argmax(1) + seq_input = seq_input.transpose(0, 1) # convert from (S, N) to (N, S) + for num in range(0, seq_input.size(0)): + if int(start_pos[num]) > int(end_pos[num]): + continue # start pos is in front of end pos + ans_tokens = [] + for _idx in range(len(target_end_pos)): + ans_tokens.append([vocab.itos[int(seq_input[num][i])] + for i in range(target_start_pos[_idx][num], + target_end_pos[_idx][num] + 1)]) + pred_tokens = [vocab.itos[int(seq_input[num][i])] + for i in range(start_pos[num], + end_pos[num] + 1)] + ans_pred_tokens_samples.append((ans_tokens, pred_tokens)) + return total_loss / (len(data_source) // batch_size), \ + compute_qa_exact(ans_pred_tokens_samples), \ + compute_qa_f1(ans_pred_tokens_samples) + + +def train(): + model.train() + total_loss = 0. + start_time = time.time() + batch_size = args.batch_size + dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, + collate_fn=collate_batch) + train_loss_log.append(0.0) + for idx, (seq_input, ans_pos, tok_type) in enumerate(dataloader): + optimizer.zero_grad() + start_pos, end_pos = model(seq_input, token_type_input=tok_type) + target_start_pos, target_end_pos = ans_pos[0].to(device).split(1, dim=-1) + target_start_pos = target_start_pos.squeeze(-1) + target_end_pos = target_end_pos.squeeze(-1) + loss = (criterion(start_pos, target_start_pos) + criterion(end_pos, target_end_pos)) / 2 + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer.step() + total_loss += loss.item() + if idx % args.log_interval == 0 and idx > 0: + cur_loss = total_loss / args.log_interval + train_loss_log[-1] = cur_loss + elapsed = time.time() - start_time + print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ' + 'ms/batch {:5.2f} | ' + 'loss {:5.2f} | ppl {:8.2f}'.format(epoch, idx, + len(train_dataset) // batch_size, + scheduler.get_last_lr()[0], + elapsed * 1000 / args.log_interval, + cur_loss, math.exp(cur_loss))) + total_loss = 0 + start_time = time.time() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Question-Answer fine-tuning task') + parser.add_argument('--lr', type=float, default=5.0, + help='initial learning rate') + parser.add_argument('--clip', type=float, default=0.1, + help='gradient clipping') + parser.add_argument('--epochs', type=int, default=2, + help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=72, metavar='N', + help='batch size') + parser.add_argument('--bptt', type=int, default=128, + help='max. sequence length for context + question') + parser.add_argument('--seed', type=int, default=21192391, + help='random seed') + parser.add_argument('--log-interval', type=int, default=200, metavar='N', + help='report interval') + parser.add_argument('--save', type=str, default='qa_model.pt', + help='path to save the final bert model') + parser.add_argument('--save-vocab', type=str, default='torchtext_bert_vocab.pt', + help='path to save the vocab') + parser.add_argument('--bert-model', type=str, default='ns_bert.pt', + help='path to save the pretrained bert') + parser.add_argument('--emsize', type=int, default=768, + help='size of word embeddings') + parser.add_argument('--nhid', type=int, default=3072, + help='number of hidden units per layer') + parser.add_argument('--nlayers', type=int, default=12, + help='number of layers') + parser.add_argument('--nhead', type=int, default=12, + help='the number of heads in the encoder/decoder of the transformer model') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + args = parser.parse_args() + torch.manual_seed(args.seed) + + try: + vocab = torch.load(args.save_vocab) + except: + train_dataset, dev_dataset = SQuAD1() + old_vocab = train_dataset.vocab + vocab = torchtext.vocab.Vocab(counter=old_vocab.freqs, + specials=['', '', '']) + with open(args.save_vocab, 'wb') as f: + torch.save(vocab, f) + pad_id = vocab.stoi[''] + sep_id = vocab.stoi[''] + cls_id = vocab.stoi[''] + train_dataset, dev_dataset = SQuAD1(vocab=vocab) + train_dataset = process_raw_data(train_dataset) + dev_dataset = process_raw_data(dev_dataset) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) + pretrained_bert.load_state_dict(torch.load(args.bert_model)) + model = QuestionAnswerTask(pretrained_bert).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) + best_f1 = None + train_loss_log, val_loss_log = [], [] + + for epoch in range(1, args.epochs + 1): + epoch_start_time = time.time() + train() + val_loss, val_exact, val_f1 = evaluate(dev_dataset, vocab) + val_loss_log.append(val_loss) + print('-' * 89) + print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' + 'exact {:8.3f}% | ' + 'f1 {:8.3f}%'.format(epoch, (time.time() - epoch_start_time), + val_loss, val_exact, val_f1)) + print('-' * 89) + if best_f1 is None or val_f1 > best_f1: + with open(args.save, 'wb') as f: + torch.save(model, f) + best_f1 = val_f1 + else: + scheduler.step() + + with open(args.save, 'rb') as f: + model = torch.load(f) + test_loss, test_exact, test_f1 = evaluate(dev_dataset, vocab) + print('=' * 89) + print('| End of training | test loss {:5.2f} | exact {:8.3f}% | f1 {:8.3f}%'.format( + test_loss, test_exact, test_f1)) + print('=' * 89) + print_loss_log('qa_loss.txt', train_loss_log, val_loss_log, test_loss) + with open(args.save, 'wb') as f: + torch.save(model, f) diff --git a/examples/BERT/utils.py b/examples/BERT/utils.py new file mode 100644 index 0000000000..94cf371663 --- /dev/null +++ b/examples/BERT/utils.py @@ -0,0 +1,58 @@ +import torch +import torch.distributed as dist +import os +import torch.multiprocessing as mp +import math + + +def setup(rank, world_size, seed): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + # Explicitly setting seed to make sure that models created in two processes + # start from same random weights and biases. + torch.manual_seed(seed) + + +def cleanup(): + dist.destroy_process_group() + + +def run_demo(demo_fn, main_fn, args): + mp.spawn(demo_fn, + args=(main_fn, args,), + nprocs=args.world_size, + join=True) + + +def run_ddp(rank, main_fn, args): + setup(rank, args.world_size, args.seed) + main_fn(args, rank) + cleanup() + + +def print_loss_log(file_name, train_loss, val_loss, test_loss, args=None): + with open(file_name, 'w') as f: + if args: + for item in args.__dict__: + f.write(item + ': ' + str(args.__dict__[item]) + '\n') + for idx in range(len(train_loss)): + f.write('epoch {:3d} | train loss {:8.5f}'.format(idx + 1, + train_loss[idx]) + '\n') + for idx in range(len(val_loss)): + f.write('epoch {:3d} | val loss {:8.5f}'.format(idx + 1, + val_loss[idx]) + '\n') + f.write('test loss {:8.5f}'.format(test_loss) + '\n') + + +def wrap_up(train_loss_log, val_loss_log, test_loss, args, model, ns_loss_log, model_filename): + print('=' * 89) + print('| End of training | test loss {:8.5f} | test ppl {:8.5f}'.format(test_loss, math.exp(test_loss))) + print('=' * 89) + print_loss_log(ns_loss_log, train_loss_log, val_loss_log, test_loss) + with open(args.save, 'wb') as f: + torch.save(model.bert_model.state_dict(), f) + with open(model_filename, 'wb') as f: + torch.save(model.state_dict(), f)