|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -from unittest import mock |
15 | 14 |
|
16 | | -import pytest |
17 | 15 | import torch |
18 | 16 | import torch.nn.functional as F |
19 | 17 | from torch.utils.data import DataLoader |
20 | 18 |
|
21 | 19 | import pytorch_lightning as pl |
22 | 20 | import tests_pytorch.helpers.pipelines as tpipes |
23 | 21 | import tests_pytorch.helpers.utils as tutils |
24 | | -from pytorch_lightning import Trainer |
25 | 22 | from pytorch_lightning.callbacks import EarlyStopping |
26 | 23 | from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset |
27 | | -from pytorch_lightning.utilities.exceptions import MisconfigurationException |
28 | 24 | from tests_pytorch.helpers.datamodules import ClassifDataModule |
29 | 25 | from tests_pytorch.helpers.runif import RunIf |
30 | 26 | from tests_pytorch.helpers.simple_models import ClassificationModel |
@@ -154,47 +150,6 @@ def _assert_extra_outputs(self, outputs): |
154 | 150 | assert out.dtype is torch.float |
155 | 151 |
|
156 | 152 |
|
157 | | -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) |
158 | | -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) |
159 | | -def test_dp_raise_exception_with_batch_transfer_hooks(mock_is_available, mock_device_count, tmpdir): |
160 | | - """Test that an exception is raised when overriding batch_transfer_hooks in DP model.""" |
161 | | - |
162 | | - class CustomModel(BoringModel): |
163 | | - def transfer_batch_to_device(self, batch, device, dataloader_idx): |
164 | | - batch = batch.to(device) |
165 | | - return batch |
166 | | - |
167 | | - trainer_options = dict(default_root_dir=tmpdir, max_steps=7, accelerator="gpu", devices=[0, 1], strategy="dp") |
168 | | - |
169 | | - trainer = Trainer(**trainer_options) |
170 | | - model = CustomModel() |
171 | | - |
172 | | - with pytest.raises(MisconfigurationException, match=r"Overriding `transfer_batch_to_device` is not .* in DP"): |
173 | | - trainer.fit(model) |
174 | | - |
175 | | - class CustomModel(BoringModel): |
176 | | - def on_before_batch_transfer(self, batch, dataloader_idx): |
177 | | - batch += 1 |
178 | | - return batch |
179 | | - |
180 | | - trainer = Trainer(**trainer_options) |
181 | | - model = CustomModel() |
182 | | - |
183 | | - with pytest.raises(MisconfigurationException, match=r"Overriding `on_before_batch_transfer` is not .* in DP"): |
184 | | - trainer.fit(model) |
185 | | - |
186 | | - class CustomModel(BoringModel): |
187 | | - def on_after_batch_transfer(self, batch, dataloader_idx): |
188 | | - batch += 1 |
189 | | - return batch |
190 | | - |
191 | | - trainer = Trainer(**trainer_options) |
192 | | - model = CustomModel() |
193 | | - |
194 | | - with pytest.raises(MisconfigurationException, match=r"Overriding `on_after_batch_transfer` is not .* in DP"): |
195 | | - trainer.fit(model) |
196 | | - |
197 | | - |
198 | 153 | @RunIf(min_cuda_gpus=2) |
199 | 154 | def test_dp_training_step_dict(tmpdir): |
200 | 155 | """This test verifies that dp properly reduces dictionaries.""" |
|
0 commit comments