From b4781882305595234a239f885095ace62e3a36e8 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 8 Aug 2022 19:18:17 -0400 Subject: [PATCH 1/3] simple changes --- src/pytorch_lightning/profilers/simple.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/profilers/simple.py b/src/pytorch_lightning/profilers/simple.py index 20d76f9b2d378..c12196c34550b 100644 --- a/src/pytorch_lightning/profilers/simple.py +++ b/src/pytorch_lightning/profilers/simple.py @@ -17,7 +17,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -60,7 +60,7 @@ def __init__( """ super().__init__(dirpath=dirpath, filename=filename) self.current_actions: Dict[str, float] = {} - self.recorded_durations = defaultdict(list) + self.recorded_durations: Dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -104,20 +104,23 @@ def summary(self) -> str: if len(self.recorded_durations) > 0: max_key = max(len(k) for k in self.recorded_durations.keys()) - def log_row(action, mean, num_calls, total, per): + def log_row_extended(action: Any, mean: Any, num_calls: Any, total: Any, per: Any) -> str: row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row - header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") + header_string = log_row_extended( + "Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %" + ) output_string_len = len(header_string.expandtabs()) sep_lines = f"{sep}{'-' * output_string_len}" output_string += sep_lines + header_string + sep_lines - report, total_calls, total_duration = self._make_report_extended() - output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") + report_extended: _TABLE_DATA_EXTENDED + report_extended, total_calls, total_duration = self._make_report_extended() + output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") output_string += sep_lines - for action, mean_duration, num_calls, total_duration, duration_per in report: - output_string += log_row( + for action, mean_duration, num_calls, total_duration, duration_per in report_extended: + output_string += log_row_extended( action, f"{mean_duration:.5}", f"{num_calls}", @@ -128,7 +131,7 @@ def log_row(action, mean, num_calls, total, per): else: max_key = max(len(k) for k in self.recorded_durations) - def log_row(action, mean, total): + def log_row(action: Any, mean: Any, total: Any) -> str: return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|" header_string = log_row("Action", "Mean duration (s)", "Total time (s)") From b3a195a8e09f808fb41406411d4823294941abd0 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 8 Aug 2022 19:19:50 -0400 Subject: [PATCH 2/3] simple remove --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8db782df357d8..b5e806bc69900 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ module = [ "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", - "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.trainer.callback_hook", From 69af68ce4877db40c00349468ecc9e9b435a7e0e Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 9 Aug 2022 03:50:04 -0400 Subject: [PATCH 3/3] str function signature from any --- src/pytorch_lightning/profilers/simple.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/profilers/simple.py b/src/pytorch_lightning/profilers/simple.py index c12196c34550b..0fb9497ff17fb 100644 --- a/src/pytorch_lightning/profilers/simple.py +++ b/src/pytorch_lightning/profilers/simple.py @@ -17,7 +17,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -104,7 +104,7 @@ def summary(self) -> str: if len(self.recorded_durations) > 0: max_key = max(len(k) for k in self.recorded_durations.keys()) - def log_row_extended(action: Any, mean: Any, num_calls: Any, total: Any, per: Any) -> str: + def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str: row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row @@ -131,7 +131,7 @@ def log_row_extended(action: Any, mean: Any, num_calls: Any, total: Any, per: An else: max_key = max(len(k) for k in self.recorded_durations) - def log_row(action: Any, mean: Any, total: Any) -> str: + def log_row(action: str, mean: str, total: str) -> str: return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|" header_string = log_row("Action", "Mean duration (s)", "Total time (s)")