-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task
Description
Common bugs:
Comparing the results of LBFGS + Pytorch lightening to native pytorch + LBFGS, Pytorch lightening is not able to update wights and model is not converging. there are some issues to point out:
- Adam + Pytorch lightening on MNIST works fine, however LBFGS + Pytorch lightening is not working as expected.
- LBFGS + Native pytorch works very well, however when we try LBFGS + Pytorch lightening it does not work as expected.
🐛 Bug
LBFGS + Pytorch Lightening has problem converging and weights are updating as compared to Adam + Pytorch lightening.
Code sample
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms,datasets
from torch.utils.data import DataLoader,random_split
import pytorch_lightning as pl
from IPython.display import clear_output
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier,self).__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 256)
self.layer_3 = nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
x=x.view(batch_size,-1)
# layer 1
x = self.layer_1(x)
x = torch.relu(x)
# layer 2
x = self.layer_2(x)
x = torch.relu(x)
# layer 3
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
def prepare_data(self):
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# prepare transforms standard to MNIST
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train,batch_size=1024)
# def val_dataloader(self):
# return DataLoader(self.mnist_val,batch_size=1024)
# def test_dataloader(self):
# return DataLoader(self.mnist_test,batch_size=1024)
def configure_optimizers(self):
# optimizer=optim.Adam(self.parameters(),lr=1e-3)
optimizer = optim.LBFGS(self.parameters(), lr=1e-2)
return optimizer
# def backward(self, trainer, loss, optimizer):
# loss.backward(retain_graph=True)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx,
second_order_closure, on_tpu=False, using_native_amp=False,
using_lbfgs=False):
# update params
optimizer.step(second_order_closure)
def cross_entropy_loss(self,logits,labels):
return F.nll_loss(logits,labels)
def training_step(self,train_batch,batch_idx):
x,y=train_batch
logits=self.forward(x)
loss=self.cross_entropy_loss(logits,y)
return {'loss':loss}
def training_epoch_end(self,outputs):
avg_loss=torch.stack([x['loss'] for x in outputs]).mean()
print('epoch={}, avg_Train_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
# return {'avg_train_loss':avg_loss}
# def validation_step(self,val_batch,batch_idx):
# x,y=val_batch
# logits=self.forward(x)
# loss=self.cross_entropy_loss(logits,y)
# return {'val_loss':loss}
# def validation_epoch_end(self,outputs):
# avg_loss=torch.stack([x['val_loss'] for x in outputs]).mean()
# print('epoch={}, avg_Test_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
# return {'avg_val_loss':avg_loss}
model=LightningMNISTClassifier()
#from pytorch_lightning.callbacks import EarlyStopping
trainer=pl.Trainer(max_epochs=400,gpus=1,
# check_val_every_n_epoch=2,
# accumulate_grad_batches=5,
# early_stop_callback=early_stop,
# limit_train_batches=50,
# val_check_interval=0.25,
progress_bar_refresh_rate=0,
# num_sanity_val_steps=0,
weights_summary=None)
clear_output(wait=True)
trainer.fit(model)Preformatted text.
Expected behavior
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
Environment:
-Colab and pycharm
-PyTorch version: 1.6.0+CPU and GPU
-pytorch-lightning==1.0.0rc3
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task