From 8b904e127cf3717675eebecd0236bf073d09a9a0 Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Fri, 2 Dec 2022 14:50:27 +0800 Subject: [PATCH 1/2] Added distributed training support for distillation of CNN-2. Signed-off-by: Xinyu Ye --- .../CNN-2/distillation/eager/main.py | 82 +++++++++++-------- .../CNN-2/distillation/eager/requirements.txt | 1 + 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py b/examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py index e24eb7767ff..685a0109450 100644 --- a/examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py +++ b/examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py @@ -10,6 +10,7 @@ import torchvision.datasets as datasets import torchvision.transforms as transforms +from accelerate import Accelerator from plain_cnn_cifar import ConvNetMaker, plane_cifar100_book # used for logging to TensorBoard @@ -60,6 +61,7 @@ help='loss weights of distillation, should be a list of length 2, ' 'and sum to 1.0, first for student targets loss weight, ' 'second for teacher student loss weight.') +parser.add_argument("--no_cuda", action='store_true', help='use cpu for training.') parser.set_defaults(augment=True) @@ -75,10 +77,13 @@ def set_seed(seed): def main(): global args, best_prec1 args, _ = parser.parse_known_args() + accelerator = Accelerator(cpu=args.no_cuda) + best_prec1 = 0 if args.seed is not None: set_seed(args.seed) - if args.tensorboard: configure("runs/%s" % (args.name)) + with accelerator.local_main_process_first(): + if args.tensorboard: configure("runs/%s"%(args.name)) # Data loading code normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]) @@ -121,9 +126,9 @@ def main(): raise NotImplementedError('Unsupported student model type') # get the number of model parameters - print('Number of teacher model parameters: {}'.format( + accelerator.print('Number of teacher model parameters: {}'.format( sum([p.data.nelement() for p in teacher_model.parameters()]))) - print('Number of student model parameters: {}'.format( + accelerator.print('Number of student model parameters: {}'.format( sum([p.data.nelement() for p in student_model.parameters()]))) kwargs = {'num_workers': 0, 'pin_memory': True} @@ -135,10 +140,10 @@ def main(): if args.loss_weights[1] > 0: from tqdm import tqdm def get_logits(teacher_model, train_dataset): - print("***** Getting logits of teacher model *****") - print(f" Num examples = {len(train_dataset) }") + accelerator.print("***** Getting logits of teacher model *****") + accelerator.print(f" Num examples = {len(train_dataset) }") logits_file = os.path.join(os.path.dirname(args.teacher_model), 'teacher_logits.npy') - if not os.path.exists(logits_file): + if not os.path.exists(logits_file) and accelerator.is_local_main_process: teacher_model.eval() train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) train_dataloader = tqdm(train_dataloader, desc="Evaluating") @@ -147,8 +152,8 @@ def get_logits(teacher_model, train_dataset): outputs = teacher_model(input) teacher_logits += [x for x in outputs.numpy()] np.save(logits_file, np.array(teacher_logits)) - else: - teacher_logits = np.load(logits_file) + accelerator.wait_for_everyone() + teacher_logits = np.load(logits_file) train_dataset.targets = [{'labels':l, 'teacher_logits':tl} \ for l, tl in zip(train_dataset.targets, teacher_logits)] return train_dataset @@ -163,15 +168,15 @@ def get_logits(teacher_model, train_dataset): # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) + accelerator.print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] student_model.load_state_dict(checkpoint['state_dict']) - print("=> loaded checkpoint '{}' (epoch {})" + accelerator.print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: - print("=> no checkpoint found at '{}'".format(args.resume)) + accelerator.print("=> no checkpoint found at '{}'".format(args.resume)) # define optimizer optimizer = torch.optim.SGD(student_model.parameters(), args.lr, @@ -179,13 +184,18 @@ def get_logits(teacher_model, train_dataset): weight_decay=args.weight_decay) # cosine learning rate - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, len(train_loader) * args.epochs // accelerator.num_processes + ) + + student_model, teacher_model, train_loader, val_loader, optimizer = \ + accelerator.prepare(student_model, teacher_model, train_loader, val_loader, optimizer) def train_func(model): - return train(train_loader, model, scheduler, distiller, best_prec1) + return train(train_loader, model, scheduler, distiller, best_prec1, accelerator) def eval_func(model): - return validate(val_loader, model, distiller) + return validate(val_loader, model, distiller, accelerator) from neural_compressor.experimental import Distillation, common from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss @@ -204,11 +214,12 @@ def eval_func(model): directory = "runs/%s/"%(args.name) os.makedirs(directory, exist_ok=True) + model._model = accelerator.unwrap_model(model.model) model.save(directory) # change to framework model for further use model = model.model -def train(train_loader, model, scheduler, distiller, best_prec1): +def train(train_loader, model, scheduler, distiller, best_prec1, accelerator): distiller.on_train_begin() for epoch in range(args.start_epoch, args.epochs): """Train for one epoch on the training set""" @@ -233,13 +244,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1): loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits) # measure accuracy and record loss + output = accelerator.gather(output) + target = accelerator.gather(target) prec1 = accuracy(output.data, target, topk=(1,))[0] - losses.update(loss.data.item(), input.size(0)) - top1.update(prec1.item(), input.size(0)) + losses.update(accelerator.gather(loss).sum().data.item(), input.size(0)*accelerator.num_processes) + top1.update(prec1.item(), input.size(0)*accelerator.num_processes) # compute gradient and do SGD step distiller.optimizer.zero_grad() - loss.backward() + accelerator.backward(loss) # loss.backward() distiller.optimizer.step() scheduler.step() @@ -248,7 +261,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1): end = time.time() if i % args.print_freq == 0: - print('Epoch: [{0}][{1}/{2}]\t' + accelerator.print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' @@ -260,19 +273,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1): # remember best prec@1 and save checkpoint is_best = distiller.best_score > best_prec1 best_prec1 = max(distiller.best_score, best_prec1) - save_checkpoint({ - 'epoch': distiller._epoch_runned + 1, - 'state_dict': model.state_dict(), - 'best_prec1': best_prec1, - }, is_best) - # log to TensorBoard - if args.tensorboard: - log_value('train_loss', losses.avg, epoch) - log_value('train_acc', top1.avg, epoch) - log_value('learning_rate', scheduler._last_lr[0], epoch) + if accelerator.is_local_main_process: + save_checkpoint({ + 'epoch': distiller._epoch_runned + 1, + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + }, is_best) + # log to TensorBoard + if args.tensorboard: + log_value('train_loss', losses.avg, epoch) + log_value('train_acc', top1.avg, epoch) + log_value('learning_rate', scheduler._last_lr[0], epoch) -def validate(val_loader, model, distiller): +def validate(val_loader, model, distiller, accelerator): """Perform validation on the validation set""" batch_time = AverageMeter() top1 = AverageMeter() @@ -287,6 +301,8 @@ def validate(val_loader, model, distiller): output = model(input) # measure accuracy + output = accelerator.gather(output) + target = accelerator.gather(target) prec1 = accuracy(output.data, target, topk=(1,))[0] top1.update(prec1.item(), input.size(0)) @@ -295,15 +311,15 @@ def validate(val_loader, model, distiller): end = time.time() if i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' + accelerator.print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, top1=top1)) - print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) + accelerator.print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) # log to TensorBoard - if args.tensorboard: + if accelerator.is_local_main_process and args.tensorboard: log_value('val_acc', top1.avg, distiller._epoch_runned) return top1.avg diff --git a/examples/pytorch/image_recognition/CNN-2/distillation/eager/requirements.txt b/examples/pytorch/image_recognition/CNN-2/distillation/eager/requirements.txt index 8db2f310ef5..71252629880 100644 --- a/examples/pytorch/image_recognition/CNN-2/distillation/eager/requirements.txt +++ b/examples/pytorch/image_recognition/CNN-2/distillation/eager/requirements.txt @@ -2,3 +2,4 @@ torch==1.5.0+cpu torchvision==0.6.0+cpu tensorboard_logger +accelerate \ No newline at end of file From e4146bccfe624ac1013bd8511f6bce1ebabdd6b2 Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Fri, 2 Dec 2022 14:51:26 +0800 Subject: [PATCH 2/2] modified readme Signed-off-by: Xinyu Ye --- .../CNN-2/distillation/eager/README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/pytorch/image_recognition/CNN-2/distillation/eager/README.md b/examples/pytorch/image_recognition/CNN-2/distillation/eager/README.md index 484ddc0d93a..3180b03d1ac 100644 --- a/examples/pytorch/image_recognition/CNN-2/distillation/eager/README.md +++ b/examples/pytorch/image_recognition/CNN-2/distillation/eager/README.md @@ -9,3 +9,14 @@ python train_without_distillation.py --model_type CNN-10 --epochs 200 --lr 0.1 - # for distillation of the student model CNN-2 with the teacher model CNN-10 python main.py --epochs 200 --lr 0.02 --name CNN-2-distillation --student_type CNN-2 --teacher_type CNN-10 --teacher_model runs/CNN-10/model_best.pth.tar --tensorboard ``` + +We also supported Distributed Data Parallel training on single node and multi nodes settings for distillation. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment. +
+For example, bash command will look like the following, where *``* is the address of the master node, it won't be necessary for single node case, *``* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *``* is the number of nodes to use, *``* is the rank of the current node, rank starts from 0 to *``*`-1`. +
+Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be launched in each node, and all the commands should be the same except for *``*, which should be integer from 0 to *``*`-1` assigned to each node. + +```bash +python -m torch.distributed.launch --master_addr= --nproc_per_node= --nnodes= --node_rank= \ + main.py --epochs 200 --lr 0.02 --name CNN-2-distillation --student_type CNN-2 --teacher_type CNN-10 --teacher_model runs/CNN-10/model_best.pth.tar --tensorboard +``` \ No newline at end of file