|
17 | 17 | import torch |
18 | 18 | from torch.utils.data.dataloader import DataLoader |
19 | 19 |
|
| 20 | +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin |
20 | 21 | from pytorch_lightning.lite import LightningLite |
21 | 22 | from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer |
22 | 23 | from tests.helpers.runif import RunIf |
@@ -65,6 +66,27 @@ def check_autocast(forward_input): |
65 | 66 | assert out.dtype == input_type or out.dtype == torch.get_default_dtype() |
66 | 67 |
|
67 | 68 |
|
| 69 | +@pytest.mark.parametrize( |
| 70 | + "device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))] |
| 71 | +) |
| 72 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) |
| 73 | +def test_lite_module_device_dtype_propagation(device, dtype): |
| 74 | + """Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics).""" |
| 75 | + |
| 76 | + class DeviceModule(DeviceDtypeModuleMixin): |
| 77 | + pass |
| 78 | + |
| 79 | + device_module = DeviceModule() |
| 80 | + lite_module = _LiteModule(device_module, Mock()) |
| 81 | + lite_module.to(device) |
| 82 | + assert device_module.device == device |
| 83 | + assert lite_module.device == device |
| 84 | + |
| 85 | + lite_module.to(dtype) |
| 86 | + assert device_module.dtype == dtype |
| 87 | + assert lite_module.dtype == dtype |
| 88 | + |
| 89 | + |
68 | 90 | def test_lite_dataloader_iterator(): |
69 | 91 | """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic |
70 | 92 | device placement).""" |
|
0 commit comments