1+ # Copyright The PyTorch Lightning team.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
114from typing import Callable , Union
215
316import torch
17+ from torchmetrics .metric import CompositionalMetric as _CompositionalMetric
418
5- from pytorch_lightning .metrics .metric import Metric
19+ from pytorch_lightning .metrics import Metric
20+ from pytorch_lightning .utilities import rank_zero_warn
621
722
8- class CompositionalMetric (Metric ):
9- """Composition of two metrics with a specific operator
10- which will be executed upon metric's compute
23+ class CompositionalMetric (_CompositionalMetric ):
24+ r """
25+ This implementation refers to :class:`~torchmetrics. metric.CompositionalMetric`.
1126
27+ .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0.
1228 """
1329
1430 def __init__ (
@@ -17,76 +33,8 @@ def __init__(
1733 metric_a : Union [Metric , int , float , torch .Tensor ],
1834 metric_b : Union [Metric , int , float , torch .Tensor , None ],
1935 ):
20- """
21-
22- Args:
23- operator: the operator taking in one (if metric_b is None)
24- or two arguments. Will be applied to outputs of metric_a.compute()
25- and (optionally if metric_b is not None) metric_b.compute()
26- metric_a: first metric whose compute() result is the first argument of operator
27- metric_b: second metric whose compute() result is the second argument of operator.
28- For operators taking in only one input, this should be None
29- """
30- super ().__init__ ()
31-
32- self .op = operator
33-
34- if isinstance (metric_a , torch .Tensor ):
35- self .register_buffer ("metric_a" , metric_a )
36- else :
37- self .metric_a = metric_a
38-
39- if isinstance (metric_b , torch .Tensor ):
40- self .register_buffer ("metric_b" , metric_b )
41- else :
42- self .metric_b = metric_b
43-
44- def _sync_dist (self , dist_sync_fn = None ):
45- # No syncing required here. syncing will be done in metric_a and metric_b
46- pass
47-
48- def update (self , * args , ** kwargs ):
49- if isinstance (self .metric_a , Metric ):
50- self .metric_a .update (* args , ** self .metric_a ._filter_kwargs (** kwargs ))
51-
52- if isinstance (self .metric_b , Metric ):
53- self .metric_b .update (* args , ** self .metric_b ._filter_kwargs (** kwargs ))
54-
55- def compute (self ):
56-
57- # also some parsing for kwargs?
58- if isinstance (self .metric_a , Metric ):
59- val_a = self .metric_a .compute ()
60- else :
61- val_a = self .metric_a
62-
63- if isinstance (self .metric_b , Metric ):
64- val_b = self .metric_b .compute ()
65- else :
66- val_b = self .metric_b
67-
68- if val_b is None :
69- return self .op (val_a )
70-
71- return self .op (val_a , val_b )
72-
73- def reset (self ):
74- if isinstance (self .metric_a , Metric ):
75- self .metric_a .reset ()
76-
77- if isinstance (self .metric_b , Metric ):
78- self .metric_b .reset ()
79-
80- def persistent (self , mode : bool = False ):
81- if isinstance (self .metric_a , Metric ):
82- self .metric_a .persistent (mode = mode )
83- if isinstance (self .metric_b , Metric ):
84- self .metric_b .persistent (mode = mode )
85-
86- def __repr__ (self ):
87- repr_str = (
88- self .__class__ .__name__
89- + f"(\n { self .op .__name__ } (\n { repr (self .metric_a )} ,\n { repr (self .metric_b )} \n )\n )"
36+ rank_zero_warn (
37+ "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`."
38+ " It will be removed in v1.5.0" , DeprecationWarning
9039 )
91-
92- return repr_str
40+ super ().__init__ (operator = operator , metric_a = metric_a , metric_b = metric_b )
0 commit comments