Skip to content

Commit 215b954

Browse files
authored
Merge branch 'master' into bugfix/ep_end_ckpt
2 parents b4811b3 + 4913cbb commit 215b954

File tree

19 files changed

+136
-93
lines changed

19 files changed

+136
-93
lines changed

docs/source/metrics.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,56 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
137137
To change this, after initializing the metric, the method ``.persistent(mode)`` can
138138
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
139139

140+
*******************
141+
Metrics and devices
142+
*******************
143+
144+
Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave
145+
similar to buffers and parameters of modules. This means that metrics states should
146+
be moved to the same device as the input of the metric:
147+
148+
.. code-block:: python
149+
150+
import torch
151+
from pytorch_lightning.metrics import Accuracy
152+
153+
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
154+
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
155+
156+
# Metric states are always initialized on cpu, and needs to be moved to
157+
# the correct device
158+
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
159+
out = confmat(preds, target)
160+
print(out.device) # cuda:0
161+
162+
However, when **properly defined** inside a :class:`~pytorch_lightning.core.lightning.LightningModule`
163+
, Lightning will automatically move the metrics to the same device as the data. Being
164+
**properly defined** means that the metric is correctly identified as a child module of the
165+
model (check ``.children()`` attribute of the model). Therefore, metrics cannot be placed
166+
in native python ``list`` and ``dict``, as they will not be correctly identified
167+
as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and instead of
168+
``dict`` use :class:`~torch.nn.ModuleDict`.
169+
170+
.. testcode::
171+
172+
class MyModule(LightningModule):
173+
def __init__(self):
174+
...
175+
# valid ways metrics will be identified as child modules
176+
self.metric1 = pl.metrics.Accuracy()
177+
self.metric2 = torch.nn.ModuleList(pl.metrics.Accuracy())
178+
self.metric3 = torch.nn.ModuleDict({'accuracy': Accuracy()})
179+
180+
def training_step(self, batch, batch_idx):
181+
# all metrics will be on the same device as the input batch
182+
data, target = batch
183+
preds = self(data)
184+
...
185+
val1 = self.metric1(preds, target)
186+
val2 = self.metric2[0](preds, target)
187+
val3 = self.metric3['accuracy'](preds, target)
188+
189+
140190
*********************
141191
Implementing a Metric
142192
*********************

docs/source/multi_gpu.rst

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -239,23 +239,6 @@ Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`.
239239
to be in "exclusive mode", such that only one process at a time can access them.
240240
For more details see the :ref:`Trainer guide <trainer>`.
241241

242-
243-
Remove CUDA flags
244-
^^^^^^^^^^^^^^^^^
245-
246-
CUDA flags make certain GPUs visible to your script.
247-
Lightning sets these for you automatically, there's NO NEED to do this yourself.
248-
249-
.. testcode::
250-
251-
# lightning will set according to what you give the trainer
252-
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
253-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
254-
255-
However, when using a cluster, Lightning will NOT set these flags (and you should not either).
256-
SLURM will set these for you.
257-
For more details see the :ref:`SLURM cluster guide <slurm>`.
258-
259242
----------
260243

261244
Distributed modes

pl_examples/basic_examples/autoencoder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@
1515
from argparse import ArgumentParser
1616

1717
import torch
18-
import torch.nn.functional as F
1918
from torch import nn
20-
from torch.utils.data import DataLoader
21-
from torch.utils.data import random_split
19+
import torch.nn.functional as F
20+
from torch.utils.data import DataLoader, random_split
2221

22+
from pl_examples import cli_lightning_logo, TORCHVISION_AVAILABLE
2323
import pytorch_lightning as pl
24-
from pl_examples import TORCHVISION_AVAILABLE, cli_lightning_logo
2524

2625
if TORCHVISION_AVAILABLE:
27-
from torchvision.datasets.mnist import MNIST
2826
from torchvision import transforms
27+
from torchvision.datasets.mnist import MNIST
2928
else:
3029
from tests.base.datasets import MNIST
3130

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from torch.nn import functional as F
1919
from torch.utils.data import DataLoader, random_split
2020

21+
from pl_examples import cli_lightning_logo, DATASETS_PATH, TORCHVISION_AVAILABLE
2122
import pytorch_lightning as pl
22-
from pl_examples import DATASETS_PATH, TORCHVISION_AVAILABLE, cli_lightning_logo
2323

2424
if TORCHVISION_AVAILABLE:
25-
from torchvision.datasets.mnist import MNIST
2625
from torchvision import transforms
26+
from torchvision.datasets.mnist import MNIST
2727
else:
2828
from tests.base.datasets import MNIST
2929

pl_examples/basic_examples/conv_sequential_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@
2020
To run:
2121
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential
2222
"""
23-
import math
2423
from argparse import ArgumentParser
24+
import math
2525

2626
import torch
2727
import torch.nn as nn
2828
import torch.nn.functional as F
2929
import torchvision
3030

31-
import pytorch_lightning as pl
3231
from pl_examples import cli_lightning_logo
32+
import pytorch_lightning as pl
3333
from pytorch_lightning import Trainer
3434
from pytorch_lightning.metrics.functional import accuracy
3535
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,29 @@
1313
# limitations under the License.
1414
from abc import ABC
1515
from argparse import ArgumentParser
16+
from distutils.version import LooseVersion
1617
from random import shuffle
1718
from warnings import warn
18-
from distutils.version import LooseVersion
1919

2020
import numpy as np
2121
import torch
2222
from torch.nn import functional as F
2323
from torch.utils.data import random_split
2424

25+
from pl_examples import cli_lightning_logo, DALI_AVAILABLE, TORCHVISION_AVAILABLE
2526
import pytorch_lightning as pl
26-
from pl_examples import TORCHVISION_AVAILABLE, DALI_AVAILABLE, cli_lightning_logo
2727

2828
if TORCHVISION_AVAILABLE:
29-
from torchvision.datasets.mnist import MNIST
3029
from torchvision import transforms
30+
from torchvision.datasets.mnist import MNIST
3131
else:
3232
from tests.base.datasets import MNIST
3333

3434
if DALI_AVAILABLE:
35+
from nvidia.dali import __version__ as dali_version
3536
from nvidia.dali import ops
3637
from nvidia.dali.pipeline import Pipeline
3738
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
38-
from nvidia.dali import __version__ as dali_version
3939

4040
NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0')
4141
if NEW_DALI_API:

pl_examples/basic_examples/simple_image_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import torch
1919
from torch.nn import functional as F
2020

21-
import pytorch_lightning as pl
2221
from pl_examples import cli_lightning_logo
2322
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
23+
import pytorch_lightning as pl
2424

2525

2626
class LitClassifier(pl.LightningModule):

pl_examples/bug_report_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
# --------------------------------------------
2121
# --------------------------------------------
2222
import os
23+
2324
import torch
2425
from torch.utils.data import Dataset
2526

2627
from pl_examples import cli_lightning_logo
27-
from pytorch_lightning import Trainer, LightningModule
28+
from pytorch_lightning import LightningModule, Trainer
2829

2930

3031
class RandomDataset(Dataset):

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,21 @@
3838
from collections import OrderedDict
3939
from pathlib import Path
4040
from tempfile import TemporaryDirectory
41-
from typing import Optional, Generator, Union
41+
from typing import Generator, Optional, Union
4242

4343
import torch
44-
import torch.nn.functional as F
4544
from torch import optim
4645
from torch.nn import Module
46+
import torch.nn.functional as F
4747
from torch.optim.lr_scheduler import MultiStepLR
4848
from torch.optim.optimizer import Optimizer
4949
from torch.utils.data import DataLoader
50-
from torchvision import models
51-
from torchvision import transforms
50+
from torchvision import models, transforms
5251
from torchvision.datasets import ImageFolder
5352
from torchvision.datasets.utils import download_and_extract_archive
5453

55-
import pytorch_lightning as pl
5654
from pl_examples import cli_lightning_logo
55+
import pytorch_lightning as pl
5756
from pytorch_lightning import _logger as log
5857

5958
BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)

pl_examples/domain_templates/generative_adversarial_net.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@
1919
2020
tensorboard --logdir default
2121
"""
22-
import os
2322
from argparse import ArgumentParser, Namespace
23+
import os
2424

2525
import numpy as np
2626
import torch
2727
import torch.nn as nn
2828
import torch.nn.functional as F # noqa
29-
import torchvision
30-
import torchvision.transforms as transforms
3129
from torch.utils.data import DataLoader
30+
import torchvision
3231
from torchvision.datasets import MNIST
32+
import torchvision.transforms as transforms
3333

3434
from pl_examples import cli_lightning_logo
35-
from pytorch_lightning.core import LightningModule, LightningDataModule
35+
from pytorch_lightning.core import LightningDataModule, LightningModule
3636
from pytorch_lightning.trainer import Trainer
3737

3838

0 commit comments

Comments
 (0)