11import unittest
22
3+ import pytorch_lightning as pl
34import torch
45import torch .nn .functional as F
56from torch .utils .data import DataLoader , TensorDataset
67
7- import pytorch_lightning as pl
8- from pytorch_lightning .metrics .functional import to_onehot
9-
108
119class LitDataModule (pl .LightningDataModule ):
1210
@@ -16,10 +14,10 @@ def __init__(self, batch_size=16):
1614 self .batch_size = batch_size
1715
1816 def setup (self , stage = None ):
19- X_train = torch .rand (100 , 1 , 28 , 28 ). float ()
20- y_train = to_onehot ( torch .randint (0 , 10 , size = (100 ,)), num_classes = 10 ). float ( )
17+ X_train = torch .rand (100 , 1 , 28 , 28 )
18+ y_train = torch .randint (0 , 10 , size = (100 ,))
2119 X_valid = torch .rand (20 , 1 , 28 , 28 )
22- y_valid = to_onehot ( torch .randint (0 , 10 , size = (20 ,)), num_classes = 10 ). float ( )
20+ y_valid = torch .randint (0 , 10 , size = (20 ,))
2321
2422 self .train_ds = TensorDataset (X_train , y_train )
2523 self .valid_ds = TensorDataset (X_valid , y_valid )
@@ -38,26 +36,23 @@ def __init__(self):
3836 self .l1 = torch .nn .Linear (28 * 28 , 10 )
3937
4038 def forward (self , x ):
41- return torch .relu (self .l1 (x .view (x .size (0 ), - 1 )))
39+ return F .relu (self .l1 (x .view (x .size (0 ), - 1 )))
4240
4341 def training_step (self , batch , batch_idx ):
4442 x , y = batch
4543 y_hat = self (x )
46- loss = F .binary_cross_entropy_with_logits (y_hat , y )
47- result = pl .TrainResult (loss )
48- result .log ('train_loss' , loss , on_epoch = True )
49- return result
44+ loss = F .cross_entropy (y_hat , y )
45+ self .log ('train_loss' , loss )
46+ return loss
5047
5148 def validation_step (self , batch , batch_idx ):
5249 x , y = batch
5350 y_hat = self (x )
54- loss = F .binary_cross_entropy_with_logits (y_hat , y )
55- result = pl .EvalResult (checkpoint_on = loss )
56- result .log ('val_loss' , loss )
57- return result
51+ loss = F .cross_entropy (y_hat , y )
52+ self .log ('val_loss' , loss )
5853
5954 def configure_optimizers (self ):
60- return torch .optim .Adam (self .parameters (), lr = 0.02 )
55+ return torch .optim .Adam (self .parameters (), lr = 1e-2 )
6156
6257
6358class TestPytorchLightning (unittest .TestCase ):
0 commit comments