Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801))

- Fixed an issue with Trainer constructor silently ignoring unkown/misspelled arguments ([#1820](https://github.com/PyTorchLightning/pytorch-lightning/pull/1820))

## [0.7.5] - 2020-04-27

### Changed
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

class TrainerCallbackHookMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: List[Callback] = []
self.get_model: Callable = ...
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
callbacks: List[Callback] = []
get_model: Callable = ...

def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def __init__(
use_amp=None, # backward compatible, todo: remove in v0.9.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
**kwargs
):
r"""

Expand Down Expand Up @@ -305,6 +304,7 @@ def __init__(
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
"""
super().__init__()

self.deterministic = deterministic
torch.backends.cudnn.deterministic = self.deterministic
Expand Down
21 changes: 1 addition & 20 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_inf_train_dataloader(tmpdir, check_interval):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
train_check_interval=check_interval,
val_check_interval=check_interval
)
result = trainer.fit(model)
# verify training completed
Expand All @@ -315,25 +315,6 @@ def test_inf_val_dataloader(tmpdir, check_interval):
assert result == 1


@pytest.mark.parametrize('check_interval', [50, 1.0])
def test_inf_test_dataloader(tmpdir, check_interval):
"""Test inf test data loader (e.g. IterableDataset)"""

model = EvalModelTemplate()
model.test_dataloader = model.test_dataloader__infinite

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
test_check_interval=check_interval,
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_error_on_zero_len_dataloader(tmpdir):
""" Test that error is raised if a zero-length dataloader is defined """

Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,41 @@ def test_trainer_config(trainer_kwargs, expected):
assert trainer.on_gpu is expected["on_gpu"]
assert trainer.single_gpu is expected["single_gpu"]
assert trainer.num_processes == expected["num_processes"]


def test_trainer_subclassing():
model = EvalModelTemplate()

# First way of pulling out args from signature is to list them
class TrainerSubclass(Trainer):

def __init__(self, custom_arg, *args, custom_kwarg='test', **kwargs):
super().__init__(*args, **kwargs)
self.custom_arg = custom_arg
self.custom_kwarg = custom_kwarg

trainer = TrainerSubclass(123, custom_kwarg='custom', fast_dev_run=True)
result = trainer.fit(model)
assert result == 1
assert trainer.custom_arg == 123
assert trainer.custom_kwarg == 'custom'
assert trainer.fast_dev_run

# Second way is to pop from the dict
# It's a special case because Trainer does not have any positional args
class TrainerSubclass(Trainer):

def __init__(self, **kwargs):
self.custom_arg = kwargs.pop('custom_arg', 0)
self.custom_kwarg = kwargs.pop('custom_kwarg', 'test')
super().__init__(**kwargs)

trainer = TrainerSubclass(custom_kwarg='custom', fast_dev_run=True)
result = trainer.fit(model)
assert result == 1
assert trainer.custom_kwarg == 'custom'
assert trainer.fast_dev_run

# when we pass in an unknown arg, the base class should complain
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'abcdefg'") as e:
TrainerSubclass(abcdefg='unknown_arg')