Skip to content

Commit c3fc263

Browse files
XinyuYe-Intellvliang-intel
authored andcommitted
Added distributed training support for distillation of CNN-2. (#208)
Signed-off-by: Xinyu Ye <[email protected]> Signed-off-by: Lv, Liang1 <[email protected]>
1 parent c6f3932 commit c3fc263

File tree

3 files changed

+61
-33
lines changed

3 files changed

+61
-33
lines changed

examples/pytorch/image_recognition/CNN-2/distillation/eager/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,14 @@ python train_without_distillation.py --model_type CNN-10 --epochs 200 --lr 0.1 -
99
# for distillation of the student model CNN-2 with the teacher model CNN-10
1010
python main.py --epochs 200 --lr 0.02 --name CNN-2-distillation --student_type CNN-2 --teacher_type CNN-10 --teacher_model runs/CNN-10/model_best.pth.tar --tensorboard
1111
```
12+
13+
We also supported Distributed Data Parallel training on single node and multi nodes settings for distillation. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment.
14+
<br>
15+
For example, bash command will look like the following, where *`<MASTER_ADDRESS>`* is the address of the master node, it won't be necessary for single node case, *`<NUM_PROCESSES_PER_NODE>`* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *`<NUM_NODES>`* is the number of nodes to use, *`<NODE_RANK>`* is the rank of the current node, rank starts from 0 to *`<NUM_NODES>`*`-1`.
16+
<br>
17+
Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be launched in each node, and all the commands should be the same except for *`<NODE_RANK>`*, which should be integer from 0 to *`<NUM_NODES>`*`-1` assigned to each node.
18+
19+
```bash
20+
python -m torch.distributed.launch --master_addr=<MASTER_ADDRESS> --nproc_per_node=<NUM_PROCESSES_PER_NODE> --nnodes=<NUM_NODES> --node_rank=<NODE_RANK> \
21+
main.py --epochs 200 --lr 0.02 --name CNN-2-distillation --student_type CNN-2 --teacher_type CNN-10 --teacher_model runs/CNN-10/model_best.pth.tar --tensorboard
22+
```

examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torchvision.datasets as datasets
1111
import torchvision.transforms as transforms
1212

13+
from accelerate import Accelerator
1314
from plain_cnn_cifar import ConvNetMaker, plane_cifar100_book
1415

1516
# used for logging to TensorBoard
@@ -60,6 +61,7 @@
6061
help='loss weights of distillation, should be a list of length 2, '
6162
'and sum to 1.0, first for student targets loss weight, '
6263
'second for teacher student loss weight.')
64+
parser.add_argument("--no_cuda", action='store_true', help='use cpu for training.')
6365
parser.set_defaults(augment=True)
6466

6567

@@ -75,10 +77,13 @@ def set_seed(seed):
7577
def main():
7678
global args, best_prec1
7779
args, _ = parser.parse_known_args()
80+
accelerator = Accelerator(cpu=args.no_cuda)
81+
7882
best_prec1 = 0
7983
if args.seed is not None:
8084
set_seed(args.seed)
81-
if args.tensorboard: configure("runs/%s" % (args.name))
85+
with accelerator.local_main_process_first():
86+
if args.tensorboard: configure("runs/%s"%(args.name))
8287

8388
# Data loading code
8489
normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761])
@@ -121,9 +126,9 @@ def main():
121126
raise NotImplementedError('Unsupported student model type')
122127

123128
# get the number of model parameters
124-
print('Number of teacher model parameters: {}'.format(
129+
accelerator.print('Number of teacher model parameters: {}'.format(
125130
sum([p.data.nelement() for p in teacher_model.parameters()])))
126-
print('Number of student model parameters: {}'.format(
131+
accelerator.print('Number of student model parameters: {}'.format(
127132
sum([p.data.nelement() for p in student_model.parameters()])))
128133

129134
kwargs = {'num_workers': 0, 'pin_memory': True}
@@ -135,10 +140,10 @@ def main():
135140
if args.loss_weights[1] > 0:
136141
from tqdm import tqdm
137142
def get_logits(teacher_model, train_dataset):
138-
print("***** Getting logits of teacher model *****")
139-
print(f" Num examples = {len(train_dataset) }")
143+
accelerator.print("***** Getting logits of teacher model *****")
144+
accelerator.print(f" Num examples = {len(train_dataset) }")
140145
logits_file = os.path.join(os.path.dirname(args.teacher_model), 'teacher_logits.npy')
141-
if not os.path.exists(logits_file):
146+
if not os.path.exists(logits_file) and accelerator.is_local_main_process:
142147
teacher_model.eval()
143148
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)
144149
train_dataloader = tqdm(train_dataloader, desc="Evaluating")
@@ -147,8 +152,8 @@ def get_logits(teacher_model, train_dataset):
147152
outputs = teacher_model(input)
148153
teacher_logits += [x for x in outputs.numpy()]
149154
np.save(logits_file, np.array(teacher_logits))
150-
else:
151-
teacher_logits = np.load(logits_file)
155+
accelerator.wait_for_everyone()
156+
teacher_logits = np.load(logits_file)
152157
train_dataset.targets = [{'labels':l, 'teacher_logits':tl} \
153158
for l, tl in zip(train_dataset.targets, teacher_logits)]
154159
return train_dataset
@@ -163,29 +168,34 @@ def get_logits(teacher_model, train_dataset):
163168
# optionally resume from a checkpoint
164169
if args.resume:
165170
if os.path.isfile(args.resume):
166-
print("=> loading checkpoint '{}'".format(args.resume))
171+
accelerator.print("=> loading checkpoint '{}'".format(args.resume))
167172
checkpoint = torch.load(args.resume)
168173
args.start_epoch = checkpoint['epoch']
169174
best_prec1 = checkpoint['best_prec1']
170175
student_model.load_state_dict(checkpoint['state_dict'])
171-
print("=> loaded checkpoint '{}' (epoch {})"
176+
accelerator.print("=> loaded checkpoint '{}' (epoch {})"
172177
.format(args.resume, checkpoint['epoch']))
173178
else:
174-
print("=> no checkpoint found at '{}'".format(args.resume))
179+
accelerator.print("=> no checkpoint found at '{}'".format(args.resume))
175180

176181
# define optimizer
177182
optimizer = torch.optim.SGD(student_model.parameters(), args.lr,
178183
momentum=args.momentum, nesterov = args.nesterov,
179184
weight_decay=args.weight_decay)
180185

181186
# cosine learning rate
182-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs)
187+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
188+
optimizer, len(train_loader) * args.epochs // accelerator.num_processes
189+
)
190+
191+
student_model, teacher_model, train_loader, val_loader, optimizer = \
192+
accelerator.prepare(student_model, teacher_model, train_loader, val_loader, optimizer)
183193

184194
def train_func(model):
185-
return train(train_loader, model, scheduler, distiller, best_prec1)
195+
return train(train_loader, model, scheduler, distiller, best_prec1, accelerator)
186196

187197
def eval_func(model):
188-
return validate(val_loader, model, distiller)
198+
return validate(val_loader, model, distiller, accelerator)
189199

190200
from neural_compressor.experimental import Distillation, common
191201
from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss
@@ -204,11 +214,12 @@ def eval_func(model):
204214

205215
directory = "runs/%s/"%(args.name)
206216
os.makedirs(directory, exist_ok=True)
217+
model._model = accelerator.unwrap_model(model.model)
207218
model.save(directory)
208219
# change to framework model for further use
209220
model = model.model
210221

211-
def train(train_loader, model, scheduler, distiller, best_prec1):
222+
def train(train_loader, model, scheduler, distiller, best_prec1, accelerator):
212223
distiller.on_train_begin()
213224
for epoch in range(args.start_epoch, args.epochs):
214225
"""Train for one epoch on the training set"""
@@ -233,13 +244,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
233244
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)
234245

235246
# measure accuracy and record loss
247+
output = accelerator.gather(output)
248+
target = accelerator.gather(target)
236249
prec1 = accuracy(output.data, target, topk=(1,))[0]
237-
losses.update(loss.data.item(), input.size(0))
238-
top1.update(prec1.item(), input.size(0))
250+
losses.update(accelerator.gather(loss).sum().data.item(), input.size(0)*accelerator.num_processes)
251+
top1.update(prec1.item(), input.size(0)*accelerator.num_processes)
239252

240253
# compute gradient and do SGD step
241254
distiller.optimizer.zero_grad()
242-
loss.backward()
255+
accelerator.backward(loss) # loss.backward()
243256
distiller.optimizer.step()
244257
scheduler.step()
245258

@@ -248,7 +261,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
248261
end = time.time()
249262

250263
if i % args.print_freq == 0:
251-
print('Epoch: [{0}][{1}/{2}]\t'
264+
accelerator.print('Epoch: [{0}][{1}/{2}]\t'
252265
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
253266
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
254267
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
@@ -260,19 +273,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
260273
# remember best prec@1 and save checkpoint
261274
is_best = distiller.best_score > best_prec1
262275
best_prec1 = max(distiller.best_score, best_prec1)
263-
save_checkpoint({
264-
'epoch': distiller._epoch_runned + 1,
265-
'state_dict': model.state_dict(),
266-
'best_prec1': best_prec1,
267-
}, is_best)
268-
# log to TensorBoard
269-
if args.tensorboard:
270-
log_value('train_loss', losses.avg, epoch)
271-
log_value('train_acc', top1.avg, epoch)
272-
log_value('learning_rate', scheduler._last_lr[0], epoch)
276+
if accelerator.is_local_main_process:
277+
save_checkpoint({
278+
'epoch': distiller._epoch_runned + 1,
279+
'state_dict': model.state_dict(),
280+
'best_prec1': best_prec1,
281+
}, is_best)
282+
# log to TensorBoard
283+
if args.tensorboard:
284+
log_value('train_loss', losses.avg, epoch)
285+
log_value('train_acc', top1.avg, epoch)
286+
log_value('learning_rate', scheduler._last_lr[0], epoch)
273287

274288

275-
def validate(val_loader, model, distiller):
289+
def validate(val_loader, model, distiller, accelerator):
276290
"""Perform validation on the validation set"""
277291
batch_time = AverageMeter()
278292
top1 = AverageMeter()
@@ -287,6 +301,8 @@ def validate(val_loader, model, distiller):
287301
output = model(input)
288302

289303
# measure accuracy
304+
output = accelerator.gather(output)
305+
target = accelerator.gather(target)
290306
prec1 = accuracy(output.data, target, topk=(1,))[0]
291307
top1.update(prec1.item(), input.size(0))
292308

@@ -295,15 +311,15 @@ def validate(val_loader, model, distiller):
295311
end = time.time()
296312

297313
if i % args.print_freq == 0:
298-
print('Test: [{0}/{1}]\t'
314+
accelerator.print('Test: [{0}/{1}]\t'
299315
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
300316
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
301317
i, len(val_loader), batch_time=batch_time,
302318
top1=top1))
303319

304-
print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
320+
accelerator.print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
305321
# log to TensorBoard
306-
if args.tensorboard:
322+
if accelerator.is_local_main_process and args.tensorboard:
307323
log_value('val_acc', top1.avg, distiller._epoch_runned)
308324
return top1.avg
309325

examples/pytorch/image_recognition/CNN-2/distillation/eager/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
torch==1.5.0+cpu
33
torchvision==0.6.0+cpu
44
tensorboard_logger
5+
accelerate

0 commit comments

Comments
 (0)