Skip to content

Commit df8b676

Browse files
tchatonrohitgr7
authored andcommitted
Un-balanced logging properly supported (#5119)
* resolve bug * clean code * resolve comments * Update tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Rohit Gupta <[email protected]> * resolve another bug * add comments * use abs to find diff * update * resolve flake8 Co-authored-by: Rohit Gupta <[email protected]>
1 parent b3fc662 commit df8b676

File tree

2 files changed

+78
-11
lines changed

2 files changed

+78
-11
lines changed

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,13 @@ def check_dataloader_idx(self, result: Result) -> bool:
9191
random_key = list(result.keys())[-1]
9292
return result["meta"][random_key]["dataloader_idx"] is not None
9393

94-
def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict:
94+
def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict:
9595
results = {}
96-
add_dataloader_idx = self.check_dataloader_idx(latest_result)
97-
func = getattr(latest_result, func_name)
98-
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
96+
for opt_idx in latest_result_opt:
97+
latest_result = latest_result_opt[opt_idx]
98+
add_dataloader_idx = self.check_dataloader_idx(latest_result)
99+
func = getattr(latest_result, func_name)
100+
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
99101
return results
100102

101103
def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
@@ -156,6 +158,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
156158
assert isinstance(result, Result)
157159
if dataloader_idx is None:
158160
dataloader_idx = 0
161+
159162
if extra_info is None:
160163
extra_info = {}
161164

@@ -166,22 +169,27 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
166169
if dataloader_idx not in self._internals:
167170
self._internals[dataloader_idx] = {}
168171
self._internals_reduced[dataloader_idx] = defaultdict(dict)
172+
self._latest_ref[dataloader_idx] = {}
169173

170174
# extract infos
171175
opt_idx = extra_info["opt_idx"]
172176
batch_idx = extra_info["batch_idx"]
173177

174178
self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)
175179

176-
self._latest_ref[dataloader_idx] = result
180+
self._latest_ref[dataloader_idx][opt_idx] = result
177181

178182
# [dataloader_idx] is a list
179183
else:
180184
self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP
181185
self._internals.setdefault(dataloader_idx, [])
182186
self._internals[dataloader_idx].append(result)
183187

184-
self._latest_ref[dataloader_idx] = result
188+
if dataloader_idx not in self._latest_ref:
189+
self._latest_ref[dataloader_idx] = {}
190+
self._latest_ref[dataloader_idx][0] = {}
191+
192+
self._latest_ref[dataloader_idx][0] = result
185193

186194
def auto_reduce_results_on_epoch_end(self) -> None:
187195
"""
@@ -206,13 +214,9 @@ def auto_reduce_results_on_epoch_end(self) -> None:
206214
# TODO: How to start training in middle of epoch
207215
opt_outputs = epoch_metrics[opt_idx]
208216

209-
num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1
210-
assert num_batch_idx >= 0
211-
batch_indexes = self._internals[dl_idx][num_opt_idx].keys()
212-
213217
# reduce across time first
214218
time_reduced_outputs = []
215-
for batch_idx in batch_indexes:
219+
for batch_idx in opt_outputs.keys():
216220
tbptt_outs = opt_outputs[batch_idx]
217221
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
218222
if len(tbptt_outs) > 1:
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Tests to ensure that the behaviours related to multiple optimizers works
16+
"""
17+
import torch
18+
19+
import pytorch_lightning as pl
20+
from tests.base.boring_model import BoringModel
21+
22+
23+
def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
24+
"""
25+
This tests ensures reduction works in un-balanced logging settings
26+
"""
27+
class TestModel(BoringModel):
28+
29+
loss_1 = []
30+
loss_2 = []
31+
32+
def training_step(self, batch, batch_idx, optimizer_idx):
33+
output = self.layer(batch)
34+
loss = self.loss(batch, output)
35+
if optimizer_idx == 0 and self.trainer.global_step > 10:
36+
self.log("loss_1", loss, on_epoch=True, prog_bar=True)
37+
self.loss_1.append(loss.detach().clone())
38+
elif optimizer_idx == 1:
39+
self.log("loss_2", loss, on_epoch=True, prog_bar=True)
40+
self.loss_2.append(loss.detach().clone())
41+
return {"loss": loss}
42+
43+
def configure_optimizers(self):
44+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
45+
optimizer2 = torch.optim.SGD(self.layer.parameters(), lr=0.001)
46+
return [optimizer, optimizer2]
47+
48+
model = TestModel()
49+
model.training_epoch_end = None
50+
51+
# Initialize a trainer
52+
trainer = pl.Trainer(
53+
default_root_dir=tmpdir,
54+
max_epochs=1,
55+
)
56+
57+
trainer.fit(model)
58+
59+
assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
60+
assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
61+
# test loss are properly reduced
62+
assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
63+
assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6

0 commit comments

Comments
 (0)