|
53 | 53 | from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin |
54 | 54 | from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin |
55 | 55 | from pytorch_lightning.trainer.properties import TrainerProperties |
56 | | -from pytorch_lightning.trainer.states import TrainerState |
| 56 | +from pytorch_lightning.trainer.states import RunningStage, TrainerState |
57 | 57 | from pytorch_lightning.trainer.training_loop import TrainLoop |
58 | 58 | from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin |
59 | 59 | from pytorch_lightning.tuner.tuning import Tuner |
@@ -921,3 +921,47 @@ def available_plugins(): |
921 | 921 | Returns: List of all available plugins that are supported as string arguments. |
922 | 922 | """ |
923 | 923 | return PluginConnector.available_plugins() |
| 924 | + |
| 925 | + @property |
| 926 | + def training(self) -> bool: |
| 927 | + return self._running_stage == RunningStage.TRAINING |
| 928 | + |
| 929 | + @training.setter |
| 930 | + def training(self, val: bool) -> None: |
| 931 | + if val: |
| 932 | + self._running_stage = RunningStage.TRAINING |
| 933 | + elif self.training: |
| 934 | + self._running_stage = None |
| 935 | + |
| 936 | + @property |
| 937 | + def testing(self) -> bool: |
| 938 | + return self._running_stage == RunningStage.TESTING |
| 939 | + |
| 940 | + @testing.setter |
| 941 | + def testing(self, val: bool) -> None: |
| 942 | + if val: |
| 943 | + self._running_stage = RunningStage.TESTING |
| 944 | + elif self.testing: |
| 945 | + self._running_stage = None |
| 946 | + |
| 947 | + @property |
| 948 | + def tuning(self) -> bool: |
| 949 | + return self._running_stage == RunningStage.TUNING |
| 950 | + |
| 951 | + @tuning.setter |
| 952 | + def tuning(self, val: bool) -> None: |
| 953 | + if val: |
| 954 | + self._running_stage = RunningStage.TUNING |
| 955 | + elif self.tuning: |
| 956 | + self._running_stage = None |
| 957 | + |
| 958 | + @property |
| 959 | + def evaluating(self) -> bool: |
| 960 | + return self._running_stage == RunningStage.EVALUATING |
| 961 | + |
| 962 | + @evaluating.setter |
| 963 | + def evaluating(self, val: bool) -> None: |
| 964 | + if val: |
| 965 | + self._running_stage = RunningStage.EVALUATING |
| 966 | + elif self.evaluating: |
| 967 | + self._running_stage = None |
0 commit comments