Skip to content

Accuracy metric for preds at half precision is zero with pl=1.0.8 #5013

@luzuku

Description

@luzuku

🐛 Bug

The accuracy metric is wrong if preds are given with half precision. See example.

To Reproduce

import torch
from pytorch_lightning.metrics import Accuracy

acc = Accuracy(threshold=0.5)
target = torch.Tensor([1, 1, 0, 0])
preds = torch.Tensor([0.7, 0.4, 0.8, 0.4])

print(acc(preds, target))  -> 0.5
print(acc(preds.half(), target))  -> 0.0

Expected behavior

The accuracy metric should not fail silently. Either an Error needs to be raised when preds are half precision or it should work correctly.

Environment

  • PyTorch Version (e.g., 1.0): 1.7.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.8
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: ...
  • Any other relevant information:

Additional context

This might already be fixed in master. I filed the issue regardless because I don't have time to check.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions