|
| 1 | +import distutils |
1 | 2 | import inspect |
2 | 3 | import os |
3 | 4 | import sys |
4 | 5 | import warnings |
5 | 6 | from argparse import ArgumentParser |
6 | | -from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence |
7 | | -import distutils |
| 7 | +from typing import Union, Optional, List, Dict, Tuple, Iterable, Any |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | import torch.distributed as torch_distrib |
11 | 11 | import torch.multiprocessing as mp |
12 | | -from torch import optim |
13 | | -from torch.optim.optimizer import Optimizer |
14 | 12 | from torch.utils.data import DataLoader |
15 | 13 | from tqdm.auto import tqdm |
16 | 14 |
|
|
29 | 27 | from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin |
30 | 28 | from pytorch_lightning.trainer.logging import TrainerLoggingMixin |
31 | 29 | from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin |
| 30 | +from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin |
| 31 | +from pytorch_lightning.trainer.supporters import TensorRunningMean |
32 | 32 | from pytorch_lightning.trainer.training_io import TrainerIOMixin |
33 | 33 | from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin |
34 | 34 | from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin |
35 | 35 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
36 | | -from pytorch_lightning.trainer.supporters import TensorRunningMean |
37 | 36 |
|
38 | 37 | try: |
39 | 38 | from apex import amp |
|
54 | 53 |
|
55 | 54 | class Trainer( |
56 | 55 | TrainerIOMixin, |
| 56 | + TrainerOptimizersMixin, |
57 | 57 | TrainerDPMixin, |
58 | 58 | TrainerDDPMixin, |
59 | 59 | TrainerLoggingMixin, |
@@ -712,8 +712,7 @@ def fit( |
712 | 712 |
|
713 | 713 | # CHOOSE OPTIMIZER |
714 | 714 | # allow for lr schedulers as well |
715 | | - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ |
716 | | - self.init_optimizers(model.configure_optimizers()) |
| 715 | + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) |
717 | 716 |
|
718 | 717 | self.run_pretrain_routine(model) |
719 | 718 |
|
@@ -757,90 +756,6 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da |
757 | 756 |
|
758 | 757 | model.test_dataloader = _PatchDataLoader(test_dataloaders) |
759 | 758 |
|
760 | | - def init_optimizers( |
761 | | - self, |
762 | | - optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]] |
763 | | - ) -> Tuple[List, List, List]: |
764 | | - |
765 | | - # single output, single optimizer |
766 | | - if isinstance(optim_conf, Optimizer): |
767 | | - return [optim_conf], [], [] |
768 | | - |
769 | | - # two lists, optimizer + lr schedulers |
770 | | - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): |
771 | | - optimizers, lr_schedulers = optim_conf |
772 | | - lr_schedulers = self.configure_schedulers(lr_schedulers) |
773 | | - return optimizers, lr_schedulers, [] |
774 | | - |
775 | | - # single dictionary |
776 | | - elif isinstance(optim_conf, dict): |
777 | | - optimizer = optim_conf["optimizer"] |
778 | | - lr_scheduler = optim_conf.get("lr_scheduler", []) |
779 | | - if lr_scheduler: |
780 | | - lr_schedulers = self.configure_schedulers([lr_scheduler]) |
781 | | - return [optimizer], lr_schedulers, [] |
782 | | - |
783 | | - # multiple dictionaries |
784 | | - elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): |
785 | | - optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] |
786 | | - # take only lr wif exists and ot they are defined - not None |
787 | | - lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")] |
788 | | - # take only freq wif exists and ot they are defined - not None |
789 | | - optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")] |
790 | | - |
791 | | - # clean scheduler list |
792 | | - if lr_schedulers: |
793 | | - lr_schedulers = self.configure_schedulers(lr_schedulers) |
794 | | - # assert that if frequencies are present, they are given for all optimizers |
795 | | - if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): |
796 | | - raise ValueError("A frequency must be given to each optimizer.") |
797 | | - return optimizers, lr_schedulers, optimizer_frequencies |
798 | | - |
799 | | - # single list or tuple, multiple optimizer |
800 | | - elif isinstance(optim_conf, (list, tuple)): |
801 | | - return list(optim_conf), [], [] |
802 | | - |
803 | | - # unknown configuration |
804 | | - else: |
805 | | - raise ValueError( |
806 | | - 'Unknown configuration for model optimizers.' |
807 | | - ' Output from `model.configure_optimizers()` should either be:' |
808 | | - ' * single output, single `torch.optim.Optimizer`' |
809 | | - ' * single output, list of `torch.optim.Optimizer`' |
810 | | - ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' |
811 | | - ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' |
812 | | - ' * two outputs, first being a list of `torch.optim.Optimizer` second being' |
813 | | - ' a list of `torch.optim.lr_scheduler`' |
814 | | - ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') |
815 | | - |
816 | | - def configure_schedulers(self, schedulers: list): |
817 | | - # Convert each scheduler into dict sturcture with relevant information |
818 | | - lr_schedulers = [] |
819 | | - default_config = {'interval': 'epoch', # default every epoch |
820 | | - 'frequency': 1, # default every epoch/batch |
821 | | - 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler |
822 | | - 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau |
823 | | - for scheduler in schedulers: |
824 | | - if isinstance(scheduler, dict): |
825 | | - if 'scheduler' not in scheduler: |
826 | | - raise ValueError(f'Lr scheduler should have key `scheduler`', |
827 | | - ' with item being a lr scheduler') |
828 | | - scheduler['reduce_on_plateau'] = isinstance( |
829 | | - scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) |
830 | | - |
831 | | - lr_schedulers.append({**default_config, **scheduler}) |
832 | | - |
833 | | - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): |
834 | | - lr_schedulers.append({**default_config, 'scheduler': scheduler, |
835 | | - 'reduce_on_plateau': True}) |
836 | | - |
837 | | - elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): |
838 | | - lr_schedulers.append({**default_config, 'scheduler': scheduler}) |
839 | | - else: |
840 | | - raise ValueError(f'Input {scheduler} to lr schedulers ' |
841 | | - 'is a invalid input.') |
842 | | - return lr_schedulers |
843 | | - |
844 | 759 | def run_pretrain_routine(self, model: LightningModule): |
845 | 760 | """Sanity check a few things before starting actual training. |
846 | 761 |
|
|
0 commit comments