1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import inspect
1615from abc import ABC
17- from collections import Mapping
1816
1917import torch
2018
21- from pytorch_lightning .utilities import DistributedType
22- from pytorch_lightning .utilities .distributed import rank_zero_warn
2319from pytorch_lightning .utilities .exceptions import MisconfigurationException
24- from pytorch_lightning .utilities .memory import recursive_detach
2520
2621
2722class TrainerLoggingMixin (ABC ):
2823
29- # this is just a summary on variables used in this abstract class,
30- # the proper values/initialisation should be done in child class
31- _distrib_type : DistributedType
32- num_gpus : int
33-
3424 def metrics_to_scalars (self , metrics ):
3525 new_metrics = {}
3626 # TODO: this is duplicated in MetricsHolder. should be unified
@@ -49,128 +39,3 @@ def metrics_to_scalars(self, metrics):
4939 new_metrics [k ] = v
5040
5141 return new_metrics
52-
53- def process_dict_result (self , output , train = False ):
54- """Reduces output according to the training mode.
55-
56- Separates loss from logging and progress bar metrics
57- """
58- # --------------------
59- # WARN DEPRECATED KEYS
60- # --------------------
61- # TODO: 1.0.0 remove
62- if isinstance (output , dict ):
63- for k , v in output .items ():
64- if k in ['log' , 'progress_bar' ]:
65- m = inspect .cleandoc (
66- f"The {{{ k } :dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n "
67- " Please use self.log(...) inside the lightningModule instead.\n "
68- " # log on a step or aggregate epoch metric to the logger and/or progress bar"
69- " (inside LightningModule)\n "
70- " self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)"
71- )
72- rank_zero_warn (m )
73-
74- # --------------------------
75- # handle single scalar only
76- # --------------------------
77- # single scalar returned from a xx_step
78- if isinstance (output , torch .Tensor ):
79- return output , {}, {}, None
80-
81- # ---------------
82- # EXTRACT PROGRESS BAR KEYS
83- # ---------------
84- try :
85- progress_output = output ['progress_bar' ]
86-
87- # reduce progress metrics for progress bar when using dp
88- if train and self ._distrib_type in (DistributedType .DP , DistributedType .DDP2 ):
89- num_gpus = self .num_gpus
90- progress_output = self .reduce_distributed_output (progress_output , num_gpus )
91-
92- progress_bar_metrics = progress_output
93- # todo: specify the possible exception
94- except Exception :
95- progress_bar_metrics = {}
96-
97- # ---------------
98- # EXTRACT LOGGING KEYS
99- # ---------------
100- # extract metrics to log to experiment
101- try :
102- log_output = output ['log' ]
103-
104- # reduce progress metrics for progress bar when using dp
105- if train and self ._distrib_type in (DistributedType .DP , DistributedType .DDP2 ):
106- num_gpus = self .num_gpus
107- log_output = self .reduce_distributed_output (log_output , num_gpus )
108-
109- log_metrics = log_output
110- # todo: specify the possible exception
111- except Exception :
112- log_metrics = {}
113-
114- # ---------------
115- # EXTRACT LOSS
116- # ---------------
117- # if output dict doesn't have the keyword loss
118- # then assume the output=loss if scalar
119- loss = None
120- if train :
121- try :
122- loss = output ['loss' ]
123- # todo: specify the possible exception
124- except Exception as exp :
125- if isinstance (output , torch .Tensor ):
126- loss = output
127- else :
128- raise RuntimeError (
129- 'No `loss` value in the dictionary returned from `model.training_step()`.'
130- ) from exp
131-
132- # when using dp need to reduce the loss
133- if self ._distrib_type in (DistributedType .DP , DistributedType .DDP2 ):
134- loss = self .reduce_distributed_output (loss , self .num_gpus )
135-
136- # ---------------
137- # EXTRACT HIDDEN
138- # ---------------
139- hiddens = output .get ('hiddens' , None ) if isinstance (output , Mapping ) else None
140- if hiddens is not None :
141- hiddens = hiddens .detach ()
142-
143- # detach all metrics for callbacks to prevent memory leaks
144- # no .item() because it will slow things down
145- progress_bar_metrics = recursive_detach (progress_bar_metrics )
146- log_metrics = recursive_detach (log_metrics )
147-
148- return loss , progress_bar_metrics , log_metrics , hiddens
149-
150- def reduce_distributed_output (self , output , num_gpus ):
151- if num_gpus <= 1 :
152- return output
153-
154- # when using DP, we get one output per gpu
155- # average outputs and return
156- if isinstance (output , torch .Tensor ):
157- return output .mean ()
158-
159- for k , v in output .items ():
160- # recurse on nested dics
161- if isinstance (output [k ], dict ):
162- output [k ] = self .reduce_distributed_output (output [k ], num_gpus )
163-
164- # compute the average of scalars
165- elif isinstance (output [k ], list ):
166- output [k ] = sum (output [k ]) / len (output [k ])
167-
168- # do nothing when there's a scalar
169- elif isinstance (output [k ], torch .Tensor ) and output [k ].dim () == 0 :
170- pass
171-
172- # do not reduce metrics that have batch size > num gpus
173- elif output [k ].size (0 ) <= num_gpus :
174- output [k ] = torch .mean (output [k ])
175-
176- return output
0 commit comments