Skip to content

Commit 062f090

Browse files
authored
Merge branch 'release/1.2-dev' into refactor/legacy-accel-plug
2 parents 5d46913 + c3587d3 commit 062f090

File tree

12 files changed

+43
-486
lines changed

12 files changed

+43
-486
lines changed

pytorch_lightning/core/step_result.py

Lines changed: 0 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -700,232 +700,6 @@ def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]:
700700
return items
701701

702702

703-
class EvalResult(Result):
704-
def __init__(
705-
self,
706-
early_stop_on: Optional[Tensor] = None,
707-
checkpoint_on: Optional[Tensor] = None,
708-
hiddens: Optional[Tensor] = None,
709-
):
710-
"""
711-
Used in val/train loop to auto-log to a logger or progress bar without needing to define
712-
a _step_end or _epoch_end method
713-
714-
Example::
715-
716-
def validation_step(self, batch, batch_idx):
717-
loss = ...
718-
result = EvalResult()
719-
result.log('val_loss', loss)
720-
return result
721-
722-
def test_step(self, batch, batch_idx):
723-
loss = ...
724-
result = EvalResult()
725-
result.log('val_loss', loss)
726-
return result
727-
728-
Args:
729-
early_stop_on: Metric to early stop on.
730-
Should be a one element tensor if combined with default
731-
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`.
732-
If this result is returned by
733-
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
734-
the specified value will be averaged across all steps.
735-
checkpoint_on: Metric to checkpoint on.
736-
Should be a one element tensor if combined with default checkpoint callback.
737-
If this result is returned by
738-
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
739-
the specified value will be averaged across all steps.
740-
hiddens:
741-
"""
742-
743-
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
744-
745-
def log(
746-
self,
747-
name,
748-
value,
749-
prog_bar: bool = False,
750-
logger: bool = True,
751-
on_step: bool = False,
752-
on_epoch: bool = True,
753-
reduce_fx: Callable = torch.mean,
754-
tbptt_reduce_fx: Callable = torch.mean,
755-
tbptt_pad_token: int = 0,
756-
enable_graph: bool = False,
757-
sync_dist: bool = False,
758-
sync_dist_op: Union[Any, str] = 'mean',
759-
sync_dist_group: Optional[Any] = None,
760-
):
761-
"""
762-
Log a key, value
763-
764-
Example::
765-
766-
result.log('val_loss', loss)
767-
768-
# defaults used
769-
result.log(
770-
name,
771-
value,
772-
on_step=False,
773-
on_epoch=True,
774-
logger=True,
775-
prog_bar=False,
776-
reduce_fx=torch.mean
777-
)
778-
779-
780-
Args:
781-
name: key name
782-
value: value name
783-
prog_bar: if True logs to the progress base
784-
logger: if True logs to the logger
785-
on_step: if True logs the output of validation_step or test_step
786-
on_epoch: if True, logs the output of the training loop aggregated
787-
reduce_fx: Torch.mean by default
788-
tbptt_reduce_fx: function to reduce on truncated back prop
789-
tbptt_pad_token: token to use for padding
790-
enable_graph: if True, will not auto detach the graph
791-
sync_dist: if True, reduces the metric across GPUs/TPUs
792-
sync_dist_op: the op to sync across
793-
sync_dist_group: the ddp group
794-
"""
795-
super().log(
796-
name=name,
797-
value=value,
798-
prog_bar=prog_bar,
799-
logger=logger,
800-
on_step=on_step,
801-
on_epoch=on_epoch,
802-
reduce_fx=reduce_fx,
803-
enable_graph=enable_graph,
804-
sync_dist=sync_dist,
805-
sync_dist_group=sync_dist_group,
806-
sync_dist_op=sync_dist_op,
807-
tbptt_pad_token=tbptt_pad_token,
808-
tbptt_reduce_fx=tbptt_reduce_fx,
809-
)
810-
811-
def log_dict(
812-
self,
813-
dictionary: dict,
814-
prog_bar: bool = False,
815-
logger: bool = True,
816-
on_step: bool = False,
817-
on_epoch: bool = True,
818-
reduce_fx: Callable = torch.mean,
819-
tbptt_reduce_fx: Callable = torch.mean,
820-
tbptt_pad_token: int = 0,
821-
enable_graph: bool = False,
822-
sync_dist: bool = False,
823-
sync_dist_op: Union[Any, str] = 'mean',
824-
sync_dist_group: Optional[Any] = None,
825-
):
826-
"""
827-
Log a dictonary of values at once
828-
829-
Example::
830-
831-
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
832-
result.log_dict(values)
833-
834-
Args:
835-
dictionary: key value pairs (str, tensors)
836-
prog_bar: if True logs to the progress base
837-
logger: if True logs to the logger
838-
on_step: if True logs the output of validation_step or test_step
839-
on_epoch: if True, logs the output of the training loop aggregated
840-
reduce_fx: Torch.mean by default
841-
tbptt_reduce_fx: function to reduce on truncated back prop
842-
tbptt_pad_token: token to use for padding
843-
enable_graph: if True, will not auto detach the graph
844-
sync_dist: if True, reduces the metric across GPUs/TPUs
845-
sync_dist_op: the op to sync across
846-
sync_dist_group: the ddp group
847-
"""
848-
for k, v in dictionary.items():
849-
self.log(
850-
name=k,
851-
value=v,
852-
prog_bar=prog_bar,
853-
logger=logger,
854-
on_step=on_step,
855-
on_epoch=on_epoch,
856-
reduce_fx=reduce_fx,
857-
enable_graph=enable_graph,
858-
sync_dist=sync_dist,
859-
sync_dist_group=sync_dist_group,
860-
sync_dist_op=sync_dist_op,
861-
tbptt_pad_token=tbptt_pad_token,
862-
tbptt_reduce_fx=tbptt_reduce_fx,
863-
)
864-
865-
def get_callback_metrics(self) -> dict:
866-
result = {}
867-
if self.early_stop_on:
868-
result['early_stop_on'] = self.early_stop_on
869-
if self.checkpoint_on:
870-
result['checkpoint_on'] = self.checkpoint_on
871-
return result
872-
873-
def write(self, name: str, values: Union[Tensor, list], filename: str = 'predictions.pt'):
874-
"""Add feature name and value pair to collection of predictions that will be written to disk on
875-
`validation_end` or `test_end`. If running on multiple GPUs, you will get separate `n_gpu`
876-
prediction files with the rank prepended onto filename.
877-
878-
Example::
879-
880-
result = pl.EvalResult()
881-
result.write('ids', [0, 1, 2])
882-
result.write('preds', ['cat', 'dog', 'dog'])
883-
884-
Args:
885-
name: Feature name that will turn into column header of predictions file
886-
values: Flat tensor or list of row values for given feature column 'name'.
887-
filename: Filepath where your predictions will be saved. Defaults to 'predictions.pt'.
888-
"""
889-
# Type check the incoming arguments
890-
if not isinstance(name, str):
891-
raise ValueError(f"Expected str for 'name' but got {type(name)}")
892-
if not isinstance(filename, str):
893-
raise ValueError(f"Expected str for 'filename' but got {type(name)}")
894-
895-
if isinstance(values, Tensor):
896-
values = values.detach()
897-
898-
preds = getattr(self, 'predictions', None)
899-
if preds is None:
900-
self.predictions = {filename: {name: values}}
901-
elif filename not in preds:
902-
preds[filename] = {name: values}
903-
elif name not in preds[filename]:
904-
preds[filename][name] = values
905-
elif isinstance(values, Tensor):
906-
preds[filename][name] = torch.cat((preds[filename][name], values))
907-
elif isinstance(values, list):
908-
preds[filename][name].extend(values)
909-
910-
def write_dict(self, predictions_dict, filename='predictions.pt'):
911-
"""Calls EvalResult.write() for each key-value pair in predictions_dict.
912-
913-
It is recommended that you use this function call instead of .write if you need to
914-
store more than one column of predictions in your output file.
915-
916-
Example::
917-
918-
predictions_to_write = {'preds': ['cat', 'dog'], 'ids': tensor([0, 1])}
919-
result.write_dict(predictions_to_write)
920-
921-
Args:
922-
predictions_dict ([type]): Dict of predictions to store and then write to filename at eval end.
923-
filename (str, optional): File where your predictions will be stored. Defaults to './predictions.pt'.
924-
"""
925-
for k, v in predictions_dict.items():
926-
self.write(k, v, filename)
927-
928-
929703
def weighted_mean(result, weights):
930704

931705
if isinstance(result, dict):

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from pytorch_lightning.core import memory
22-
from pytorch_lightning.core.step_result import EvalResult, Result
22+
from pytorch_lightning.core.step_result import Result
2323
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2424
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
2525
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
@@ -259,8 +259,8 @@ def add_progress_bar_metrics(self, metrics):
259259

260260
self.trainer.dev_debugger.track_pbar_metrics_history(metrics)
261261

262-
def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode):
263-
self._track_callback_metrics(deprecated_eval_results, using_eval_result)
262+
def track_metrics_deprecated(self, deprecated_eval_results, test_mode):
263+
self._track_callback_metrics(deprecated_eval_results)
264264
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)
265265

266266
def evaluation_epoch_end(self, testing):
@@ -314,53 +314,41 @@ def get_evaluate_epoch_results(self, test_mode):
314314
self.eval_loop_results = []
315315
return results
316316

317-
def _track_callback_metrics(self, eval_results, using_eval_result):
317+
def _track_callback_metrics(self, eval_results):
318318
if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)):
319319
return
320320

321-
if using_eval_result:
322-
if isinstance(eval_results, list):
323-
for eval_result in eval_results:
324-
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
325-
if self.trainer.testing:
326-
self.trainer.logger_connector.evaluation_callback_metrics.update(
327-
eval_result.callback_metrics)
328-
else:
329-
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
330-
if self.trainer.testing:
331-
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
332-
else:
333-
flat = {}
334-
if isinstance(eval_results, list):
335-
for eval_result in eval_results:
336-
# with a scalar return, auto set it to "val_loss" for callbacks
337-
if isinstance(eval_result, torch.Tensor):
338-
flat = {'val_loss': eval_result}
339-
elif isinstance(eval_result, dict):
340-
flat = flatten_dict(eval_result)
341-
342-
# removing val_loss magic word to map to checkpoint + ES callback
343-
if 'val_loss' in flat:
344-
flat['checkpoint_on'] = flat['val_loss']
345-
flat['early_stop_on'] = flat['val_loss']
346-
self.trainer.logger_connector.callback_metrics.update(flat)
347-
if self.trainer.testing:
348-
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
349-
else:
321+
flat = {}
322+
if isinstance(eval_results, list):
323+
for eval_result in eval_results:
350324
# with a scalar return, auto set it to "val_loss" for callbacks
351-
if isinstance(eval_results, torch.Tensor):
352-
flat = {'val_loss': eval_results}
353-
else:
354-
flat = flatten_dict(eval_results)
325+
if isinstance(eval_result, torch.Tensor):
326+
flat = {'val_loss': eval_result}
327+
elif isinstance(eval_result, dict):
328+
flat = flatten_dict(eval_result)
355329

356330
# removing val_loss magic word to map to checkpoint + ES callback
357331
if 'val_loss' in flat:
358332
flat['checkpoint_on'] = flat['val_loss']
359333
flat['early_stop_on'] = flat['val_loss']
360-
361334
self.trainer.logger_connector.callback_metrics.update(flat)
362335
if self.trainer.testing:
363336
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
337+
else:
338+
# with a scalar return, auto set it to "val_loss" for callbacks
339+
if isinstance(eval_results, torch.Tensor):
340+
flat = {'val_loss': eval_results}
341+
else:
342+
flat = flatten_dict(eval_results)
343+
344+
# removing val_loss magic word to map to checkpoint + ES callback
345+
if 'val_loss' in flat:
346+
flat['checkpoint_on'] = flat['val_loss']
347+
flat['early_stop_on'] = flat['val_loss']
348+
349+
self.trainer.logger_connector.callback_metrics.update(flat)
350+
if self.trainer.testing:
351+
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
364352

365353
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
366354
# eval loop returns all metrics
@@ -397,16 +385,7 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
397385
prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}
398386

399387
for result_idx, result in enumerate(eval_results):
400-
if isinstance(result, EvalResult):
401-
prog_bar_metrics = result.epoch_pbar_metrics
402-
log_metrics = result.epoch_log_metrics
403-
callback_metrics = result.callback_metrics
404-
405-
# in testing we don't need the callback metrics
406-
if test_mode:
407-
callback_metrics = {}
408-
else:
409-
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
388+
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
410389

411390
if num_loaders > 1:
412391
self.__process_eval_epoch_end_results_and_log_legacy_update(

0 commit comments

Comments
 (0)