Skip to content

Commit 21f5844

Browse files
committed
Update defaults for WandbLogger's run name and project name (#14145)
1 parent 1a65a9a commit 21f5844

File tree

4 files changed

+25
-14
lines changed

4 files changed

+25
-14
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
- Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967))
1717
- Raised a `MisconfigurationException` if batch transfer hooks are overriden with `IPUAccelerator` ([13961](https://github.com/Lightning-AI/lightning/pull/13961))
18-
18+
- The default project name in `WandbLogger` is now "lightning_logs" ([#14145](https://github.com/Lightning-AI/lightning/pull/14145))
19+
- The `WandbLogger.name` property no longer returns the name of the experiment, and instead returns the project's name ([#14145](https://github.com/Lightning-AI/lightning/pull/14145))
1920
### Fixed
2021

2122
- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117))
@@ -28,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2829
- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))
2930
- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))
3031
- Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092))
32+
- Fixed an issue in which the default name for a run in `WandbLogger` would be set to the project name instead of a randomly generated string ([#14145](https://github.com/Lightning-AI/lightning/pull/14145))
3133
- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212))
3234

3335

src/pytorch_lightning/loggers/wandb.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def __init__(
260260
id: Optional[str] = None,
261261
anonymous: Optional[bool] = None,
262262
version: Optional[str] = None,
263-
project: Optional[str] = None,
263+
project: str = "lightning_logs",
264264
log_model: Union[str, bool] = False,
265265
experiment: Union[Run, RunDisabled, None] = None,
266266
prefix: str = "",
@@ -297,7 +297,7 @@ def __init__(
297297
self._checkpoint_callback: Optional["ReferenceType[Checkpoint]"] = None
298298
# set wandb init arguments
299299
self._wandb_init: Dict[str, Any] = dict(
300-
name=name or project,
300+
name=name,
301301
project=project,
302302
id=version or id,
303303
dir=save_dir,
@@ -306,6 +306,7 @@ def __init__(
306306
)
307307
self._wandb_init.update(**kwargs)
308308
# extract parameters
309+
self._project = self._wandb_init.get("project")
309310
self._save_dir = self._wandb_init.get("dir")
310311
self._name = self._wandb_init.get("name")
311312
self._id = self._wandb_init.get("id")
@@ -450,13 +451,13 @@ def save_dir(self) -> Optional[str]:
450451

451452
@property
452453
def name(self) -> Optional[str]:
453-
"""Gets the name of the experiment.
454+
"""The project name of this experiment.
454455
455456
Returns:
456-
The name of the experiment if the experiment exists else the name given to the constructor.
457+
The name of the project the current experiment belongs to. This name is not the same as `wandb.Run`'s
458+
name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead.
457459
"""
458-
# don't create an experiment if we don't have one
459-
return self._experiment.name if self._experiment else self._name
460+
return self._project
460461

461462
@property
462463
def version(self) -> Optional[str]:

tests/tests_pytorch/loggers/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
300300

301301

302302
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES_WO_NEPTUNE_WANDB)
303-
@RunIf(skip_windows=True, skip_hanging_spawn=True)
303+
@RunIf(skip_windows=True)
304304
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
305305
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
306306
_patch_comet_atexit(monkeypatch)

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525
from tests_pytorch.helpers.utils import no_warning_call
2626

2727

28+
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
29+
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
30+
def test_wandb_project_name(*_):
31+
logger = WandbLogger()
32+
assert logger.name == "lightning_logs"
33+
34+
logger = WandbLogger(project="project")
35+
assert logger.name == "project"
36+
37+
2838
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
2939
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
3040
def test_wandb_logger_init(wandb, monkeypatch):
@@ -48,7 +58,7 @@ def test_wandb_logger_init(wandb, monkeypatch):
4858
wandb.init.reset_mock()
4959
WandbLogger(project="test_project").experiment
5060
wandb.init.assert_called_once_with(
51-
name="test_project", dir=None, id=None, project="test_project", resume="allow", anonymous=None
61+
name=None, dir=None, id=None, project="test_project", resume="allow", anonymous=None
5262
)
5363

5464
# test wandb.init and setting logger experiment externally
@@ -91,7 +101,6 @@ def test_wandb_logger_init(wandb, monkeypatch):
91101
logger.watch("model", "log", 10, False)
92102
wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False)
93103

94-
assert logger.name == wandb.init().name
95104
assert logger.version == wandb.init().id
96105

97106

@@ -140,10 +149,9 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
140149
"""Test that the logger creates the folders and files in the right place."""
141150
monkeypatch.setattr(pytorch_lightning.loggers.wandb, "_WANDB_GREATER_EQUAL_0_12_10", True)
142151
wandb.run = None
143-
logger = WandbLogger(save_dir=str(tmpdir), offline=True)
152+
logger = WandbLogger(project="project", save_dir=str(tmpdir), offline=True)
144153
# the logger get initialized
145154
assert logger.version == wandb.init().id
146-
assert logger.name == wandb.init().name
147155

148156
# mock return values of experiment
149157
wandb.run = None
@@ -154,7 +162,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
154162
_ = logger.experiment
155163

156164
assert logger.version == "1"
157-
assert logger.name == "run_name"
165+
assert logger.name == "project"
158166
assert str(tmpdir) == logger.save_dir
159167
assert not os.listdir(tmpdir)
160168

@@ -164,7 +172,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
164172
assert trainer.log_dir == logger.save_dir
165173
trainer.fit(model)
166174

167-
assert trainer.checkpoint_callback.dirpath == str(tmpdir / "run_name" / version / "checkpoints")
175+
assert trainer.checkpoint_callback.dirpath == str(tmpdir / "project" / version / "checkpoints")
168176
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"}
169177
assert trainer.log_dir == logger.save_dir
170178

0 commit comments

Comments
 (0)