Skip to content

Commit 3adc0c9

Browse files
committed
update test
1 parent 26b9e71 commit 3adc0c9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/core/test_datamodules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,10 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
297297
batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long)))
298298

299299
trainer = Trainer(gpus=1)
300+
model.trainer = trainer
300301
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
301302
get_module_mock.return_value = model
302-
trainer.attach_datamodule(model, datamodule=dm)
303+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
303304
batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
304305

305306
assert dm.on_before_batch_transfer_hook_rank == 0

0 commit comments

Comments
 (0)