Skip to content

Commit dae8d5f

Browse files
committed
update test
1 parent 811e8e1 commit dae8d5f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/core/test_datamodules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,10 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
296296
batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long)))
297297

298298
trainer = Trainer(gpus=1)
299+
model.trainer = trainer
299300
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
300301
get_module_mock.return_value = model
301-
trainer.attach_datamodule(model, datamodule=dm)
302+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
302303
batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
303304

304305
assert dm.on_before_batch_transfer_hook_rank == 0

0 commit comments

Comments
 (0)