Skip to content

Commit 34cceb7

Browse files
authored
Merge branch 'release/1.2-dev' into refactor/legacy-accel-plug
2 parents 9edf083 + 86d905c commit 34cceb7

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

tests/metrics/test_composition.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from distutils.version import LooseVersion
12
from operator import neg, pos
23

34
import pytest
@@ -6,6 +7,11 @@
67
from pytorch_lightning.metrics.compositional import CompositionalMetric
78
from pytorch_lightning.metrics.metric import Metric
89

10+
_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
11+
reason='required PT >= 1.5')
12+
_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
13+
reason='required PT >= 1.6')
14+
915

1016
class DummyMetric(Metric):
1117
def __init__(self, val_to_return):
@@ -50,6 +56,7 @@ def test_metrics_add(second_operand, expected_result):
5056
["second_operand", "expected_result"],
5157
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
5258
)
59+
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
5360
def test_metrics_and(second_operand, expected_result):
5461
first_metric = DummyMetric(2)
5562

@@ -92,6 +99,7 @@ def test_metrics_eq(second_operand, expected_result):
9299
(torch.tensor(2), torch.tensor(2)),
93100
],
94101
)
102+
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
95103
def test_metrics_floordiv(second_operand, expected_result):
96104
first_metric = DummyMetric(5)
97105

@@ -261,6 +269,7 @@ def test_metrics_ne(second_operand, expected_result):
261269
["second_operand", "expected_result"],
262270
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
263271
)
272+
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
264273
def test_metrics_or(second_operand, expected_result):
265274
first_metric = DummyMetric([-1, -2, 3])
266275

@@ -277,10 +286,10 @@ def test_metrics_or(second_operand, expected_result):
277286
@pytest.mark.parametrize(
278287
["second_operand", "expected_result"],
279288
[
280-
(DummyMetric(2), torch.tensor(4)),
281-
(2, torch.tensor(4)),
282-
(2.0, torch.tensor(4.0)),
283-
(torch.tensor(2), torch.tensor(4)),
289+
pytest.param(DummyMetric(2), torch.tensor(4)),
290+
pytest.param(2, torch.tensor(4)),
291+
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
292+
pytest.param(torch.tensor(2), torch.tensor(4)),
284293
],
285294
)
286295
def test_metrics_pow(second_operand, expected_result):
@@ -297,6 +306,7 @@ def test_metrics_pow(second_operand, expected_result):
297306
["first_operand", "expected_result"],
298307
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
299308
)
309+
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
300310
def test_metrics_rfloordiv(first_operand, expected_result):
301311
second_operand = DummyMetric(2)
302312

@@ -329,8 +339,12 @@ def test_metrics_rmod(first_operand, expected_result):
329339

330340

331341
@pytest.mark.parametrize(
332-
["first_operand", "expected_result"],
333-
[(DummyMetric(2), torch.tensor(4)), (2, torch.tensor(4)), (2.0, torch.tensor(4.0))],
342+
"first_operand,expected_result",
343+
[
344+
pytest.param(DummyMetric(2), torch.tensor(4)),
345+
pytest.param(2, torch.tensor(4)),
346+
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
347+
],
334348
)
335349
def test_metrics_rpow(first_operand, expected_result):
336350
second_operand = DummyMetric(2)
@@ -370,6 +384,7 @@ def test_metrics_rsub(first_operand, expected_result):
370384
(torch.tensor(6), torch.tensor(2.0)),
371385
],
372386
)
387+
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
373388
def test_metrics_rtruediv(first_operand, expected_result):
374389
second_operand = DummyMetric(3)
375390

@@ -408,6 +423,7 @@ def test_metrics_sub(second_operand, expected_result):
408423
(torch.tensor(3), torch.tensor(2.0)),
409424
],
410425
)
426+
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
411427
def test_metrics_truediv(second_operand, expected_result):
412428
first_metric = DummyMetric(6)
413429

0 commit comments

Comments
 (0)