1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- """[Train, Eval] Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
14+ """Result class for easier logging and epoch-wise reduction."""
1515
1616import numbers
17- import os
1817from copy import copy
1918from typing import Any , Callable , Dict , Iterable , List , MutableMapping , Optional , Sequence , Tuple , Union
2019
2726
2827class Result (Dict ):
2928
30- def __init__ (
31- self ,
32- minimize : Optional [Tensor ] = None ,
33- early_stop_on : Optional [Tensor ] = None ,
34- checkpoint_on : Optional [Union [Tensor , bool ]] = None ,
35- hiddens : Optional [Tensor ] = None ,
36- ):
37-
29+ def __init__ (self , minimize : Optional [Tensor ] = None ):
3830 super ().__init__ ()
3931
40- # temporary until dict results are deprecated
41- os .environ ['PL_USING_RESULT_OBJ' ] = '1'
42-
43- if early_stop_on is not None :
44- self .early_stop_on = early_stop_on
45- if checkpoint_on is not None and checkpoint_on :
46- self .checkpoint_on = checkpoint_on
47- if hiddens is not None :
48- self .hiddens = hiddens .detach ()
4932 if minimize is not None :
5033 err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
5134 self ._assert_grad_tensor_metric ('minimize' , minimize , err )
5235 self .minimize = minimize
5336
54- if minimize is not None and checkpoint_on is None :
55- self .checkpoint_on = minimize .detach ()
56-
5737 self ['meta' ] = {'_internal' : {'_reduce_on_epoch' : False , 'batch_sizes' : []}}
5838
5939 def __getitem__ (self , key : Union [str , Any ]) -> Any :
@@ -64,9 +44,7 @@ def __getitem__(self, key: Union[str, Any]) -> Any:
6444
6545 def __getattr__ (self , key : str ) -> Any :
6646 try :
67- if key == 'callback_metrics' :
68- return self .get_callback_metrics ()
69- elif key == 'batch_log_metrics' :
47+ if key == 'batch_log_metrics' :
7048 return self .get_batch_log_metrics ()
7149 elif key == 'batch_pbar_metrics' :
7250 return self .get_batch_pbar_metrics ()
@@ -80,16 +58,9 @@ def __getattr__(self, key: str) -> Any:
8058 return None
8159
8260 def __setattr__ (self , key : str , val : Union [Tensor , Any ]):
83- # ensure reserve keys are tensors and detached
84- if key in {'checkpoint_on' , 'early_stop_on' }:
85- self ._assert_tensor_metric (key , val )
86- if val is not None and isinstance (val , torch .Tensor ):
87- val = val .detach ()
88-
89- # ensure anything else that is a tensor is detached
90- elif isinstance (val , torch .Tensor ) and key != 'minimize' :
61+ # ensure tensors are detached
62+ if isinstance (val , torch .Tensor ) and key != 'minimize' :
9163 val = val .detach ()
92-
9364 self [key ] = val
9465
9566 def __getstate__ (self ):
@@ -98,11 +69,6 @@ def __getstate__(self):
9869 def __setstate__ (self , d ):
9970 self .update (d )
10071
101- def _assert_tensor_metric (self , name : str , potential_metric : Union [bool , Tensor , None , Any ]):
102- if potential_metric is not None and not isinstance (potential_metric , bool ):
103- if not isinstance (potential_metric , Tensor ):
104- raise TypeError (f'{ name } must be a torch.Tensor' )
105-
10672 def _assert_grad_tensor_metric (self , name : str , x : Union [torch .Tensor , Any ], additional_err : str = '' ):
10773 if x is not None :
10874 if not isinstance (x , Tensor ):
@@ -272,11 +238,6 @@ def get_batch_sizes(self):
272238 meta = self ['meta' ]
273239 return torch .tensor (meta ['_internal' ]['batch_sizes' ])
274240
275- def get_callback_metrics (self ) -> dict :
276- result = {'early_stop_on' : self .early_stop_on , 'checkpoint_on' : self .checkpoint_on }
277-
278- return result
279-
280241 def _add_dataloader_idx (self , k : str , dataloader_idx : Union [int , None ], add_dataloader_idx : bool ) -> str :
281242 if dataloader_idx is not None and add_dataloader_idx :
282243 return f"{ k } /dataloader_idx_{ dataloader_idx } "
@@ -495,25 +456,22 @@ def padded_gather(cls, outputs):
495456 # find the padding used for other values
496457 default_padding_idx = 0
497458 for name , value in result .items ():
498- if isinstance (value , list ) and len (value ) > 0 and isinstance (value [0 ], torch .Tensor ):
499- if name not in {'checkpoint_on' , 'early_stop_on' , 'minimize' }:
500- default_padding_idx = meta [name ]['tbptt_pad_token' ]
501- break
459+ if (
460+ name != 'minimize' and isinstance (value , list ) and len (value ) > 0
461+ and isinstance (value [0 ], torch .Tensor )
462+ ):
463+ default_padding_idx = meta [name ]['tbptt_pad_token' ]
464+ break
502465
503466 # pad across each key individually
504467 for name , value in result .items ():
505- is_reserved = name in {'checkpoint_on' , 'early_stop_on' , 'minimize' }
506- if isinstance (value , list ) and len (value ) > 0 and isinstance (value [0 ], torch .Tensor ):
507-
508- if is_reserved :
509- padding_key = default_padding_idx
510- else :
511- padding_key = meta [name ]['tbptt_pad_token' ]
468+ if (isinstance (value , list ) and len (value ) > 0 and isinstance (value [0 ], torch .Tensor )):
469+ padding_key = default_padding_idx if name == 'minimize' else meta [name ]['tbptt_pad_token' ]
512470 padded = torch .nn .utils .rnn .pad_sequence (value , batch_first = True , padding_value = padding_key )
513471 result [name ] = padded
514472
515473 # also update the result
516- if meta and not is_reserved :
474+ if meta and name != "minimize" :
517475 meta [name ]['value' ] = padded
518476 if meta :
519477 result ['meta' ] = meta
@@ -581,10 +539,7 @@ def reduce_across_time(cls, time_outputs):
581539 continue
582540
583541 # pick the reduce fx
584- if k in ['checkpoint_on' , 'early_stop_on' , 'minimize' ]:
585- tbptt_reduce_fx = torch .mean
586- else :
587- tbptt_reduce_fx = meta [k ]['tbptt_reduce_fx' ]
542+ tbptt_reduce_fx = torch .mean if k == "minimize" else meta [k ]['tbptt_reduce_fx' ]
588543
589544 if isinstance (value , list ):
590545 value = torch .tensor (value )
@@ -612,10 +567,6 @@ def dp_reduce(self):
612567 def should_reduce_on_epoch_end (self ) -> bool :
613568 return self ['meta' ]['_internal' ]['_reduce_on_epoch' ]
614569
615- def drop_hiddens (self ):
616- if 'hiddens' in self :
617- del self ['hiddens' ]
618-
619570 def rename_keys (self , map_dict : dict ):
620571 """
621572 Maps key values to the target values. Useful when renaming variables in mass.
0 commit comments