Skip to content

Improve indexation of the cummulated_loss_tensor in the Trainer class #123

@le1nux

Description

@le1nux

As seen here,

post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]),
, the cumulated_loss_and_gradient_norm is indexed directly in the train function, adding unnecessary complexity to the code.

A better solution would be to encapsulate the cumulated_loss_and_gradient_norm within a dedicated class that implements also the reduce operations.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions