Skip to content

Commit 64ac73b

Browse files
BordaSkafteNicki
andcommitted
Document speed comparison (#2072)
* docs * script * dump * desc * import * import * if * norm * t * finished * isort * typing Co-authored-by: Nicki Skafte <[email protected]> * xlabel * pandas * time Co-authored-by: Nicki Skafte <[email protected]>
1 parent 417071a commit 64ac73b

File tree

9 files changed

+155
-23
lines changed

9 files changed

+155
-23
lines changed

benchmarks/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
BENCHMARK_ROOT = os.path.dirname(__file__)
17+
PROJECT_ROOT = os.path.dirname(BENCHMARK_ROOT)

benchmarks/generate_comparison.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import matplotlib.pylab as plt
17+
import pandas as pd
18+
19+
from benchmarks.test_basic_parity import lightning_loop, vanilla_loop
20+
from tests.base.models import ParityModuleMNIST, ParityModuleRNN
21+
22+
NUM_EPOCHS = 20
23+
NUM_RUNS = 50
24+
MODEL_CLASSES = (ParityModuleRNN, ParityModuleMNIST)
25+
PATH_HERE = os.path.dirname(__file__)
26+
FIGURE_EXTENSION = '.png'
27+
28+
29+
def _main():
30+
fig, axarr = plt.subplots(nrows=len(MODEL_CLASSES))
31+
32+
for i, cls_model in enumerate(MODEL_CLASSES):
33+
path_csv = os.path.join(PATH_HERE, f'dump-times_{cls_model.__name__}.csv')
34+
if os.path.isfile(path_csv):
35+
df_time = pd.read_csv(path_csv, index_col=0)
36+
else:
37+
vanilla = vanilla_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
38+
lightning = lightning_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
39+
40+
df_time = pd.DataFrame({'vanilla PT': vanilla['durations'][1:], 'PT Lightning': lightning['durations'][1:]})
41+
df_time /= NUM_RUNS
42+
df_time.to_csv(os.path.join(PATH_HERE, f'dump-times_{cls_model.__name__}.csv'))
43+
# todo: add also relative X-axis ticks to see both: relative and absolute time differences
44+
df_time.plot.hist(
45+
ax=axarr[i],
46+
bins=20,
47+
alpha=0.5,
48+
title=cls_model.__name__,
49+
legend=True,
50+
grid=True,
51+
)
52+
axarr[i].set(xlabel='time [seconds]')
53+
54+
path_fig = os.path.join(PATH_HERE, f'figure-parity-times{FIGURE_EXTENSION}')
55+
fig.tight_layout()
56+
fig.savefig(path_fig)
57+
58+
59+
if __name__ == '__main__':
60+
_main()

benchmarks/test_parity.py renamed to benchmarks/test_basic_parity.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import time
216

317
import numpy as np
418
import pytest
519
import torch
20+
from tqdm import tqdm
621

722
from pytorch_lightning import seed_everything, Trainer
823
import tests.base.develop_utils as tutils
@@ -15,34 +30,33 @@
1530
(ParityModuleMNIST, 0.25), # todo: lower this thr
1631
])
1732
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
18-
def test_pytorch_parity(tmpdir, cls_model, max_diff):
33+
def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3):
1934
"""
2035
Verify that the same pytorch and lightning models achieve the same results
2136
"""
22-
num_epochs = 4
23-
num_rums = 3
24-
lightning_outs, pl_times = lightning_loop(cls_model, num_rums, num_epochs)
25-
manual_outs, pt_times = vanilla_loop(cls_model, num_rums, num_epochs)
37+
lightning = lightning_loop(cls_model, num_runs, num_epochs)
38+
vanilla = vanilla_loop(cls_model, num_runs, num_epochs)
2639

2740
# make sure the losses match exactly to 5 decimal places
28-
for pl_out, pt_out in zip(lightning_outs, manual_outs):
41+
for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']):
2942
np.testing.assert_almost_equal(pl_out, pt_out, 5)
3043

3144
# the fist run initialize dataset (download & filter)
32-
tutils.assert_speed_parity_absolute(pl_times[1:], pt_times[1:],
33-
nb_epochs=num_epochs, max_diff=max_diff)
45+
tutils.assert_speed_parity_absolute(
46+
lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff
47+
)
3448

3549

3650
def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
3751
"""
3852
Returns an array with the last loss from each epoch for each run
3953
"""
40-
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
41-
errors = []
42-
times = []
54+
hist_losses = []
55+
hist_durations = []
4356

57+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
4458
torch.backends.cudnn.deterministic = True
45-
for i in range(num_runs):
59+
for i in tqdm(range(num_runs), desc=f'Vanilla PT with {cls_model.__name__}'):
4660
time_start = time.perf_counter()
4761

4862
# set seed
@@ -74,18 +88,21 @@ def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
7488
epoch_losses.append(loss.item())
7589

7690
time_end = time.perf_counter()
77-
times.append(time_end - time_start)
91+
hist_durations.append(time_end - time_start)
7892

79-
errors.append(epoch_losses[-1])
93+
hist_losses.append(epoch_losses[-1])
8094

81-
return errors, times
95+
return {
96+
'losses': hist_losses,
97+
'durations': hist_durations,
98+
}
8299

83100

84101
def lightning_loop(cls_model, num_runs=10, num_epochs=10):
85-
errors = []
86-
times = []
102+
hist_losses = []
103+
hist_durations = []
87104

88-
for i in range(num_runs):
105+
for i in tqdm(range(num_runs), desc=f'PT Lightning with {cls_model.__name__}'):
89106
time_start = time.perf_counter()
90107

91108
# set seed
@@ -108,9 +125,12 @@ def lightning_loop(cls_model, num_runs=10, num_epochs=10):
108125
trainer.fit(model)
109126

110127
final_loss = trainer.train_loop.running_loss.last().item()
111-
errors.append(final_loss)
128+
hist_losses.append(final_loss)
112129

113130
time_end = time.perf_counter()
114-
times.append(time_end - time_start)
131+
hist_durations.append(time_end - time_start)
115132

116-
return errors, times
133+
return {
134+
'losses': hist_losses,
135+
'durations': hist_durations,
136+
}

benchmarks/test_sharded_parity.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import os
216
import platform
317
import time
30.8 KB
Loading

docs/source/benchmarking.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Benchmark with vanilla PyTorch
2+
==============================
3+
4+
In this section we set grounds for comparison between vanilla PyTorch and PT Lightning for most common scenarios.
5+
6+
Time comparison
7+
---------------
8+
9+
We have set regular benchmarking against PyTorch vanilla training loop on with RNN and simple MNIST classifier as per of out CI.
10+
In average for simple MNIST CNN classifier we are only about 0.06s slower per epoch, see detail chart bellow.
11+
12+
.. figure:: _images/benchmarks/figure-parity-times.png
13+
:alt: Speed parity to vanilla PT, created on 2020-12-16
14+
:width: 500

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ PyTorch Lightning Documentation
2424
style_guide
2525
performance
2626
Lightning project template<https://github.com/PyTorchLightning/pytorch-lightning-conference-seed>
27+
benchmarking
2728

2829

2930
.. toctree::

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ pre-commit>=1.0
1717

1818
cloudpickle>=1.3
1919
nltk>=3.3
20+
pandas # needed in benchmarks

tests/base/datasets.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,13 @@ class MNIST(Dataset):
6363
TEST_FILE_NAME = 'test.pt'
6464
cache_folder_name = 'complete'
6565

66-
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
67-
normalize: tuple = (0.5, 1.0), download: bool = True):
66+
def __init__(
67+
self,
68+
root: str = PATH_DATASETS,
69+
train: bool = True,
70+
normalize: tuple = (0.5, 1.0),
71+
download: bool = True,
72+
):
6873
super().__init__()
6974
self.root = root
7075
self.train = train # training set or test set

0 commit comments

Comments
 (0)