Skip to content

Commit 4622bd2

Browse files
committed
add memory metrics to TensorBoard
ghstack-source-id: da7e02b Pull Request resolved: #60
1 parent 196d56e commit 4622bd2

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

torchtrain/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ def get_current_stats(self, return_data: bool = False):
122122
)
123123

124124
display_str = ""
125-
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%,"
126-
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
125+
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%, "
126+
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
127127

128128
self.get_peak_stats(curr_mem)
129129

130130
peak_active_pct = self.get_pct_memory(self.peak_active_memory)
131131
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
132132
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
133-
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
133+
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
134134

135135
display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
136136
if self.num_retries > 0:

train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,18 @@ def main(args):
219219
time_delta * parallel_dims.model_parallel_size
220220
)
221221

222+
gpu_mem_stats = gpu_metrics.get_current_stats(return_data=True)
223+
222224
metrics = {
223-
"global_avg_loss": global_avg_loss,
224-
"global_max_loss": global_max_loss,
225+
"loss_metrics/global_avg_loss": global_avg_loss,
226+
"loss_metrics/global_max_loss": global_max_loss,
225227
"wps": wps,
228+
"memory_current/active(%)": gpu_mem_stats.active_curr,
229+
"memory_current/allocated(%)": gpu_mem_stats.allocated_curr,
230+
"memory_current/reserved(%)": gpu_mem_stats.reserved_curr,
231+
"memory_peak/active(%)": gpu_mem_stats.active_peak,
232+
"memory_peak/allocated(%)": gpu_mem_stats.allocated_peak,
233+
"memory_peak/reserved(%)": gpu_mem_stats.reserved_peak,
226234
}
227235
metric_logger.log(metrics, step=train_state.step)
228236

0 commit comments

Comments
 (0)