Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported ([#14117](https://github.com/Lightning-AI/lightning/pull/14117))


- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


Expand Down
10 changes: 6 additions & 4 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,15 +501,17 @@ def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = Non
It patches the ``__init__`` method.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
# Check that __init__ belongs to the class
# https://stackoverflow.com/a/5253424
if "__init__" in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
wrapped.add(cls.__init__)
yield
for cls in classes:
if hasattr(cls, "_old_init"):
# Check that _old_init belongs to the class
# https://stackoverflow.com/a/5253424
if "_old_init" in cls.__dict__:
cls.__init__ = cls._old_init
del cls._old_init

Expand Down
25 changes: 25 additions & 0 deletions tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from dataclasses import dataclass

import pytest
Expand Down Expand Up @@ -173,6 +174,30 @@ def __init__(self, randomize, *args, **kwargs):
assert isinstance(new_dataloader, GoodImpl)


def test_replace_init_method_multiple_loaders_without_init():
"""In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`
method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent
class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error
occured only sometimes because it depends on the order in which we are iterating over a set of classes we are
patching.

This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and
b) the mechanism checking for presence of `_old_init` works as expected.
"""
classes = [DataLoader]
for i in range(100):
classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {}))

with _replace_init_method(DataLoader, "dataset"):
for cls in classes[1:]: # First one is `DataLoader`
assert "_old_init" not in cls.__dict__
assert hasattr(cls, "_old_init")

assert "_old_init" in DataLoader.__dict__
assert hasattr(DataLoader, "_old_init")


class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
self.at1 = attribute1
Expand Down