Skip to content

Commit 0003f66

Browse files
tchatonBorda
authored andcommitted
[doc] Update Dict Train Loader doc. (#6579)
* update doc * update example (cherry picked from commit 8853a36)
1 parent da4b5eb commit 0003f66

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

docs/source/advanced/multiple_loaders.rst

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Multiple Datasets
99
Lightning supports multiple dataloaders in a few ways.
1010

1111
1. Create a dataloader that iterates multiple datasets under the hood.
12-
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
12+
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
1313
will automatically combine the batches from different loaders.
1414
3. In the validation and test loop you also have the option to return multiple dataloaders
1515
which lightning will call sequentially.
@@ -75,21 +75,38 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer
7575

7676
loader_a = torch.utils.data.DataLoader(range(6), batch_size=4)
7777
loader_b = torch.utils.data.DataLoader(range(15), batch_size=5)
78-
78+
7979
# pass loaders as a dict. This will create batches like this:
8080
# {'a': batch from loader_a, 'b': batch from loader_b}
8181
loaders = {'a': loader_a,
8282
'b': loader_b}
8383

84-
# OR:
84+
# OR:
8585
# pass loaders as sequence. This will create batches like this:
8686
# [batch from loader_a, batch from loader_b]
8787
loaders = [loader_a, loader_b]
8888

8989
return loaders
9090

9191
Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
92-
be returned
92+
be returned.
93+
94+
.. testcode::
95+
96+
class LitModel(LightningModule):
97+
98+
def train_dataloader(self):
99+
100+
loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
101+
loader_b = torch.utils.data.DataLoader(range(16), batch_size=2)
102+
103+
return {'a': loader_a, 'b': loader_b}
104+
105+
def training_step(self, batch, batch_idx):
106+
# access a dictionnary with a batch from each dataloader
107+
batch_a = batch["a"]
108+
batch_b = batch["b"]
109+
93110

94111
.. testcode::
95112

@@ -103,12 +120,29 @@ be returned
103120
loader_c = torch.utils.data.DataLoader(range(64), batch_size=4)
104121

105122
# pass loaders as a nested dict. This will create batches like this:
106-
# {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b},
107-
# 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}}
108-
loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b},
109-
'loaders_c_d': {'c': loader_c, 'd': loader_d}}
123+
loaders = {
124+
'loaders_a_b': {
125+
'a': loader_a,
126+
'b': loader_b
127+
},
128+
'loaders_c_d': {
129+
'c': loader_c,
130+
'd': loader_d
131+
}
132+
}
110133
return loaders
111134

135+
def training_step(self, batch, batch_idx):
136+
# access the data
137+
batch_a_b = batch["loaders_a_b"]
138+
batch_c_d = batch["loaders_c_d"]
139+
140+
batch_a = batch_a_b["a"]
141+
batch_b = batch_a_b["a"]
142+
143+
batch_c = batch_c_d["c"]
144+
batch_d = batch_c_d["d"]
145+
112146
----------
113147

114148
Test/Val dataloaders

0 commit comments

Comments
 (0)