diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 28025859814cc..6d206f3dd929e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -91,11 +91,13 @@ def check_dataloader_idx(self, result: Result) -> bool: random_key = list(result.keys())[-1] return result["meta"][random_key]["dataloader_idx"] is not None - def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict: + def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict: results = {} - add_dataloader_idx = self.check_dataloader_idx(latest_result) - func = getattr(latest_result, func_name) - results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) + for opt_idx in latest_result_opt: + latest_result = latest_result_opt[opt_idx] + add_dataloader_idx = self.check_dataloader_idx(latest_result) + func = getattr(latest_result, func_name) + results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) return results def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: @@ -156,6 +158,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio assert isinstance(result, Result) if dataloader_idx is None: dataloader_idx = 0 + if extra_info is None: extra_info = {} @@ -166,6 +169,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio if dataloader_idx not in self._internals: self._internals[dataloader_idx] = {} self._internals_reduced[dataloader_idx] = defaultdict(dict) + self._latest_ref[dataloader_idx] = {} # extract infos opt_idx = extra_info["opt_idx"] @@ -173,7 +177,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result) - self._latest_ref[dataloader_idx] = result + self._latest_ref[dataloader_idx][opt_idx] = result # [dataloader_idx] is a list else: @@ -181,7 +185,11 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio self._internals.setdefault(dataloader_idx, []) self._internals[dataloader_idx].append(result) - self._latest_ref[dataloader_idx] = result + if dataloader_idx not in self._latest_ref: + self._latest_ref[dataloader_idx] = {} + self._latest_ref[dataloader_idx][0] = {} + + self._latest_ref[dataloader_idx][0] = result def auto_reduce_results_on_epoch_end(self) -> None: """ @@ -206,13 +214,9 @@ def auto_reduce_results_on_epoch_end(self) -> None: # TODO: How to start training in middle of epoch opt_outputs = epoch_metrics[opt_idx] - num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1 - assert num_batch_idx >= 0 - batch_indexes = self._internals[dl_idx][num_opt_idx].keys() - # reduce across time first time_reduced_outputs = [] - for batch_idx in batch_indexes: + for batch_idx in opt_outputs.keys(): tbptt_outs = opt_outputs[batch_idx] tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) if len(tbptt_outs) > 1: diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py new file mode 100644 index 0000000000000..78b6f8f7ff84a --- /dev/null +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the behaviours related to multiple optimizers works +""" +import torch + +import pytorch_lightning as pl +from tests.base.boring_model import BoringModel + + +def test_unbalanced_logging_with_multiple_optimizers(tmpdir): + """ + This tests ensures reduction works in un-balanced logging settings + """ + class TestModel(BoringModel): + + loss_1 = [] + loss_2 = [] + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + if optimizer_idx == 0 and self.trainer.global_step > 10: + self.log("loss_1", loss, on_epoch=True, prog_bar=True) + self.loss_1.append(loss.detach().clone()) + elif optimizer_idx == 1: + self.log("loss_2", loss, on_epoch=True, prog_bar=True) + self.loss_2.append(loss.detach().clone()) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) + optimizer2 = torch.optim.SGD(self.layer.parameters(), lr=0.001) + return [optimizer, optimizer2] + + model = TestModel() + model.training_epoch_end = None + + # Initialize a trainer + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + trainer.fit(model) + + assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1]) + assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1]) + # test loss are properly reduced + assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6 + assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6