Skip to content

Commit dde8f55

Browse files
committed
fix
1 parent d8d62e2 commit dde8f55

File tree

1 file changed

+0
-25
lines changed

1 file changed

+0
-25
lines changed

tests/metrics/functional/test_classification.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,9 @@
11
import pytest
22
import torch
33

4-
from pytorch_lightning import seed_everything
54
from pytorch_lightning.metrics.functional.classification import dice_score
65

76

8-
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
9-
pytest.param(1, 1., 42),
10-
pytest.param(None, 1., 42),
11-
])
12-
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
13-
# TODO: move back the pred and target to test func arguments
14-
# if you fix the array inside the function, you'd also have fix the shape,
15-
# because when the array changes, you also have to fix the shape
16-
seed_everything(0)
17-
pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100
18-
target = torch.tensor([0, 1] * 50, dtype=torch.int)
19-
if sample_weight is not None:
20-
sample_weight = torch.ones_like(pred) * sample_weight
21-
22-
fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
23-
24-
assert isinstance(tps, torch.Tensor)
25-
assert isinstance(fps, torch.Tensor)
26-
assert isinstance(thresh, torch.Tensor)
27-
assert tps.shape == (exp_shape, )
28-
assert fps.shape == (exp_shape, )
29-
assert thresh.shape == (exp_shape, )
30-
31-
327
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
338
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
349
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),

0 commit comments

Comments
 (0)