-
Couldn't load subscription status.
- Fork 3.6k
Description
🚀 Feature
Motivation
Pitch
The LoggerConnector Logic is pretty opaque and hard to follow.
The EpochResultStore and HookResult add an extra layer of complexity and the tests are possibly too sparse to catch wrong behaviours.
One of the reason of the complexity is the non-uniformity of the stored logged data.
Description of internal functionalities:
- An EpochResultStore is create for each Trainer RUNNING_STAGE.
2 . A new Result Object is created when running a new hook.
Result Object are enhanced dictionary containing a mapping key - value with the extra logged meta data and inferred batch_size. - Store this Result Object in the associated EpochResultStore.
How Result Object are stored is different between TRAIN and TEST/VALIDATION making the code complex and hard to follow. - On batch_end: Get the latest stored Result Object and provide its values to logger and progress_bar based on meta.
- On epoch_end: Reduce the values and provide them to logger and progress_bar based on meta.
As Logged value can either be a Metric or a float/tensor creating extra internal check for properly reduce on EpochEnd.
Proposition: Uniformize Logged Values to simplify storing them and reduction.
TODOs:
- Simplify Result Object internally
- Create 1 Result Object for the entire loop.
- Storage: RunningStage -> hook_name -> dataloader_idx -> LoggedMetric
- Create a LoggedTensorMetric
Here is the pseudo code for the LoggedMetric. It will wrap both Metric + tensors and greatly simplify the internal way to store information.
It would also make fault tolerant training simpler as the state could be reduced and stored/reloaded as 'weighted_mean, sum(batch_sizes)'
import torchmetrics
class LoggedMetric(torchmetrics.Metrics):
def __init__(self, key, meta, wrap_metric: bool = False):
self.key = key
self.meta = meta
self.wrap_metric = wrap_metric
if not self.wrap_metric:
self.add_state("value", default=torch.tensor(0))
self.add_state("batch_sizes", defauft=[])
def update(self, value, batch_size):
if not self.wrap_metric:
self.value += value
self.batch_sizes.append(batch_size)
else:
if not isinstance(value, torchmetrics.Metrics):
raise Mis...
if not hasattr(self, "value"):
self.value = value
else:
if not self.value != value:
raise Mis...
def compute(self):
if not self.wrap_metric:
return weighted_mean(self.value, self.batch_sizes)
else:
return self.value.compute()
@property
def on_epoch(self) -> bool:
return self.meta["on_epoch"]
@property
def on_step(self) -> bool:
return self.meta["on_step"]
...