Skip to content

Commit dfbb592

Browse files
committed
fix test - reduce metric
1 parent 74d0652 commit dfbb592

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -863,19 +863,18 @@ def test_metric_are_properly_reduced(tmpdir):
863863
class TestingModel(BoringModel):
864864
def __init__(self, *args, **kwargs):
865865
super().__init__()
866-
self.train_acc = pl.metrics.Accuracy()
867866
self.val_acc = pl.metrics.Accuracy()
868867

869868
def training_step(self, batch, batch_idx):
870-
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
871-
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
872-
return super().training_step(batch, batch_idx)
869+
output = super().training_step(batch, batch_idx)
870+
self.log("train_loss", output["loss"])
871+
return output
873872

874873
def validation_step(self, batch, batch_idx):
875-
preds = torch.tensor(0, device=self.device)
876-
targets = torch.tensor(1, device=self.device)
874+
preds = torch.tensor([[0.9, 0.1]], device=self.device)
875+
targets = torch.tensor([1], device=self.device)
877876
if batch_idx < 8:
878-
targets = preds
877+
preds = torch.tensor([[0.1, 0.9]], device=self.device)
879878
self.val_acc(preds, targets)
880879
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
881880
return super().validation_step(batch, batch_idx)
@@ -900,4 +899,4 @@ def validation_step(self, batch, batch_idx):
900899
trainer.fit(model)
901900

902901
assert trainer.callback_metrics["val_acc"] == 8 / 32.
903-
assert "train_acc" in trainer.callback_metrics
902+
assert "train_loss" in trainer.callback_metrics

0 commit comments

Comments
 (0)