-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Hi!
In the pseudo-code below I have two models that I want to fit with two different datasets.
I have tried to figure out if this is possible by reading test_dataloaders.py with no success...
In the documentation, it states that:
Multiple training dataloaders
For training, the best way to use multiple-dataloaders is to create a Dataloader class which wraps both your dataloaders. (This of course also works for testing and validation dataloaders).
But that doesn't really help me...
I guess that this already has been discussed in: #1089
And that I should study: https://gist.github.com/Dref360/2524e524244569ed47428f19c487f264
But it would be nice with a dataloader_idx like just like the optimizer_idx parameter...
Or perhaps a batch could have a dictionary-like structure where you sample data into different "baskets"
so that I could write something like:
if optimizer_idx == 0:
# REQUIRED
x, y = batch[0]
class FashionMNIST_and_MNISTModel(pl.LightningModule):
def __init__(self):
super(FashionMNIST_and_MNISTModel, self).__init__()
# l1 should be fit to MNIST dataset
self.l1 = torch.nn.Linear(28 * 28, 10)
# l2 should be fit to FashionMNIST dataset
self.l2 = torch.nn.Linear(28 * 28, 10)
def training_step(self, batch, batch_nb, optimizer_idx):
if optimizer_idx == 0:
# REQUIRED
x, y = batch
y_hat = torch.relu(self.l1(x.view(x.size(0), -1)))
loss_l1 = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss_l1}
return {'loss': loss_l1, 'log': tensorboard_logs}
if optimizer_idx == 1:
# REQUIRED
x, y = batch
y_hat = torch.relu(self.l2(x.view(x.size(0), -1)))
loss_l2 = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss_l2}
return {'loss': loss_l2, 'log': tensorboard_logs}
def train_dataloader(self):
# REQUIRED
return [
DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32),
DataLoader(FashionMNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
]
//Christofer