-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Correctly reset metric objects in self.log #7055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d713d49
0e6c141
747b643
6c8e5c2
7184738
794eeff
ee8e623
128e12a
d52cdc6
783cfd3
d6bd61a
2969903
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -233,6 +233,18 @@ def auto_reduce_results_on_epoch_end(self) -> None: | |
|
|
||
| self.has_reduced = True | ||
|
|
||
| def reset(self) -> None: | ||
| """ | ||
| Call at the end of epoch to reset Result objects | ||
| """ | ||
| for dl_idx in range(self.num_dataloaders): | ||
| epoch_metrics = self._internals[dl_idx] if not self.has_reduced else self._internals_reduced[dl_idx] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In current usage, |
||
| if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: | ||
| for opt_idx in list(epoch_metrics): | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| epoch_metrics[opt_idx].reset() | ||
|
Comment on lines
+242
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you explain this check? at the surface, inside the batch train loop reads like we're not at the epoch end?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not completely sure about what is going on myself, but apparently the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. I think But I'd also like to clear all this entirely at some point. |
||
| else: | ||
| epoch_metrics.reset() | ||
|
|
||
| def __getitem__(self, key: str) -> Any: | ||
| return self._internals.get(key, None) | ||
|
|
||
|
|
@@ -262,6 +274,7 @@ def __init__(self, trainer: 'pl.Trainer') -> None: | |
| _should_warn = trainer.accelerator_connector.is_distributed | ||
| _should_warn &= not trainer.training_type_plugin.rpc_enabled | ||
| self._should_warn = _should_warn | ||
| self._internals = {} | ||
|
|
||
| self.reset() | ||
|
|
||
|
|
@@ -442,7 +455,9 @@ def get_epoch_log_metrics(self) -> Dict: | |
| def get_forked_metrics(self) -> Dict: | ||
| return self.run_epoch_by_func_name("get_forked_metrics") | ||
|
|
||
| def reset(self): | ||
| def reset(self) -> None: | ||
| for k, value in self._internals.items(): | ||
| value.reset() | ||
| self._internals = {} | ||
| self._dataloader_idx: Optional[int] = None | ||
| self._split_idx: Optional[int] = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.