diff --git a/pyproject.toml b/pyproject.toml index 8db782df357d8..b5e806bc69900 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ module = [ "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", - "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.trainer.callback_hook", diff --git a/src/pytorch_lightning/profilers/simple.py b/src/pytorch_lightning/profilers/simple.py index 20d76f9b2d378..0fb9497ff17fb 100644 --- a/src/pytorch_lightning/profilers/simple.py +++ b/src/pytorch_lightning/profilers/simple.py @@ -60,7 +60,7 @@ def __init__( """ super().__init__(dirpath=dirpath, filename=filename) self.current_actions: Dict[str, float] = {} - self.recorded_durations = defaultdict(list) + self.recorded_durations: Dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -104,20 +104,23 @@ def summary(self) -> str: if len(self.recorded_durations) > 0: max_key = max(len(k) for k in self.recorded_durations.keys()) - def log_row(action, mean, num_calls, total, per): + def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str: row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row - header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") + header_string = log_row_extended( + "Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %" + ) output_string_len = len(header_string.expandtabs()) sep_lines = f"{sep}{'-' * output_string_len}" output_string += sep_lines + header_string + sep_lines - report, total_calls, total_duration = self._make_report_extended() - output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") + report_extended: _TABLE_DATA_EXTENDED + report_extended, total_calls, total_duration = self._make_report_extended() + output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") output_string += sep_lines - for action, mean_duration, num_calls, total_duration, duration_per in report: - output_string += log_row( + for action, mean_duration, num_calls, total_duration, duration_per in report_extended: + output_string += log_row_extended( action, f"{mean_duration:.5}", f"{num_calls}", @@ -128,7 +131,7 @@ def log_row(action, mean, num_calls, total, per): else: max_key = max(len(k) for k in self.recorded_durations) - def log_row(action, mean, total): + def log_row(action: str, mean: str, total: str) -> str: return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|" header_string = log_row("Action", "Mean duration (s)", "Total time (s)")