1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from contextlib import contextmanager
15- from functools import wraps
1615from typing import Any , Generator , List , Tuple
1716
1817import torch
1918import torch .nn as nn
2019from torch .optim import Optimizer
2120
2221from pytorch_lightning .core .lightning import LightningModule
22+ from pytorch_lightning .overrides .base import _LightningPrecisionModuleWrapperBase
2323from pytorch_lightning .plugins .precision .precision_plugin import PrecisionPlugin
2424from pytorch_lightning .utilities .apply_func import apply_to_collection
2525
2626
27- class _DoublePrecisionPatch :
28- """Class to handle patching of methods in the ``LightningModule`` and subsequent teardown."""
27+ class LightningDoublePrecisionModule (_LightningPrecisionModuleWrapperBase ):
28+ """
29+ LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double
30+ (``torch.float64``) precision.
2931
30- def __init__ (self , model : nn .Module , method_name : str , old_method : Any ) -> None :
31- self .model = model
32- self .method_name = method_name
33- self .old_method = old_method
32+ Args:
33+ pl_module: the model to wrap
34+ """
3435
35- def teardown (self ) -> None :
36- setattr ( self . model , self . method_name , self . old_method )
36+ def __init__ (self , pl_module : LightningModule ) :
37+ super (). __init__ ( pl_module )
3738
3839 @staticmethod
3940 def _to_double_precision (data : torch .Tensor ) -> torch .Tensor :
@@ -43,55 +44,63 @@ def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
4344
4445 @staticmethod
4546 def _move_float_tensors_to_double (collection : Any ) -> Any :
46- return apply_to_collection (collection , torch .Tensor , function = _DoublePrecisionPatch ._to_double_precision )
47-
48- @classmethod
49- def patch (cls , model : nn .Module , method_name : str ) -> '_DoublePrecisionPatch' :
50- old_method = getattr (model , method_name )
51-
52- @wraps (old_method )
53- def new_method (* args : Any , ** kwargs : Any ) -> Any :
54- return old_method (
55- * _DoublePrecisionPatch ._move_float_tensors_to_double (args ),
56- ** _DoublePrecisionPatch ._move_float_tensors_to_double (kwargs )
57- )
58-
59- setattr (model , method_name , new_method if callable (old_method ) else old_method )
60- return cls (model , method_name , old_method )
47+ return apply_to_collection (
48+ collection ,
49+ torch .Tensor ,
50+ LightningDoublePrecisionModule ._to_double_precision ,
51+ )
52+
53+ def training_step (self , * args : Any , ** kwargs : Any ) -> Any :
54+ return self .module .training_step (
55+ * LightningDoublePrecisionModule ._move_float_tensors_to_double (args ),
56+ ** LightningDoublePrecisionModule ._move_float_tensors_to_double (kwargs ),
57+ )
58+
59+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Any :
60+ return self .module .validation_step (
61+ * LightningDoublePrecisionModule ._move_float_tensors_to_double (args ),
62+ ** LightningDoublePrecisionModule ._move_float_tensors_to_double (kwargs ),
63+ )
64+
65+ def test_step (self , * args : Any , ** kwargs : Any ) -> Any :
66+ return self .module .test_step (
67+ * LightningDoublePrecisionModule ._move_float_tensors_to_double (args ),
68+ ** LightningDoublePrecisionModule ._move_float_tensors_to_double (kwargs ),
69+ )
70+
71+ def predict_step (self , * args : Any , ** kwargs : Any ) -> Any :
72+ return self .module .predict_step (
73+ * LightningDoublePrecisionModule ._move_float_tensors_to_double (args ),
74+ ** LightningDoublePrecisionModule ._move_float_tensors_to_double (kwargs ),
75+ )
76+
77+ def forward (self , * args : Any , ** kwargs : Any ) -> Any :
78+ return self .module (
79+ * LightningDoublePrecisionModule ._move_float_tensors_to_double (args ),
80+ ** LightningDoublePrecisionModule ._move_float_tensors_to_double (kwargs ),
81+ )
6182
6283
6384class DoublePrecisionPlugin (PrecisionPlugin ):
64- """Plugin for training with double (``torch.float64``) precision."""
85+ """ Plugin for training with double (``torch.float64``) precision. """
6586
6687 precision : int = 64
6788
68- def __init__ (self ) -> None :
69- super ().__init__ ()
70- self .patches : List [_DoublePrecisionPatch ] = []
71-
7289 def connect (
7390 self ,
7491 model : nn .Module ,
7592 optimizers : List [Optimizer ],
7693 lr_schedulers : List [Any ],
77- ) -> Tuple [nn .Module , List [Optimizer ], List [Any ]]:
78- """Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`,
79- `predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter
80- `optimizers` or `lr_schedulers`."""
94+ ) -> Tuple [nn .Module , List ['Optimizer' ], List [Any ]]:
95+ """Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert
96+ incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
97+ `lr_schedulers`.
98+ """
8199 model = model .to (dtype = torch .float64 )
82- if isinstance (model , LightningModule ):
83- self .patches .append (_DoublePrecisionPatch .patch (model , 'training_step' ))
84- self .patches .append (_DoublePrecisionPatch .patch (model , 'validation_step' ))
85- self .patches .append (_DoublePrecisionPatch .patch (model , 'test_step' ))
86- self .patches .append (_DoublePrecisionPatch .patch (model , 'predict_step' ))
87- self .patches .append (_DoublePrecisionPatch .patch (model , 'forward' ))
100+ model = LightningDoublePrecisionModule (model )
88101
89102 return super ().connect (model , optimizers , lr_schedulers )
90103
91- def post_dispatch (self ) -> None :
92- while len (self .patches ) > 0 :
93- self .patches .pop ().teardown ()
94-
95104 @contextmanager
96105 def train_step_context (self ) -> Generator [None , None , None ]:
97106 """
0 commit comments