Skip to content

Commit 29c6bf5

Browse files
tchatoncarmoccarohitgr7
authored andcommitted
[bugfix] remove nan loss in manual optimization (#5121)
* remove nan loss whe missing * Update pytorch_lightning/core/lightning.py Co-authored-by: Carlos Mocholí <[email protected]> * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent df8b676 commit 29c6bf5

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,12 +1393,15 @@ def get_progress_bar_dict(self):
13931393
"""
13941394
# call .item() only once but store elements without graphs
13951395
running_train_loss = self.trainer.train_loop.running_loss.mean()
1396-
avg_training_loss = (
1397-
running_train_loss.cpu().item()
1398-
if running_train_loss is not None
1399-
else float("NaN")
1400-
)
1401-
tqdm_dict = {"loss": "{:.3g}".format(avg_training_loss)}
1396+
avg_training_loss = None
1397+
if running_train_loss is not None:
1398+
avg_training_loss = running_train_loss.cpu().item()
1399+
elif self.trainer.train_loop.automatic_optimization:
1400+
avg_training_loss = float('NaN')
1401+
1402+
tqdm_dict = {}
1403+
if avg_training_loss is not None:
1404+
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
14021405

14031406
if self.trainer.truncated_bptt_steps is not None:
14041407
tqdm_dict["split_idx"] = self.trainer.split_idx

0 commit comments

Comments
 (0)