|
1 | 1 | import pytest |
2 | 2 | import torch |
3 | 3 |
|
4 | | -from pytorch_lightning import seed_everything |
5 | 4 | from pytorch_lightning.metrics.functional.classification import dice_score |
6 | 5 |
|
7 | 6 |
|
8 | | -@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ |
9 | | - pytest.param(1, 1., 42), |
10 | | - pytest.param(None, 1., 42), |
11 | | -]) |
12 | | -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): |
13 | | - # TODO: move back the pred and target to test func arguments |
14 | | - # if you fix the array inside the function, you'd also have fix the shape, |
15 | | - # because when the array changes, you also have to fix the shape |
16 | | - seed_everything(0) |
17 | | - pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100 |
18 | | - target = torch.tensor([0, 1] * 50, dtype=torch.int) |
19 | | - if sample_weight is not None: |
20 | | - sample_weight = torch.ones_like(pred) * sample_weight |
21 | | - |
22 | | - fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) |
23 | | - |
24 | | - assert isinstance(tps, torch.Tensor) |
25 | | - assert isinstance(fps, torch.Tensor) |
26 | | - assert isinstance(thresh, torch.Tensor) |
27 | | - assert tps.shape == (exp_shape, ) |
28 | | - assert fps.shape == (exp_shape, ) |
29 | | - assert thresh.shape == (exp_shape, ) |
30 | | - |
31 | | - |
32 | 7 | @pytest.mark.parametrize(['pred', 'target', 'expected'], [ |
33 | 8 | pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), |
34 | 9 | pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), |
|
0 commit comments