Skip to content

Commit 527b28e

Browse files
Fix mypy errors attributed to pytorch_lightning.profilers.simple (#14103)
1 parent 2abed91 commit 527b28e

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ module = [
5757
"pytorch_lightning.demos.mnist_datamodule",
5858
"pytorch_lightning.profilers.base",
5959
"pytorch_lightning.profilers.pytorch",
60-
"pytorch_lightning.profilers.simple",
6160
"pytorch_lightning.strategies.sharded",
6261
"pytorch_lightning.strategies.sharded_spawn",
6362
"pytorch_lightning.trainer.callback_hook",

src/pytorch_lightning/profilers/simple.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
"""
6161
super().__init__(dirpath=dirpath, filename=filename)
6262
self.current_actions: Dict[str, float] = {}
63-
self.recorded_durations = defaultdict(list)
63+
self.recorded_durations: Dict = defaultdict(list)
6464
self.extended = extended
6565
self.start_time = time.monotonic()
6666

@@ -104,20 +104,23 @@ def summary(self) -> str:
104104
if len(self.recorded_durations) > 0:
105105
max_key = max(len(k) for k in self.recorded_durations.keys())
106106

107-
def log_row(action, mean, num_calls, total, per):
107+
def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
108108
row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|"
109109
row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
110110
return row
111111

112-
header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
112+
header_string = log_row_extended(
113+
"Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %"
114+
)
113115
output_string_len = len(header_string.expandtabs())
114116
sep_lines = f"{sep}{'-' * output_string_len}"
115117
output_string += sep_lines + header_string + sep_lines
116-
report, total_calls, total_duration = self._make_report_extended()
117-
output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
118+
report_extended: _TABLE_DATA_EXTENDED
119+
report_extended, total_calls, total_duration = self._make_report_extended()
120+
output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
118121
output_string += sep_lines
119-
for action, mean_duration, num_calls, total_duration, duration_per in report:
120-
output_string += log_row(
122+
for action, mean_duration, num_calls, total_duration, duration_per in report_extended:
123+
output_string += log_row_extended(
121124
action,
122125
f"{mean_duration:.5}",
123126
f"{num_calls}",
@@ -128,7 +131,7 @@ def log_row(action, mean, num_calls, total, per):
128131
else:
129132
max_key = max(len(k) for k in self.recorded_durations)
130133

131-
def log_row(action, mean, total):
134+
def log_row(action: str, mean: str, total: str) -> str:
132135
return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|"
133136

134137
header_string = log_row("Action", "Mean duration (s)", "Total time (s)")

0 commit comments

Comments
 (0)