Skip to content

Commit 1de53be

Browse files
authored
Simplify the gradient clipping code. (#4896)
1 parent f676f94 commit 1de53be

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

references/classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
4040
if args.clip_grad_norm is not None:
4141
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
4242
scaler.unscale_(optimizer)
43-
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
43+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
4444
scaler.step(optimizer)
4545
scaler.update()
4646
else:
4747
loss.backward()
4848
if args.clip_grad_norm is not None:
49-
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
49+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
5050
optimizer.step()
5151

5252
if model_ema and i % args.model_ema_steps == 0:

references/classification/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,3 @@ def reduce_across_processes(val):
409409
dist.barrier()
410410
dist.all_reduce(t)
411411
return t
412-
413-
414-
def get_optimizer_params(optimizer):
415-
"""Generator to iterate over all parameters in the optimizer param_groups."""
416-
417-
for group in optimizer.param_groups:
418-
for p in group["params"]:
419-
yield p

0 commit comments

Comments
 (0)