-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
direct code
import os, torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
BATCH_SIZE = 64
class MNISTModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = torch.nn.functional.cross_entropy(self(x), y)
self.log('loss', loss, prog_bar=True, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
class Data():
def __len__(self):
return 1000
def __getitem__(self, batch_idx):
return torch.randn(28*28), torch.rand(10)
make_ckpt_cb = lambda: ModelCheckpoint(dirpath='.', monitor='loss', mode='min',
every_n_epochs=1, save_top_k=-1)
mnist_model = MNISTModel()
train_loader = DataLoader(Data(), batch_size=BATCH_SIZE)
trainer = Trainer(gpus=0, max_epochs=1, callbacks=make_ckpt_cb())
trainer.fit(mnist_model, train_loader)
ckpt_path = [nm for nm in os.listdir() if nm.endswith('.ckpt')][0]
mnist_model = MNISTModel()
train_loader = DataLoader(Data(), batch_size=BATCH_SIZE)
trainer = Trainer(gpus=0, max_epochs=2, callbacks=make_ckpt_cb())
trainer.fit(mnist_model, train_loader, ckpt_path=ckpt_path)Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working