Skip to content

[BUG] train.py call train_one_epoch without device parameter #2488

@frostylight

Description

@frostylight

Describe the bug

def train_one_epoch(
epoch,
model,
loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),

train_metrics = train_one_epoch(
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
args,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_dtype=model_dtype,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
)

the device parameter does not pass to the train_one_epoch
it remains to be default(cuda) and later be used

input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions