Skip to content

Commit 6b5da9b

Browse files
committed
Move test
1 parent c34bf81 commit 6b5da9b

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

tests/models/test_hooks.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -229,27 +229,6 @@ def train_dataloader(self):
229229
trainer.fit(model)
230230

231231

232-
@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 70)])
233-
def test_on_train_batch_start_hook(max_epochs, batch_idx_):
234-
235-
class CurrentModel(BoringModel):
236-
237-
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
238-
if batch_idx == batch_idx_:
239-
return -1
240-
241-
model = CurrentModel()
242-
trainer = Trainer(max_epochs=max_epochs)
243-
trainer.fit(model)
244-
assert len(model.val_dataloader()) < 70
245-
if batch_idx_ > len(model.val_dataloader()) - 1:
246-
assert trainer.train_loop.batch_idx == len(model.val_dataloader()) - 1
247-
assert trainer.global_step == len(model.val_dataloader()) * max_epochs
248-
else:
249-
assert trainer.train_loop.batch_idx == batch_idx_
250-
assert trainer.global_step == batch_idx_ * max_epochs
251-
252-
253232
def test_trainer_model_hook_system(tmpdir):
254233
"""Test the LightningModule hook system."""
255234

tests/trainer/loops/test_training_loop.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytest
1415
import torch
1516

1617
from pytorch_lightning import seed_everything, Trainer
@@ -201,3 +202,23 @@ def run_training(**trainer_kwargs):
201202
num_sanity_val_steps=2,
202203
)
203204
assert torch.allclose(sequence0, sequence1)
205+
206+
207+
@pytest.mark.parametrize(['max_epochs', 'batch_idx_'], [(2, 5), (3, 8), (4, 12)])
208+
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_):
209+
210+
class CurrentModel(BoringModel):
211+
212+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
213+
if batch_idx == batch_idx_:
214+
return -1
215+
216+
model = CurrentModel()
217+
trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10)
218+
trainer.fit(model)
219+
if batch_idx_ > trainer.num_training_batches - 1:
220+
assert trainer.train_loop.batch_idx == trainer.num_training_batches - 1
221+
assert trainer.global_step == trainer.num_training_batches * max_epochs
222+
else:
223+
assert trainer.train_loop.batch_idx == batch_idx_
224+
assert trainer.global_step == batch_idx_ * max_epochs

0 commit comments

Comments
 (0)