Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Root package info."""

import logging
import os

from pytorch_lightning.__about__ import * # noqa: F401, F403

Expand All @@ -14,9 +13,6 @@
_logger.addHandler(logging.StreamHandler())
_logger.propagate = False

_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from pytorch_lightning.callbacks import Callback # noqa: E402
from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402
from pytorch_lightning.trainer import Trainer # noqa: E402
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ def reset_deterministic_algorithm():
torch.set_deterministic(False)


@pytest.fixture
def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.

Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
"""
import logging

lightning_logger = logging.getLogger("pytorch_lightning")
propagate = lightning_logger.propagate
lightning_logger.propagate = True
yield caplog
lightning_logger.propagate = propagate


@pytest.fixture
def tmpdir_server(tmpdir):
if sys.version_info >= (3, 7):
Expand Down
27 changes: 27 additions & 0 deletions tests/utilities/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,30 @@

assert "test_warnings.py:39: UserWarning: test6" in output
assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output

# check that logging is properly configured
import logging

root_logger = logging.getLogger()
lightning_logger = logging.getLogger("pytorch_lightning")
# should have a `StreamHandler`
assert lightning_logger.hasHandlers() and len(lightning_logger.handlers) == 1
# set our own stream for testing
handler = lightning_logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
stderr = StringIO()
# necessary with `propagate = False`
lightning_logger.handlers[0].stream = stderr

# necessary with `propagate = True`
with redirect_stderr(stderr):
# Lightning should not configure the root `logging` logger by default
logging.info("test1")
root_logger.info("test1")
# but our logger instance
lightning_logger.info("test2")
# level is set to INFO
lightning_logger.debug("test3")

output = stderr.getvalue()
assert output == "test2\n", repr(output)