|
18 | 18 |
|
19 | 19 | import pytest |
20 | 20 | import torch |
| 21 | +import torch.distributed as torch_distrib |
21 | 22 | import torch.nn.functional as F |
22 | 23 |
|
23 | 24 | from pytorch_lightning import Trainer, seed_everything |
@@ -862,7 +863,7 @@ def dis_closure(): |
862 | 863 | self.manual_backward(loss_dis, opt_dis) |
863 | 864 |
|
864 | 865 | # this will accumulate gradients for 2 batches and then call opt_gen.step() |
865 | | - opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0, optim='sgd') |
| 866 | + opt_gen.step(closure=gen_closure, make_optimizer_step=(batch_idx % 2 == 0), optim='sgd') |
866 | 867 |
|
867 | 868 | # update discriminator every 4 baches |
868 | 869 | # therefore, no gradient accumulation for discriminator |
@@ -904,6 +905,114 @@ def configure_optimizers(self): |
904 | 905 | mock_adam_step.assert_has_calls(expected_calls) |
905 | 906 |
|
906 | 907 |
|
| 908 | +@patch("torch.optim.Adam.step") |
| 909 | +@patch("torch.optim.SGD.step") |
| 910 | +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") |
| 911 | +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") |
| 912 | +def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_step, mock_adam_step, tmpdir): |
| 913 | + """ |
| 914 | + Tests that `step` works with optimizer_closure and different accumulated_gradient frequency |
| 915 | + """ |
| 916 | + os.environ['PL_DEV_DEBUG'] = '1' |
| 917 | + |
| 918 | + class TestModel(BoringModel): |
| 919 | + |
| 920 | + def loss_ones(self, batch, prediction): |
| 921 | + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls |
| 922 | + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) |
| 923 | + |
| 924 | + def loss_zeros(self, batch, prediction): |
| 925 | + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls |
| 926 | + return torch.nn.functional.mse_loss(prediction, torch.zeros_like(prediction)) |
| 927 | + |
| 928 | + def manual_sync_grad(self) -> bool: |
| 929 | + torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False) |
| 930 | + return True |
| 931 | + |
| 932 | + def training_step(self, batch, batch_idx, optimizer_idx): |
| 933 | + |
| 934 | + # emulate gans training |
| 935 | + opt_gen, opt_dis = self.optimizers() |
| 936 | + |
| 937 | + # Note: Be careful, don't log on the same key in self.log in both closure |
| 938 | + # as they will be aggregated together on epoch_end |
| 939 | + |
| 940 | + world_size = torch_distrib.get_world_size(torch_distrib.group.WORLD) |
| 941 | + assert world_size == 2 |
| 942 | + |
| 943 | + def compute_loss(): |
| 944 | + x = batch[0] |
| 945 | + x = F.dropout(x, 0.1) |
| 946 | + predictions = self(x) |
| 947 | + predictions = F.dropout(predictions, 0.1) |
| 948 | + loss_ones = self.loss_ones(None, predictions) |
| 949 | + loss_zeros = self.loss_zeros(None, predictions) |
| 950 | + return loss_ones, loss_zeros |
| 951 | + |
| 952 | + def make_manual_backward(loss, opt, retain_graph=False): |
| 953 | + self.manual_backward(loss, opt, retain_graph=retain_graph) |
| 954 | + grad_clone = self.layer.weight.grad.clone() |
| 955 | + assert self.manual_sync_grad() |
| 956 | + self.layer.weight.grad /= world_size |
| 957 | + assert torch.equal(self.layer.weight.grad, grad_clone) |
| 958 | + |
| 959 | + def gen_closure(): |
| 960 | + loss_ones_gen, loss_zeros = compute_loss() |
| 961 | + make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True) |
| 962 | + make_manual_backward(loss_ones_gen, opt_gen) |
| 963 | + |
| 964 | + def dis_closure(): |
| 965 | + loss_ones_gen, loss_zeros = compute_loss() |
| 966 | + make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True) |
| 967 | + make_manual_backward(loss_ones_gen, opt_dis) |
| 968 | + |
| 969 | + # this will accumulate gradients for 2 batches and then call opt_gen.step() |
| 970 | + opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0, optim='sgd') |
| 971 | + |
| 972 | + # update discriminator every 4 baches |
| 973 | + # therefore, no gradient accumulation for discriminator |
| 974 | + if batch_idx % 4 == 0 : |
| 975 | + # Note: Set make_optimizer_step to True or it will use by default |
| 976 | + # Trainer(accumulate_grad_batches=x) |
| 977 | + opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam') |
| 978 | + |
| 979 | + def training_epoch_end(self, outputs) -> None: |
| 980 | + # outputs should be an array with an entry per optimizer |
| 981 | + assert len(outputs) == 2 |
| 982 | + |
| 983 | + def configure_optimizers(self): |
| 984 | + optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) |
| 985 | + optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) |
| 986 | + return [optimizer_gen, optimizer_dis] |
| 987 | + |
| 988 | + seed_everything(42) |
| 989 | + |
| 990 | + model = TestModel() |
| 991 | + model.val_dataloader = None |
| 992 | + model.training_epoch_end = None |
| 993 | + |
| 994 | + limit_train_batches = 8 |
| 995 | + trainer = Trainer( |
| 996 | + automatic_optimization=False, |
| 997 | + default_root_dir=tmpdir, |
| 998 | + limit_train_batches=limit_train_batches, |
| 999 | + limit_val_batches=2, |
| 1000 | + max_epochs=1, |
| 1001 | + log_every_n_steps=1, |
| 1002 | + accumulate_grad_batches=2, |
| 1003 | + enable_pl_optimizer=True, |
| 1004 | + gpus=2, |
| 1005 | + accelerator="ddp", |
| 1006 | + ) |
| 1007 | + |
| 1008 | + trainer.fit(model) |
| 1009 | + expected_calls = [call(closure=ANY, optim='sgd')] * 4 |
| 1010 | + mock_sgd_step.assert_has_calls(expected_calls) |
| 1011 | + |
| 1012 | + expected_calls = [call(closure=ANY, optim='adam')] * 2 |
| 1013 | + mock_adam_step.assert_has_calls(expected_calls) |
| 1014 | + |
| 1015 | + |
907 | 1016 | def test_step_with_misconfiguraiton_error_when_overriding_optimizer_zero_grad(tmpdir): |
908 | 1017 | """ |
909 | 1018 | Tests that `optimizer_zero_grad` in manual_optimization triggers a MisconfigurationException |
|
0 commit comments