-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Lines 989 to 996 in 81900a6
| def train_one_epoch( | |
| epoch, | |
| model, | |
| loader, | |
| optimizer, | |
| loss_fn, | |
| args, | |
| device=torch.device('cuda'), |
Lines 888 to 904 in 81900a6
| train_metrics = train_one_epoch( | |
| epoch, | |
| model, | |
| loader_train, | |
| optimizer, | |
| train_loss_fn, | |
| args, | |
| lr_scheduler=lr_scheduler, | |
| saver=saver, | |
| output_dir=output_dir, | |
| amp_autocast=amp_autocast, | |
| loss_scaler=loss_scaler, | |
| model_dtype=model_dtype, | |
| model_ema=model_ema, | |
| mixup_fn=mixup_fn, | |
| num_updates_total=num_epochs * updates_per_epoch, | |
| ) |
the device parameter does not pass to the train_one_epoch
it remains to be default(cuda) and later be used
Line 1039 in 81900a6
| input, target = input.to(device=device, dtype=model_dtype), target.to(device=device) |
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working