1+ # Copyright The PyTorch Lightning team.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115import time
216
317import numpy as np
418import pytest
519import torch
20+ from tqdm import tqdm
621
722from pytorch_lightning import seed_everything , Trainer
823import tests .base .develop_utils as tutils
1530 (ParityModuleMNIST , 0.25 ), # todo: lower this thr
1631])
1732@pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires GPU machine" )
18- def test_pytorch_parity (tmpdir , cls_model , max_diff ):
33+ def test_pytorch_parity (tmpdir , cls_model , max_diff : float , num_epochs : int = 4 , num_runs : int = 3 ):
1934 """
2035 Verify that the same pytorch and lightning models achieve the same results
2136 """
22- num_epochs = 4
23- num_rums = 3
24- lightning_outs , pl_times = lightning_loop (cls_model , num_rums , num_epochs )
25- manual_outs , pt_times = vanilla_loop (cls_model , num_rums , num_epochs )
37+ lightning = lightning_loop (cls_model , num_runs , num_epochs )
38+ vanilla = vanilla_loop (cls_model , num_runs , num_epochs )
2639
2740 # make sure the losses match exactly to 5 decimal places
28- for pl_out , pt_out in zip (lightning_outs , manual_outs ):
41+ for pl_out , pt_out in zip (lightning [ 'losses' ], vanilla [ 'losses' ] ):
2942 np .testing .assert_almost_equal (pl_out , pt_out , 5 )
3043
3144 # the fist run initialize dataset (download & filter)
32- tutils .assert_speed_parity_absolute (pl_times [1 :], pt_times [1 :],
33- nb_epochs = num_epochs , max_diff = max_diff )
45+ tutils .assert_speed_parity_absolute (
46+ lightning ['durations' ][1 :], vanilla ['durations' ][1 :], nb_epochs = num_epochs , max_diff = max_diff
47+ )
3448
3549
3650def vanilla_loop (cls_model , num_runs = 10 , num_epochs = 10 ):
3751 """
3852 Returns an array with the last loss from each epoch for each run
3953 """
40- device = torch .device ('cuda' if torch .cuda .is_available () else "cpu" )
41- errors = []
42- times = []
54+ hist_losses = []
55+ hist_durations = []
4356
57+ device = torch .device ('cuda' if torch .cuda .is_available () else "cpu" )
4458 torch .backends .cudnn .deterministic = True
45- for i in range (num_runs ):
59+ for i in tqdm ( range (num_runs ), desc = f'Vanilla PT with { cls_model . __name__ } ' ):
4660 time_start = time .perf_counter ()
4761
4862 # set seed
@@ -74,18 +88,21 @@ def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
7488 epoch_losses .append (loss .item ())
7589
7690 time_end = time .perf_counter ()
77- times .append (time_end - time_start )
91+ hist_durations .append (time_end - time_start )
7892
79- errors .append (epoch_losses [- 1 ])
93+ hist_losses .append (epoch_losses [- 1 ])
8094
81- return errors , times
95+ return {
96+ 'losses' : hist_losses ,
97+ 'durations' : hist_durations ,
98+ }
8299
83100
84101def lightning_loop (cls_model , num_runs = 10 , num_epochs = 10 ):
85- errors = []
86- times = []
102+ hist_losses = []
103+ hist_durations = []
87104
88- for i in range (num_runs ):
105+ for i in tqdm ( range (num_runs ), desc = f'PT Lightning with { cls_model . __name__ } ' ):
89106 time_start = time .perf_counter ()
90107
91108 # set seed
@@ -108,9 +125,12 @@ def lightning_loop(cls_model, num_runs=10, num_epochs=10):
108125 trainer .fit (model )
109126
110127 final_loss = trainer .train_loop .running_loss .last ().item ()
111- errors .append (final_loss )
128+ hist_losses .append (final_loss )
112129
113130 time_end = time .perf_counter ()
114- times .append (time_end - time_start )
131+ hist_durations .append (time_end - time_start )
115132
116- return errors , times
133+ return {
134+ 'losses' : hist_losses ,
135+ 'durations' : hist_durations ,
136+ }
0 commit comments