Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ module = [
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.profilers.simple",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.trainer.callback_hook",
Expand Down
19 changes: 11 additions & 8 deletions src/pytorch_lightning/profilers/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
"""
super().__init__(dirpath=dirpath, filename=filename)
self.current_actions: Dict[str, float] = {}
self.recorded_durations = defaultdict(list)
self.recorded_durations: Dict = defaultdict(list)
self.extended = extended
self.start_time = time.monotonic()

Expand Down Expand Up @@ -104,20 +104,23 @@ def summary(self) -> str:
if len(self.recorded_durations) > 0:
max_key = max(len(k) for k in self.recorded_durations.keys())

def log_row(action, mean, num_calls, total, per):
def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|"
row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row

header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
header_string = log_row_extended(
"Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-' * output_string_len}"
output_string += sep_lines + header_string + sep_lines
report, total_calls, total_duration = self._make_report_extended()
output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
report_extended: _TABLE_DATA_EXTENDED
report_extended, total_calls, total_duration = self._make_report_extended()
output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
output_string += sep_lines
for action, mean_duration, num_calls, total_duration, duration_per in report:
output_string += log_row(
for action, mean_duration, num_calls, total_duration, duration_per in report_extended:
output_string += log_row_extended(
action,
f"{mean_duration:.5}",
f"{num_calls}",
Expand All @@ -128,7 +131,7 @@ def log_row(action, mean, num_calls, total, per):
else:
max_key = max(len(k) for k in self.recorded_durations)

def log_row(action, mean, total):
def log_row(action: str, mean: str, total: str) -> str:
return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|"

header_string = log_row("Action", "Mean duration (s)", "Total time (s)")
Expand Down