1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from unittest import mock
1514from unittest .mock import patch
1615
1716import pytest
2019from pytorch_lightning .demos .boring_classes import BoringModel
2120from pytorch_lightning .loops import TrainingEpochLoop
2221from pytorch_lightning .trainer .trainer import Trainer
23- from tests_pytorch .deprecated_api import no_deprecated_call
2422
2523_out00 = {"loss" : 0.0 }
2624_out01 = {"loss" : 0.1 }
3331
3432
3533class TestPrepareOutputs :
36- def prepare_outputs (self , fn , tbptt_splits , new_format , batch_outputs , num_optimizers , automatic_optimization ):
34+ def prepare_outputs (self , fn , tbptt_splits , batch_outputs , num_optimizers , automatic_optimization ):
3735 lightning_module = LightningModule ()
38- lightning_module .on_train_batch_end = lambda * _ : None # override to trigger the deprecation message
3936 lightning_module .automatic_optimization = automatic_optimization
4037 lightning_module .truncated_bptt_steps = tbptt_splits
41- match = "will change in version v1.8.*new_format=True"
42- will_warn = tbptt_splits and num_optimizers > 1 and not new_format
43- ctx_manager = pytest .deprecated_call if will_warn else no_deprecated_call
44- with ctx_manager (match = match ):
45- with mock .patch (
46- "pytorch_lightning.loops.epoch.training_epoch_loop._v1_8_output_format" , return_value = new_format
47- ):
48- return fn (
49- batch_outputs ,
50- lightning_module = lightning_module ,
51- num_optimizers = num_optimizers , # does not matter for manual optimization
52- )
38+ return fn (
39+ batch_outputs ,
40+ lightning_module = lightning_module ,
41+ num_optimizers = num_optimizers , # does not matter for manual optimization
42+ )
5343
5444 def prepare_outputs_training_epoch_end (
55- self , tbptt_splits , new_format , batch_outputs , num_optimizers , automatic_optimization = True
45+ self , tbptt_splits , batch_outputs , num_optimizers , automatic_optimization = True
5646 ):
5747 return self .prepare_outputs (
5848 TrainingEpochLoop ._prepare_outputs_training_epoch_end ,
5949 tbptt_splits ,
60- new_format ,
6150 batch_outputs ,
6251 num_optimizers ,
6352 automatic_optimization = automatic_optimization ,
6453 )
6554
6655 def prepare_outputs_training_batch_end (
67- self , tbptt_splits , new_format , batch_outputs , num_optimizers , automatic_optimization = True
56+ self , tbptt_splits , batch_outputs , num_optimizers , automatic_optimization = True
6857 ):
6958 return self .prepare_outputs (
7059 TrainingEpochLoop ._prepare_outputs_training_batch_end ,
7160 tbptt_splits ,
72- new_format ,
7361 batch_outputs ,
7462 num_optimizers ,
7563 automatic_optimization = automatic_optimization ,
@@ -97,53 +85,19 @@ def prepare_outputs_training_batch_end(
9785 ),
9886 # 1 batch, tbptt with 2 splits (uneven)
9987 (1 , 2 , [[{0 : _out00 }, {0 : _out01 }], [{0 : _out03 }]], [[_out00 , _out01 ], [_out03 ]]),
100- ],
101- )
102- @pytest .mark .parametrize ("new_format" , (False , True ))
103- def test_prepare_outputs_training_epoch_end_automatic (
104- self , num_optimizers , tbptt_splits , batch_outputs , expected , new_format
105- ):
106- """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
107- currently expects in the case of automatic optimization."""
108- assert (
109- self .prepare_outputs_training_epoch_end (tbptt_splits , new_format , batch_outputs , num_optimizers ) == expected
110- )
111-
112- @pytest .mark .parametrize (
113- "num_optimizers,tbptt_splits,batch_outputs,expected" ,
114- [
115- # 3 batches, tbptt with 2 splits, 2 optimizers alternating
116- (
117- 2 ,
118- 2 ,
119- [[{0 : _out00 }, {0 : _out01 }], [{1 : _out10 }, {1 : _out11 }], [{0 : _out02 }, {0 : _out03 }]],
120- [[[_out00 , _out01 ], [], [_out02 , _out03 ]], [[], [_out10 , _out11 ], []]],
121- )
122- ],
123- )
124- def test_prepare_outputs_training_epoch_end_automatic_old_format (
125- self , num_optimizers , tbptt_splits , batch_outputs , expected
126- ):
127- assert self .prepare_outputs_training_epoch_end (tbptt_splits , False , batch_outputs , num_optimizers ) == expected
128-
129- @pytest .mark .parametrize (
130- "num_optimizers,tbptt_splits,batch_outputs,expected" ,
131- [
13288 # 3 batches, tbptt with 2 splits, 2 optimizers alternating
13389 (
13490 2 ,
13591 2 ,
13692 [[{0 : _out00 }, {0 : _out01 }], [{1 : _out10 }, {1 : _out11 }], [{0 : _out02 }, {0 : _out03 }]],
13793 [[[_out00 ], [_out01 ]], [[_out10 ], [_out11 ]], [[_out02 ], [_out03 ]]],
138- )
94+ ),
13995 ],
14096 )
141- def test_prepare_outputs_training_epoch_end_automatic_new_format (
142- self , num_optimizers , tbptt_splits , batch_outputs , expected
143- ):
97+ def test_prepare_outputs_training_epoch_end_automatic (self , num_optimizers , tbptt_splits , batch_outputs , expected ):
14498 """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
14599 currently expects in the case of automatic optimization."""
146- assert self .prepare_outputs_training_epoch_end (tbptt_splits , True , batch_outputs , num_optimizers ) == expected
100+ assert self .prepare_outputs_training_epoch_end (tbptt_splits , batch_outputs , num_optimizers ) == expected
147101
148102 @pytest .mark .parametrize (
149103 "batch_outputs,expected" ,
@@ -160,14 +114,10 @@ def test_prepare_outputs_training_epoch_end_automatic_new_format(
160114 ([[_out00 , _out01 ], [_out02 , _out03 ], [], [_out10 ]], [[_out00 , _out01 ], [_out02 , _out03 ], [_out10 ]]),
161115 ],
162116 )
163- @pytest .mark .parametrize ("new_format" , (False , True ))
164- def test_prepare_outputs_training_epoch_end_manual (self , batch_outputs , expected , new_format ):
117+ def test_prepare_outputs_training_epoch_end_manual (self , batch_outputs , expected ):
165118 """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
166119 currently expects in the case of manual optimization."""
167- assert (
168- self .prepare_outputs_training_epoch_end (0 , new_format , batch_outputs , - 1 , automatic_optimization = False )
169- == expected
170- )
120+ assert self .prepare_outputs_training_epoch_end (0 , batch_outputs , - 1 , automatic_optimization = False ) == expected
171121
172122 @pytest .mark .parametrize (
173123 "num_optimizers,tbptt_splits,batch_end_outputs,expected" ,
@@ -180,47 +130,17 @@ def test_prepare_outputs_training_epoch_end_manual(self, batch_outputs, expected
180130 (2 , 0 , [{0 : _out00 , 1 : _out01 }], [_out00 , _out01 ]),
181131 # tbptt with 2 splits
182132 (1 , 2 , [{0 : _out00 }, {0 : _out01 }], [_out00 , _out01 ]),
133+ # 2 optimizers, tbptt with 2 splits
134+ (2 , 2 , [{0 : _out00 , 1 : _out01 }, {0 : _out10 , 1 : _out11 }], [[_out00 , _out01 ], [_out10 , _out11 ]]),
183135 ],
184136 )
185- @pytest .mark .parametrize ("new_format" , (False , True ))
186137 def test_prepare_outputs_training_batch_end_automatic (
187- self , num_optimizers , tbptt_splits , batch_end_outputs , expected , new_format
188- ):
189- """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
190- currently expects in the case of automatic optimization."""
191-
192- assert (
193- self .prepare_outputs_training_batch_end (tbptt_splits , new_format , batch_end_outputs , num_optimizers )
194- == expected
195- )
196-
197- @pytest .mark .parametrize (
198- "num_optimizers,tbptt_splits,batch_end_outputs,expected" ,
199- # 2 optimizers, tbptt with 2 splits
200- [(2 , 2 , [{0 : _out00 , 1 : _out01 }, {0 : _out10 , 1 : _out11 }], [[_out00 , _out10 ], [_out01 , _out11 ]])],
201- )
202- def test_prepare_outputs_training_batch_end_automatic_old_format (
203138 self , num_optimizers , tbptt_splits , batch_end_outputs , expected
204139 ):
205140 """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
206141 currently expects in the case of automatic optimization."""
207- assert (
208- self .prepare_outputs_training_batch_end (tbptt_splits , False , batch_end_outputs , num_optimizers ) == expected
209- )
210142
211- @pytest .mark .parametrize (
212- "num_optimizers,tbptt_splits,batch_end_outputs,expected" ,
213- # 2 optimizers, tbptt with 2 splits
214- [(2 , 2 , [{0 : _out00 , 1 : _out01 }, {0 : _out10 , 1 : _out11 }], [[_out00 , _out01 ], [_out10 , _out11 ]])],
215- )
216- def test_prepare_outputs_training_batch_end_automatic_new_format (
217- self , num_optimizers , tbptt_splits , batch_end_outputs , expected
218- ):
219- """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
220- currently expects in the case of automatic optimization."""
221- assert (
222- self .prepare_outputs_training_batch_end (tbptt_splits , True , batch_end_outputs , num_optimizers ) == expected
223- )
143+ assert self .prepare_outputs_training_batch_end (tbptt_splits , batch_end_outputs , num_optimizers ) == expected
224144
225145 @pytest .mark .parametrize (
226146 "batch_end_outputs,expected" ,
@@ -237,8 +157,7 @@ def test_prepare_outputs_training_batch_end_manual(self, batch_end_outputs, expe
237157 """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
238158 currently expects in the case of manual optimization."""
239159 assert (
240- self .prepare_outputs_training_batch_end (0 , False , batch_end_outputs , - 1 , automatic_optimization = False )
241- == expected
160+ self .prepare_outputs_training_batch_end (0 , batch_end_outputs , - 1 , automatic_optimization = False ) == expected
242161 )
243162
244163
0 commit comments