-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task
Description
The metrics precision and recall fail if inputs are on gpu.
>>> from pytorch_lightning.metrics import Precision
>>> target = torch.tensor([0, 1, 2, 0, 1, 2]).cuda()
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1]).cuda()
>>> precision = Precision(num_classes=3)
>>> precision(preds, target)
tensor(0.3333)
RuntimeError Traceback (most recent call last)
<ipython-input-39-834a61ec6e2d> in <module>
3 preds = torch.tensor([0, 2, 1, 0, 0, 1]).cuda()
4 precision = Precision(num_classes=3)
----> 5 precision(preds, target)
6 tensor(0.3333)
/anaconda/envs/azureml_py36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/anaconda/envs/azureml_py36/lib/python3.6/site-packages/pytorch_lightning/metrics/metric.py in forward(self, *args, **kwargs)
154 # add current step
155 with torch.no_grad():
--> 156 self.update(*args, **kwargs)
157 self._forward_cache = None
158
/anaconda/envs/azureml_py36/lib/python3.6/site-packages/pytorch_lightning/metrics/metric.py in wrapped_func(*args, **kwargs)
200 def wrapped_func(*args, **kwargs):
201 self._computed = None
--> 202 return update(*args, **kwargs)
203 return wrapped_func
204
/anaconda/envs/azureml_py36/lib/python3.6/site-packages/pytorch_lightning/metrics/classification/precision_recall.py in update(self, preds, target)
130
131 # multiply because we are counting (1, 1) pair for true positives
--> 132 self.true_positives += torch.sum(preds * target, dim=1)
133 self.predicted_positives += torch.sum(preds, dim=1)
134
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task