1+ from distutils .version import LooseVersion
12from operator import neg , pos
23
34import pytest
67from pytorch_lightning .metrics .compositional import CompositionalMetric
78from pytorch_lightning .metrics .metric import Metric
89
10+ _MARK_TORCH_LOWER_1_4 = dict (condition = LooseVersion (torch .__version__ ) < LooseVersion ("1.5.0" ),
11+ reason = 'required PT >= 1.5' )
12+ _MARK_TORCH_LOWER_1_5 = dict (condition = LooseVersion (torch .__version__ ) < LooseVersion ("1.6.0" ),
13+ reason = 'required PT >= 1.6' )
14+
915
1016class DummyMetric (Metric ):
1117 def __init__ (self , val_to_return ):
@@ -50,6 +56,7 @@ def test_metrics_add(second_operand, expected_result):
5056 ["second_operand" , "expected_result" ],
5157 [(DummyMetric (3 ), torch .tensor (2 )), (3 , torch .tensor (2 )), (3 , torch .tensor (2 )), (torch .tensor (3 ), torch .tensor (2 ))],
5258)
59+ @pytest .mark .skipif (** _MARK_TORCH_LOWER_1_4 )
5360def test_metrics_and (second_operand , expected_result ):
5461 first_metric = DummyMetric (2 )
5562
@@ -92,6 +99,7 @@ def test_metrics_eq(second_operand, expected_result):
9299 (torch .tensor (2 ), torch .tensor (2 )),
93100 ],
94101)
102+ @pytest .mark .skipif (** _MARK_TORCH_LOWER_1_4 )
95103def test_metrics_floordiv (second_operand , expected_result ):
96104 first_metric = DummyMetric (5 )
97105
@@ -261,6 +269,7 @@ def test_metrics_ne(second_operand, expected_result):
261269 ["second_operand" , "expected_result" ],
262270 [(DummyMetric ([1 , 0 , 3 ]), torch .tensor ([- 1 , - 2 , 3 ])), (torch .tensor ([1 , 0 , 3 ]), torch .tensor ([- 1 , - 2 , 3 ]))],
263271)
272+ @pytest .mark .skipif (** _MARK_TORCH_LOWER_1_4 )
264273def test_metrics_or (second_operand , expected_result ):
265274 first_metric = DummyMetric ([- 1 , - 2 , 3 ])
266275
@@ -277,10 +286,10 @@ def test_metrics_or(second_operand, expected_result):
277286@pytest .mark .parametrize (
278287 ["second_operand" , "expected_result" ],
279288 [
280- (DummyMetric (2 ), torch .tensor (4 )),
281- (2 , torch .tensor (4 )),
282- (2.0 , torch .tensor (4.0 )),
283- (torch .tensor (2 ), torch .tensor (4 )),
289+ pytest . param (DummyMetric (2 ), torch .tensor (4 )),
290+ pytest . param (2 , torch .tensor (4 )),
291+ pytest . param (2.0 , torch .tensor (4.0 ), marks = pytest . mark . skipif ( ** _MARK_TORCH_LOWER_1_5 )),
292+ pytest . param (torch .tensor (2 ), torch .tensor (4 )),
284293 ],
285294)
286295def test_metrics_pow (second_operand , expected_result ):
@@ -297,6 +306,7 @@ def test_metrics_pow(second_operand, expected_result):
297306 ["first_operand" , "expected_result" ],
298307 [(5 , torch .tensor (2 )), (5.0 , torch .tensor (2.0 )), (torch .tensor (5 ), torch .tensor (2 ))],
299308)
309+ @pytest .mark .skipif (** _MARK_TORCH_LOWER_1_4 )
300310def test_metrics_rfloordiv (first_operand , expected_result ):
301311 second_operand = DummyMetric (2 )
302312
@@ -329,8 +339,12 @@ def test_metrics_rmod(first_operand, expected_result):
329339
330340
331341@pytest .mark .parametrize (
332- ["first_operand" , "expected_result" ],
333- [(DummyMetric (2 ), torch .tensor (4 )), (2 , torch .tensor (4 )), (2.0 , torch .tensor (4.0 ))],
342+ "first_operand,expected_result" ,
343+ [
344+ pytest .param (DummyMetric (2 ), torch .tensor (4 )),
345+ pytest .param (2 , torch .tensor (4 )),
346+ pytest .param (2.0 , torch .tensor (4.0 ), marks = pytest .mark .skipif (** _MARK_TORCH_LOWER_1_5 )),
347+ ],
334348)
335349def test_metrics_rpow (first_operand , expected_result ):
336350 second_operand = DummyMetric (2 )
@@ -370,6 +384,7 @@ def test_metrics_rsub(first_operand, expected_result):
370384 (torch .tensor (6 ), torch .tensor (2.0 )),
371385 ],
372386)
387+ @pytest .mark .skipif (** _MARK_TORCH_LOWER_1_4 )
373388def test_metrics_rtruediv (first_operand , expected_result ):
374389 second_operand = DummyMetric (3 )
375390
@@ -408,6 +423,7 @@ def test_metrics_sub(second_operand, expected_result):
408423 (torch .tensor (3 ), torch .tensor (2.0 )),
409424 ],
410425)
426+ @pytest .mark .skipif (** _MARK_TORCH_LOWER_1_4 )
411427def test_metrics_truediv (second_operand , expected_result ):
412428 first_metric = DummyMetric (6 )
413429
0 commit comments