@@ -88,6 +88,19 @@ def forward(self, x):
8888 return self .reduce (self .embed (x ))
8989
9090
91+ class PartialScriptModel (LightningModule ):
92+ """ A model which contains scripted layers. """
93+
94+ def __init__ (self ):
95+ super ().__init__ ()
96+ self .layer1 = torch .jit .script (nn .Linear (5 , 3 ))
97+ self .layer2 = nn .Linear (3 , 2 )
98+ self .example_input_array = torch .rand (2 , 5 )
99+
100+ def forward (self , x ):
101+ return self .layer2 (self .layer1 (x ))
102+
103+
91104def test_invalid_weights_summmary ():
92105 """ Test that invalid value for weights_summary raises an error. """
93106 with pytest .raises (MisconfigurationException , match = '`mode` can be None, .* got temp' ):
@@ -97,11 +110,8 @@ def test_invalid_weights_summmary():
97110 Trainer (weights_summary = 'temp' )
98111
99112
100- @pytest .mark .parametrize (['mode' ], [
101- pytest .param (ModelSummary .MODE_FULL ),
102- pytest .param (ModelSummary .MODE_TOP ),
103- ])
104- def test_empty_model_summary_shapes (mode ):
113+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
114+ def test_empty_model_summary_shapes (mode : ModelSummary ):
105115 """ Test that the summary works for models that have no submodules. """
106116 model = EmptyModule ()
107117 summary = model .summarize (mode = mode )
@@ -110,10 +120,7 @@ def test_empty_model_summary_shapes(mode):
110120 assert summary .param_nums == []
111121
112122
113- @pytest .mark .parametrize (['mode' ], [
114- pytest .param (ModelSummary .MODE_FULL ),
115- pytest .param (ModelSummary .MODE_TOP ),
116- ])
123+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
117124@pytest .mark .parametrize (['device' ], [
118125 pytest .param (torch .device ('cpu' )),
119126 pytest .param (torch .device ('cuda' , 0 )),
@@ -157,10 +164,7 @@ def test_mixed_dtype_model_summary():
157164 ]
158165
159166
160- @pytest .mark .parametrize (['mode' ], [
161- pytest .param (ModelSummary .MODE_FULL ),
162- pytest .param (ModelSummary .MODE_TOP ),
163- ])
167+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
164168def test_hooks_removed_after_summarize (mode ):
165169 """ Test that all hooks were properly removed after summary, even ones that were not run. """
166170 model = UnorderedModel ()
@@ -171,10 +175,7 @@ def test_hooks_removed_after_summarize(mode):
171175 assert handle .id not in handle .hooks_dict_ref ()
172176
173177
174- @pytest .mark .parametrize (['mode' ], [
175- pytest .param (ModelSummary .MODE_FULL ),
176- pytest .param (ModelSummary .MODE_TOP ),
177- ])
178+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
178179def test_rnn_summary_shapes (mode ):
179180 """ Test that the model summary works for RNNs. """
180181 model = ParityModuleRNN ()
@@ -198,10 +199,7 @@ def test_rnn_summary_shapes(mode):
198199 ]
199200
200201
201- @pytest .mark .parametrize (['mode' ], [
202- pytest .param (ModelSummary .MODE_FULL ),
203- pytest .param (ModelSummary .MODE_TOP ),
204- ])
202+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
205203def test_summary_parameter_count (mode ):
206204 """ Test that the summary counts the number of parameters in every submodule. """
207205 model = UnorderedModel ()
@@ -215,10 +213,7 @@ def test_summary_parameter_count(mode):
215213 ]
216214
217215
218- @pytest .mark .parametrize (['mode' ], [
219- pytest .param (ModelSummary .MODE_FULL ),
220- pytest .param (ModelSummary .MODE_TOP ),
221- ])
216+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
222217def test_summary_layer_types (mode ):
223218 """ Test that the summary displays the layer names correctly. """
224219 model = UnorderedModel ()
@@ -232,10 +227,16 @@ def test_summary_layer_types(mode):
232227 ]
233228
234229
235- @pytest .mark .parametrize (['mode' ], [
236- pytest .param (ModelSummary .MODE_FULL ),
237- pytest .param (ModelSummary .MODE_TOP ),
238- ])
230+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
231+ def test_summary_with_scripted_modules (mode ):
232+ model = PartialScriptModel ()
233+ summary = model .summarize (mode = mode )
234+ assert summary .layer_types == ["RecursiveScriptModule" , "Linear" ]
235+ assert summary .in_sizes == [UNKNOWN_SIZE , [2 , 3 ]]
236+ assert summary .out_sizes == [UNKNOWN_SIZE , [2 , 2 ]]
237+
238+
239+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
239240@pytest .mark .parametrize (['example_input' , 'expected_size' ], [
240241 pytest .param ([], UNKNOWN_SIZE ),
241242 pytest .param ((1 , 2 , 3 ), [UNKNOWN_SIZE ] * 3 ),
@@ -269,21 +270,15 @@ def forward(self, *args, **kwargs):
269270 assert summary .in_sizes == [expected_size ]
270271
271272
272- @pytest .mark .parametrize (['mode' ], [
273- pytest .param (ModelSummary .MODE_FULL ),
274- pytest .param (ModelSummary .MODE_TOP ),
275- ])
273+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
276274def test_model_size (mode ):
277275 """ Test model size is calculated correctly. """
278276 model = PreCalculatedModel ()
279277 summary = model .summarize (mode = mode )
280278 assert model .pre_calculated_model_size == summary .model_size
281279
282280
283- @pytest .mark .parametrize (['mode' ], [
284- pytest .param (ModelSummary .MODE_FULL ),
285- pytest .param (ModelSummary .MODE_TOP ),
286- ])
281+ @pytest .mark .parametrize ('mode' , [ModelSummary .MODE_FULL , ModelSummary .MODE_TOP ])
287282def test_empty_model_size (mode ):
288283 """ Test empty model size is zero. """
289284 model = EmptyModule ()
@@ -293,23 +288,17 @@ def test_empty_model_size(mode):
293288
294289@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Test requires GPU." )
295290@pytest .mark .skipif (not _NATIVE_AMP_AVAILABLE , reason = "test requires native AMP." )
296- @pytest .mark .parametrize (
297- 'precision' , [
298- pytest .param (16 , marks = pytest .mark .skip (reason = "no longer valid, because 16 can mean mixed precision" )),
299- pytest .param (32 ),
300- ]
301- )
302- def test_model_size_precision (monkeypatch , tmpdir , precision ):
291+ def test_model_size_precision (tmpdir ):
303292 """ Test model size for half and full precision. """
304- model = PreCalculatedModel (precision )
293+ model = PreCalculatedModel ()
305294
306295 # fit model
307296 trainer = Trainer (
308297 default_root_dir = tmpdir ,
309298 gpus = 1 ,
310299 max_steps = 1 ,
311300 max_epochs = 1 ,
312- precision = precision ,
301+ precision = 32 ,
313302 )
314303 trainer .fit (model )
315304 summary = model .summarize ()
0 commit comments