Skip to content

Commit fb2d542

Browse files
authored
Merge branch 'master' into bug/scheduler-name
2 parents f917592 + 16feb51 commit fb2d542

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+985
-991
lines changed

.github/workflows/ci_test-base.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ jobs:
7676
with:
7777
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
7878
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
79-
# Use always() to always run this step to publish test results when there are test failures
80-
if: always()
79+
if: failure()
8180

8281
- name: Statistics
8382
if: success()

.github/workflows/ci_test-conda.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,4 @@ jobs:
5050
with:
5151
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
5252
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
53-
# Use always() to always run this step to publish test results when there are test failures
54-
if: always()
53+
if: failure()

.github/workflows/ci_test-full.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ jobs:
129129
with:
130130
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
131131
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
132-
# Use always() to always run this step to publish test results when there are test failures
133-
if: always()
132+
if: failure()
134133

135134
- name: Statistics
136135
if: success()

CHANGELOG.md

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,47 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [unreleased.Features] - YYYY-MM-DD
9+
10+
### Added
11+
12+
13+
### Changed
14+
15+
16+
### Deprecated
17+
18+
19+
### Removed
20+
21+
22+
### Fixed
23+
24+
25+
26+
## [unreleased.BugFix] - YYYY-MM-DD
27+
28+
### Added
29+
30+
31+
### Changed
32+
33+
34+
### Deprecated
35+
36+
37+
### Removed
38+
39+
40+
### Fixed
41+
42+
- Fixed trainer by default `None` in `DDPAccelerator` ([#4915](https://github.com/PyTorchLightning/pytorch-lightning/pull/4915))
43+
44+
45+
- Fixed `LightningOptimizer` exposes optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095))
46+
47+
48+
849
## [1.1.0] - 2020-12-09
950

1051
### Added
@@ -44,9 +85,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4485

4586
### Changed
4687

47-
- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549))
4888
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
49-
- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))
89+
- `WandbLogger` does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))
5090
- Changed `automatic_optimization` to be a model attribute ([#4602](https://github.com/PyTorchLightning/pytorch-lightning/pull/4602))
5191
- Changed `Simple Profiler` report to order by percentage time spent + num calls ([#4880](https://github.com/PyTorchLightning/pytorch-lightning/pull/4880))
5292
- Simplify optimization Logic ([#4984](https://github.com/PyTorchLightning/pytorch-lightning/pull/4984))
@@ -64,6 +104,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
64104
### Removed
65105

66106
- Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004))
107+
- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549))
67108

68109
### Fixed
69110

docs/source/introduction_guide.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,8 @@ In this method we do all the preparation we need to do once (instead of on every
601601
def setup(self, stage):
602602
# transform
603603
transform=transforms.Compose([transforms.ToTensor()])
604-
MNIST(os.getcwd(), train=True, download=False, transform=transform)
605-
MNIST(os.getcwd(), train=False, download=False, transform=transform)
604+
mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
605+
mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transform)
606606
607607
# train/val split
608608
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

docs/source/multi_gpu.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ It is highly recommended to use Sharded Training in multi-GPU environments where
663663
A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful.
664664
Work within the future will bring optional sharding to activations and model parameters to reduce memory further, but come with a speed cost.
665665

666-
To use Sharded Training, you need to first install FairScale using the command below or install all extras using ``pip install pytorch-lightning["extra"]``.
666+
To use Sharded Training, you need to first install FairScale using the command below.
667667

668668
.. code-block:: bash
669669

docs/source/optimizers.rst

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -191,46 +191,69 @@ override the :meth:`optimizer_step` function.
191191
192192
For example, here step optimizer A every 2 batches and optimizer B every 4 batches
193193
194-
.. testcode::
194+
.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`.
195195
196-
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
197-
optimizer.step()
196+
.. testcode::
198197
199198
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
200199
optimizer.zero_grad()
201200
202201
# Alternating schedule for optimizer steps (ie: GANs)
203-
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
202+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
204203
# update generator opt every 2 steps
205204
if optimizer_i == 0:
206205
if batch_nb % 2 == 0 :
207-
optimizer.step()
208-
optimizer.zero_grad()
206+
optimizer.step(closure=closure)
209207
210208
# update discriminator opt every 4 steps
211209
if optimizer_i == 1:
212210
if batch_nb % 4 == 0 :
213-
optimizer.step()
214-
optimizer.zero_grad()
211+
optimizer.step(closure=closure)
212+
213+
.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow.
214+
215+
.. testcode::
216+
217+
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
218+
optimizer.zero_grad()
219+
220+
# Alternating schedule for optimizer steps (ie: GANs)
221+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
222+
# update generator opt every 2 steps
223+
if optimizer_i == 0:
224+
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0)
215225
216-
# ...
217-
# add as many optimizers as you want
226+
# update discriminator opt every 4 steps
227+
if optimizer_i == 1:
228+
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0)
218229
219230
Here we add a learning-rate warm up
220231
221232
.. testcode::
222233
223234
# learning rate warm-up
224-
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
235+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
225236
# warm up lr
226237
if self.trainer.global_step < 500:
227238
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
228239
for pg in optimizer.param_groups:
229240
pg['lr'] = lr_scale * self.hparams.learning_rate
230241
231242
# update params
232-
optimizer.step()
233-
optimizer.zero_grad()
243+
optimizer.step(closure=closure)
244+
245+
The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step.
246+
247+
.. testcode::
248+
249+
from pytorch_lightning.core.optimizer import LightningOptimizer
250+
251+
# function hook in LightningModule
252+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
253+
if not isinstance(optimizer, LightningOptimizer):
254+
# wraps into LightingOptimizer only for running step
255+
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
256+
optimizer.step(closure=closure)
234257
235258
----------
236259

pytorch_lightning/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Root package info."""
22

3-
__version__ = '1.1.0'
3+
__version__ = '1.1.1rc0'
44
__author__ = 'William Falcon et al.'
55
__author_email__ = '[email protected]'
66
__license__ = 'Apache-2.0'

pytorch_lightning/core/hooks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Various hooks to be used in the Lightning code."""
1616

17-
from typing import Any, Dict, List, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import torch
2020
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
@@ -501,7 +501,7 @@ def val_dataloader(self):
501501
will have an argument ``dataloader_idx`` which matches the order here.
502502
"""
503503

504-
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
504+
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
505505
"""
506506
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
507507
wrapped in a custom data structure.
@@ -549,6 +549,7 @@ def transfer_batch_to_device(self, batch, device)
549549
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
550550
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
551551
"""
552+
device = device or self.device
552553
return move_data_to_device(batch, device)
553554

554555

pytorch_lightning/core/lightning.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tempfile
2323
from abc import ABC
2424
from argparse import Namespace
25+
from pathlib import Path
2526
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2627

2728
import torch
@@ -1171,7 +1172,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11711172

11721173
def optimizer_step(
11731174
self,
1174-
*args,
11751175
epoch: int = None,
11761176
batch_idx: int = None,
11771177
optimizer: Optimizer = None,
@@ -1180,7 +1180,6 @@ def optimizer_step(
11801180
on_tpu: bool = None,
11811181
using_native_amp: bool = None,
11821182
using_lbfgs: bool = None,
1183-
**kwargs,
11841183
) -> None:
11851184
r"""
11861185
Override this method to adjust the default way the
@@ -1255,7 +1254,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
12551254
if not isinstance(optimizer, LightningOptimizer):
12561255
# wraps into LightingOptimizer only for running step
12571256
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
1258-
optimizer.step(closure=optimizer_closure, *args, **kwargs)
1257+
optimizer.step(closure=optimizer_closure)
12591258

12601259
def optimizer_zero_grad(
12611260
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
@@ -1533,12 +1532,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
15331532
else:
15341533
self._hparams = hp
15351534

1536-
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
1537-
"""Saves the model in ONNX format
1535+
@torch.no_grad()
1536+
def to_onnx(
1537+
self,
1538+
file_path: Union[str, Path],
1539+
input_sample: Optional[Any] = None,
1540+
**kwargs,
1541+
):
1542+
"""
1543+
Saves the model in ONNX format
15381544
15391545
Args:
1540-
file_path: The path of the file the model should be saved to.
1541-
input_sample: A sample of an input tensor for tracing.
1546+
file_path: The path of the file the onnx model should be saved to.
1547+
input_sample: An input for tracing. Default: None (Use self.example_input_array)
15421548
**kwargs: Will be passed to torch.onnx.export function.
15431549
15441550
Example:
@@ -1557,31 +1563,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
15571563
... os.path.isfile(tmpfile.name)
15581564
True
15591565
"""
1566+
mode = self.training
15601567

1561-
if isinstance(input_sample, Tensor):
1562-
input_data = input_sample
1563-
elif self.example_input_array is not None:
1564-
input_data = self.example_input_array
1565-
else:
1566-
if input_sample is not None:
1568+
if input_sample is None:
1569+
if self.example_input_array is None:
15671570
raise ValueError(
1568-
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
1571+
"Could not export to ONNX since neither `input_sample` nor"
1572+
" `model.example_input_array` attribute is set."
15691573
)
1570-
raise ValueError(
1571-
"Could not export to ONNX since neither `input_sample` nor"
1572-
" `model.example_input_array` attribute is set."
1573-
)
1574-
input_data = input_data.to(self.device)
1574+
input_sample = self.example_input_array
1575+
1576+
input_sample = self.transfer_batch_to_device(input_sample)
1577+
15751578
if "example_outputs" not in kwargs:
15761579
self.eval()
1577-
with torch.no_grad():
1578-
kwargs["example_outputs"] = self(input_data)
1580+
kwargs["example_outputs"] = self(input_sample)
15791581

1580-
torch.onnx.export(self, input_data, file_path, **kwargs)
1582+
torch.onnx.export(self, input_sample, file_path, **kwargs)
1583+
self.train(mode)
15811584

1585+
@torch.no_grad()
15821586
def to_torchscript(
1583-
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
1584-
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
1587+
self,
1588+
file_path: Optional[Union[str, Path]] = None,
1589+
method: Optional[str] = 'script',
1590+
example_inputs: Optional[Any] = None,
1591+
**kwargs,
15851592
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
15861593
"""
15871594
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
@@ -1593,7 +1600,7 @@ def to_torchscript(
15931600
Args:
15941601
file_path: Path where to save the torchscript. Default: None (no file saved).
15951602
method: Whether to use TorchScript's script or trace method. Default: 'script'
1596-
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
1603+
example_inputs: An input to be used to do tracing when method is set to 'trace'.
15971604
Default: None (Use self.example_input_array)
15981605
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
15991606
:func:`torch.jit.trace` function.
@@ -1627,21 +1634,27 @@ def to_torchscript(
16271634
This LightningModule as a torchscript, regardless of whether file_path is
16281635
defined or not.
16291636
"""
1630-
16311637
mode = self.training
1632-
with torch.no_grad():
1633-
if method == 'script':
1634-
torchscript_module = torch.jit.script(self.eval(), **kwargs)
1635-
elif method == 'trace':
1636-
# if no example inputs are provided, try to see if model has example_input_array set
1637-
if example_inputs is None:
1638-
example_inputs = self.example_input_array
1639-
# automatically send example inputs to the right device and use trace
1640-
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
1641-
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1642-
else:
1643-
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
1644-
f"{method}")
1638+
1639+
if method == 'script':
1640+
torchscript_module = torch.jit.script(self.eval(), **kwargs)
1641+
elif method == 'trace':
1642+
# if no example inputs are provided, try to see if model has example_input_array set
1643+
if example_inputs is None:
1644+
if self.example_input_array is None:
1645+
raise ValueError(
1646+
'Choosing method=`trace` requires either `example_inputs`'
1647+
' or `model.example_input_array` to be defined'
1648+
)
1649+
example_inputs = self.example_input_array
1650+
1651+
# automatically send example inputs to the right device and use trace
1652+
example_inputs = self.transfer_batch_to_device(example_inputs)
1653+
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1654+
else:
1655+
raise ValueError("The 'method' parameter only supports 'script' or 'trace',"
1656+
f" but value given was: {method}")
1657+
16451658
self.train(mode)
16461659

16471660
if file_path is not None:

0 commit comments

Comments
 (0)