Skip to content

Commit ba23d77

Browse files
author
Ubuntu
committed
resolve a second bug
1 parent 2aa76ed commit ba23d77

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

pytorch_lightning/core/step_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
367367
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
368368

369369
if options['forked']:
370-
result[dl_key] = self[k]
370+
if isinstance(self[k], Metric):
371+
result[dl_key] = self[k].compute().detach()
372+
else:
373+
result[dl_key] = self[k]
371374

372375
return result
373376

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -778,24 +778,29 @@ def test_metric_are_properly_reduced(tmpdir):
778778
class TestingModel(BoringModel):
779779
def __init__(self, *args, **kwargs):
780780
super().__init__()
781-
self.acc = pl.metrics.Accuracy()
781+
self.train_acc = pl.metrics.Accuracy()
782+
self.val_acc = pl.metrics.Accuracy()
782783

783784
def training_step(self, batch, batch_idx):
784-
self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
785-
self.log('train_acc', self.acc, on_step=True, on_epoch=True)
785+
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
786+
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
786787
return super().training_step(batch, batch_idx)
787788

788789
def validation_step(self, batch, batch_idx):
789-
self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
790-
self.log('val_acc', self.acc, on_step=True, on_epoch=True)
790+
preds = torch.tensor(0, device=self.device)
791+
targets = torch.tensor(1, device=self.device)
792+
if batch_idx < 8:
793+
targets = preds
794+
self.val_acc(preds, targets)
795+
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
791796
return super().validation_step(batch, batch_idx)
792797

793798
early_stop = EarlyStopping(monitor='val_acc', mode='max')
794799

795800
checkpoint = ModelCheckpoint(
796801
monitor='val_acc',
797802
save_last=True,
798-
save_top_k=5,
803+
save_top_k=2,
799804
mode='max',
800805
)
801806

@@ -804,8 +809,10 @@ def validation_step(self, batch, batch_idx):
804809
default_root_dir=tmpdir,
805810
gpus=1,
806811
max_epochs=2,
812+
limit_train_batches=5,
813+
limit_val_batches=32,
807814
callbacks=[early_stop, checkpoint])
808815
trainer.fit(model)
809816

810-
assert "val_acc" in trainer.callback_metrics
817+
assert trainer.callback_metrics["val_acc"] == 8 / 32.
811818
assert "train_acc" in trainer.callback_metrics

0 commit comments

Comments
 (0)