Skip to content

LoggerConnector Refactor #7183

@tchaton

Description

@tchaton

🚀 Feature

Motivation

Pitch

The LoggerConnector Logic is pretty opaque and hard to follow.
The EpochResultStore and HookResult add an extra layer of complexity and the tests are possibly too sparse to catch wrong behaviours.

One of the reason of the complexity is the non-uniformity of the stored logged data.

Description of internal functionalities:

  1. An EpochResultStore is create for each Trainer RUNNING_STAGE.
    2 . A new Result Object is created when running a new hook.
    Result Object are enhanced dictionary containing a mapping key - value with the extra logged meta data and inferred batch_size.
  2. Store this Result Object in the associated EpochResultStore.
    How Result Object are stored is different between TRAIN and TEST/VALIDATION making the code complex and hard to follow.
  3. On batch_end: Get the latest stored Result Object and provide its values to logger and progress_bar based on meta.
  4. On epoch_end: Reduce the values and provide them to logger and progress_bar based on meta.
    As Logged value can either be a Metric or a float/tensor creating extra internal check for properly reduce on EpochEnd.

Proposition: Uniformize Logged Values to simplify storing them and reduction.

TODOs:

  • Simplify Result Object internally
  • Create 1 Result Object for the entire loop.
  • Storage: RunningStage -> hook_name -> dataloader_idx -> LoggedMetric
  • Create a LoggedTensorMetric

Here is the pseudo code for the LoggedMetric. It will wrap both Metric + tensors and greatly simplify the internal way to store information.
It would also make fault tolerant training simpler as the state could be reduced and stored/reloaded as 'weighted_mean, sum(batch_sizes)'

import torchmetrics

class LoggedMetric(torchmetrics.Metrics):

      def __init__(self, key, meta, wrap_metric: bool = False):
             self.key = key
             self.meta = meta
             self.wrap_metric = wrap_metric
             if not self.wrap_metric:
                self.add_state("value", default=torch.tensor(0))
                self.add_state("batch_sizes", defauft=[])

      def update(self, value, batch_size):
        if not self.wrap_metric:
            self.value += value
            self.batch_sizes.append(batch_size)
        else:
            if not isinstance(value, torchmetrics.Metrics):
                raise Mis...

            if not hasattr(self, "value"):
                self.value = value
            else:
                if not self.value != value:
                    raise Mis...

    def compute(self):
        if not self.wrap_metric:
            return weighted_mean(self.value, self.batch_sizes)
        else:
            return self.value.compute()

    @property
    def on_epoch(self) -> bool:
        return self.meta["on_epoch"]

    @property
    def on_step(self) -> bool:
        return self.meta["on_step"]

    ...

Metadata

Metadata

Assignees

Labels

designIncludes a design discussionfeatureIs an improvement or enhancementhelp wantedOpen to be worked onrefactor

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions