Skip to content

Commit 28f7fb4

Browse files
committed
isort
1 parent abadc86 commit 28f7fb4

File tree

6 files changed

+23
-15
lines changed

6 files changed

+23
-15
lines changed

pytorch_lightning/metrics/classification/precision_recall.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
1515

16-
import torch
17-
1816
from torchmetrics import Precision as _Precision
1917
from torchmetrics import Recall as _Recall
2018

pytorch_lightning/metrics/functional/precision_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
import torch
1717
from torchmetrics.functional import precision as _precision
18-
from torchmetrics.functional import recall as _recall
1918
from torchmetrics.functional import precision_recall as _precision_recall
19+
from torchmetrics.functional import recall as _recall
2020

2121
from pytorch_lightning.utilities.deprecation import deprecated
2222

pytorch_lightning/metrics/functional/precision_recall_curve.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pytorch_lightning.utilities.deprecation import deprecated
2020

2121

22-
2322
@deprecated(target=_precision_recall_curve, ver_deprecate="1.3.0", ver_remove="1.5.0")
2423
def precision_recall_curve(
2524
preds: torch.Tensor,
@@ -28,7 +27,7 @@ def precision_recall_curve(
2827
pos_label: Optional[int] = None,
2928
sample_weights: Optional[Sequence] = None,
3029
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
31-
List[torch.Tensor]],]:
30+
List[torch.Tensor]], ]:
3231
"""
3332
.. deprecated::
3433
Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0.

pytorch_lightning/utilities/deprecation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import inspect
1515
from functools import wraps
16-
from typing import Any, Callable, List, Tuple, Optional
16+
from typing import Any, Callable, List, Optional, Tuple
1717

1818
from pytorch_lightning.utilities import rank_zero_warn
1919

tests/metrics/test_remove_1-5_metrics.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@
1616
import pytest
1717
import torch
1818

19-
from pytorch_lightning.metrics import (Accuracy, MetricCollection, AveragePrecision, Precision, Recall,
20-
PrecisionRecallCurve)
21-
from pytorch_lightning.metrics.functional import (average_precision, precision, recall, precision_recall,
22-
precision_recall_curve)
19+
from pytorch_lightning.metrics import (
20+
Accuracy,
21+
AveragePrecision,
22+
MetricCollection,
23+
Precision,
24+
PrecisionRecallCurve,
25+
Recall,
26+
)
27+
from pytorch_lightning.metrics.functional import (
28+
average_precision,
29+
precision,
30+
precision_recall,
31+
precision_recall_curve,
32+
recall,
33+
)
2334
from pytorch_lightning.metrics.functional.accuracy import accuracy
2435
from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot
2536

@@ -58,7 +69,7 @@ def test_v1_5_metrics_collection():
5869
MetricCollection.__init__.warned = False
5970
with pytest.deprecated_call(
6071
match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor"
61-
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0."
72+
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0."
6273
):
6374
metrics = MetricCollection([Accuracy()])
6475
assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)}

tests/utilities/test_deprecation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def dep3_sum(a, b=4):
3030
def test_deprecated_func():
3131
with pytest.deprecated_call(
3232
match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor'
33-
' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.'
33+
' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.'
3434
):
3535
assert dep_sum(2) == 7
3636

@@ -41,7 +41,7 @@ def test_deprecated_func():
4141
# and does not affect other functions
4242
with pytest.deprecated_call(
4343
match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor'
44-
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
44+
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
4545
):
4646
assert dep3_sum(2, 1) == 3
4747

@@ -61,7 +61,7 @@ def test_deprecated_func_incomplete():
6161
# does not affect other functions
6262
with pytest.deprecated_call(
6363
match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor'
64-
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
64+
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
6565
):
6666
assert dep2_sum(b=2, a=1) == 3
6767

@@ -83,7 +83,7 @@ def __init__(self, c, d="efg"):
8383
def test_deprecated_class():
8484
with pytest.deprecated_call(
8585
match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor'
86-
' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.'
86+
' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.'
8787
):
8888
past = PastCls(2)
8989
assert past.my_c == 2

0 commit comments

Comments
 (0)