Skip to content

Commit 9669c80

Browse files
tchatoncarmoccarohitgr7
authored andcommitted
[bugfix] remove nan loss in manual optimization (#5121)
* remove nan loss whe missing * Update pytorch_lightning/core/lightning.py Co-authored-by: Carlos Mocholí <[email protected]> * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 13bbf4b commit 9669c80

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,12 +1392,15 @@ def get_progress_bar_dict(self):
13921392
"""
13931393
# call .item() only once but store elements without graphs
13941394
running_train_loss = self.trainer.train_loop.running_loss.mean()
1395-
avg_training_loss = (
1396-
running_train_loss.cpu().item()
1397-
if running_train_loss is not None
1398-
else float("NaN")
1399-
)
1400-
tqdm_dict = {"loss": "{:.3g}".format(avg_training_loss)}
1395+
avg_training_loss = None
1396+
if running_train_loss is not None:
1397+
avg_training_loss = running_train_loss.cpu().item()
1398+
elif self.trainer.train_loop.automatic_optimization:
1399+
avg_training_loss = float('NaN')
1400+
1401+
tqdm_dict = {}
1402+
if avg_training_loss is not None:
1403+
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
14011404

14021405
if self.trainer.truncated_bptt_steps is not None:
14031406
tqdm_dict["split_idx"] = self.trainer.split_idx

0 commit comments

Comments
 (0)