Skip to content

Commit 5dace57

Browse files
authored
Merge branch 'master' into ci/gha-python
2 parents 10e3d84 + 64163c2 commit 5dace57

File tree

5 files changed

+54
-7
lines changed

5 files changed

+54
-7
lines changed

docs/source/datamodules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ Use this method to generate the val dataloader. Usually you just wrap the datas
268268
def val_dataloader(self):
269269
return DataLoader(self.mnist_val, batch_size=64)
270270
271+
.. _datamodule-test-dataloader-label:
271272

272273
test_dataloader
273274
^^^^^^^^^^^^^^^

docs/source/optimizers.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ For example, here step optimizer A every 2 batches and optimizer B every 4 batch
201201
# Alternating schedule for optimizer steps (ie: GANs)
202202
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
203203
# update generator opt every 2 steps
204-
if optimizer_i == 0:
204+
if optimizer_idx == 0:
205205
if batch_nb % 2 == 0 :
206206
optimizer.step(closure=closure)
207207

208208
# update discriminator opt every 4 steps
209-
if optimizer_i == 1:
209+
if optimizer_idx == 1:
210210
if batch_nb % 4 == 0 :
211211
optimizer.step(closure=closure)
212212

@@ -220,11 +220,11 @@ For example, here step optimizer A every 2 batches and optimizer B every 4 batch
220220
# Alternating schedule for optimizer steps (ie: GANs)
221221
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
222222
# update generator opt every 2 steps
223-
if optimizer_i == 0:
223+
if optimizer_idx == 0:
224224
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0)
225225

226226
# update discriminator opt every 4 steps
227-
if optimizer_i == 1:
227+
if optimizer_idx == 1:
228228
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0)
229229

230230
Here we add a learning-rate warm up

docs/source/test_set.rst

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Test set
44
========
55
Lightning forces the user to run the test set separately to make sure it isn't evaluated by mistake.
6+
Testing is performed using the ``trainer`` object's ``.test()`` method.
7+
8+
.. automethod:: pytorch_lightning.trainer.Trainer.test
9+
:noindex:
610

711
----------
812

@@ -82,4 +86,22 @@ is not available at the time your model was declared.
8286
trainer.test(test_dataloaders=test)
8387
8488
You can either pass in a single dataloader or a list of them. This optional named
85-
parameter can be used in conjunction with any of the above use cases.
89+
parameter can be used in conjunction with any of the above use cases. Additionally,
90+
you can also pass in an :ref:`datamodules` that have overridden the
91+
:ref:`datamodule-test-dataloader-label` method.
92+
93+
.. code-block:: python
94+
95+
class MyDataModule(pl.LightningDataModule):
96+
...
97+
def test_dataloader(self):
98+
return DataLoader(...)
99+
100+
# setup your datamodule
101+
dm = MyDataModule(...)
102+
103+
# test (pass in datamodule)
104+
trainer.test(datamodule=dm)
105+
106+
107+

pytorch_lightning/metrics/metric.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def add_state(
9494
reset to this value when ``self.reset()`` is called.
9595
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
9696
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
97-
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
97+
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
98+
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
9899
function in this parameter.
99100
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
100101
Default is ``False``.
@@ -244,7 +245,7 @@ def reset(self):
244245
"""
245246
for attr, default in self._defaults.items():
246247
current_val = getattr(self, attr)
247-
if isinstance(current_val, torch.Tensor):
248+
if isinstance(default, torch.Tensor):
248249
setattr(self, attr, deepcopy(default).to(current_val.device))
249250
else:
250251
setattr(self, attr, deepcopy(default))

tests/metrics/test_metric.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ def compute(self):
2626
pass
2727

2828

29+
class DummyList(Metric):
30+
name = "DummyList"
31+
32+
def __init__(self):
33+
super().__init__()
34+
self.add_state("x", list(), dist_reduce_fx=None)
35+
36+
def update(self):
37+
pass
38+
39+
def compute(self):
40+
pass
41+
42+
2943
def test_inherit():
3044
a = Dummy()
3145

@@ -77,12 +91,21 @@ def test_reset():
7791
class A(Dummy):
7892
pass
7993

94+
class B(DummyList):
95+
pass
96+
8097
a = A()
8198
assert a.x == 0
8299
a.x = torch.tensor(5)
83100
a.reset()
84101
assert a.x == 0
85102

103+
b = B()
104+
assert isinstance(b.x, list) and len(b.x) == 0
105+
b.x = torch.tensor(5)
106+
b.reset()
107+
assert isinstance(b.x, list) and len(b.x) == 0
108+
86109

87110
def test_update():
88111
class A(Dummy):

0 commit comments

Comments
 (0)