diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 2996b5182cd94..c9d914573fe71 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -1,7 +1,6 @@ """Root package info.""" import logging -import os from pytorch_lightning.__about__ import * # noqa: F401, F403 @@ -14,9 +13,6 @@ _logger.addHandler(logging.StreamHandler()) _logger.propagate = False -_PACKAGE_ROOT = os.path.dirname(__file__) -_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) - from pytorch_lightning.callbacks import Callback # noqa: E402 from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 from pytorch_lightning.trainer import Trainer # noqa: E402 diff --git a/tests/conftest.py b/tests/conftest.py index 3d5548b7bd0ae..c5c86f780f912 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,6 +99,21 @@ def reset_deterministic_algorithm(): torch.set_deterministic(False) +@pytest.fixture +def caplog(caplog): + """Workaround for https://github.com/pytest-dev/pytest/issues/3697. + + Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``. + """ + import logging + + lightning_logger = logging.getLogger("pytorch_lightning") + propagate = lightning_logger.propagate + lightning_logger.propagate = True + yield caplog + lightning_logger.propagate = propagate + + @pytest.fixture def tmpdir_server(tmpdir): if sys.version_info >= (3, 7): diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py index d1222672b7595..6ef3793b5e0f3 100644 --- a/tests/utilities/test_warnings.py +++ b/tests/utilities/test_warnings.py @@ -50,3 +50,30 @@ assert "test_warnings.py:39: UserWarning: test6" in output assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output + + # check that logging is properly configured + import logging + + root_logger = logging.getLogger() + lightning_logger = logging.getLogger("pytorch_lightning") + # should have a `StreamHandler` + assert lightning_logger.hasHandlers() and len(lightning_logger.handlers) == 1 + # set our own stream for testing + handler = lightning_logger.handlers[0] + assert isinstance(handler, logging.StreamHandler) + stderr = StringIO() + # necessary with `propagate = False` + lightning_logger.handlers[0].stream = stderr + + # necessary with `propagate = True` + with redirect_stderr(stderr): + # Lightning should not configure the root `logging` logger by default + logging.info("test1") + root_logger.info("test1") + # but our logger instance + lightning_logger.info("test2") + # level is set to INFO + lightning_logger.debug("test3") + + output = stderr.getvalue() + assert output == "test2\n", repr(output)