Skip to content

metrics fail if inputs are on gpu #4909

@simonm3

Description

@simonm3

The metrics precision and recall fail if inputs are on gpu.

>>> from pytorch_lightning.metrics import Precision
>>> target = torch.tensor([0, 1, 2, 0, 1, 2]).cuda()
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1]).cuda()
>>> precision = Precision(num_classes=3)
>>> precision(preds, target)
tensor(0.3333)
RuntimeError                              Traceback (most recent call last)
<ipython-input-39-834a61ec6e2d> in <module>
      3 preds = torch.tensor([0, 2, 1, 0, 0, 1]).cuda()
      4 precision = Precision(num_classes=3)
----> 5 precision(preds, target)
      6 tensor(0.3333)

/anaconda/envs/azureml_py36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/anaconda/envs/azureml_py36/lib/python3.6/site-packages/pytorch_lightning/metrics/metric.py in forward(self, *args, **kwargs)
    154         # add current step
    155         with torch.no_grad():
--> 156             self.update(*args, **kwargs)
    157         self._forward_cache = None
    158 

/anaconda/envs/azureml_py36/lib/python3.6/site-packages/pytorch_lightning/metrics/metric.py in wrapped_func(*args, **kwargs)
    200         def wrapped_func(*args, **kwargs):
    201             self._computed = None
--> 202             return update(*args, **kwargs)
    203         return wrapped_func
    204 

/anaconda/envs/azureml_py36/lib/python3.6/site-packages/pytorch_lightning/metrics/classification/precision_recall.py in update(self, preds, target)
    130 
    131         # multiply because we are counting (1, 1) pair for true positives
--> 132         self.true_positives += torch.sum(preds * target, dim=1)
    133         self.predicted_positives += torch.sum(preds, dim=1)
    134 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onpriority: 1Medium priority task

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions