-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
The Trainer class uses its self.fit directly in several places. That causes problems when the Trainer class is subclassed and the behavior of the fit method altered. As it's now impossible to call the test method because it uses self.fit internally.
Since things like trainer.validate and GridSearch are not yet implemented in PyTorchLightning the user should be allowed to subclass and modify those methods without worrying about breaking the Trainer code.
Please reproduce using the BoringModel
https://colab.research.google.com/drive/1jcq5g0XF_3xTLR0p7ZuEMfHcowtZhf4j?usp=sharing
Expected behavior
SeedSearchTrainer uses the fit method of the base class (pl.Trainer) when test method of the base class is called.
Additional context
This could be solved using name mangling.
Store a local reference to the fit method by adding a __fit attribute to the Trainer class.
class Trainer():
...
__fit = fitThen on lines:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L920
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L981
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L1075
replace self.fit by self.__fit
What this achieves is that whenever the self.__fit method is called the name mangling would change its name to _Trainer__fit to assure that the method of the Trainer class is called rather than some overwritten version of the method from the derived class.
This approach can also be implemented for other methods that the user may want to subclass and that are used internally in the base class as self.method.