@@ -165,6 +165,7 @@ def training_epoch_end(self, outputs) -> None:
165165 with mock .patch .object (Strategy , "backward" , wraps = trainer .strategy .backward ) as bwd_mock :
166166 trainer .fit (model )
167167 assert bwd_mock .call_count == limit_train_batches * 3
168+ assert trainer .global_step == limit_train_batches * 2
168169
169170
170171def test_multiple_optimizers_manual_log (tmpdir ):
@@ -524,18 +525,14 @@ def optimizer_closure():
524525 weight_after = self .layer .weight .clone ()
525526 assert not torch .equal (weight_before , weight_after )
526527
527- def configure_optimizers (self ):
528- return torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
529-
530528 model = TestModel ()
531- model .val_dataloader = None
532529 model .training_epoch_end = None
533530
534531 limit_train_batches = 2
535532 trainer = Trainer (
536533 default_root_dir = tmpdir ,
537534 limit_train_batches = limit_train_batches ,
538- limit_val_batches = 2 ,
535+ limit_val_batches = 0 ,
539536 max_epochs = 1 ,
540537 log_every_n_steps = 1 ,
541538 )
@@ -547,58 +544,45 @@ def configure_optimizers(self):
547544 assert trainer .progress_bar_metrics ["train_loss_epoch" ] == torch .stack (model ._losses ).mean ()
548545
549546
550- def test_step_with_optimizer_closure_and_accumulated_grad (tmpdir ):
551- """Tests that `step` works with optimizer_closure and accumulated_grad."""
552-
547+ def test_step_with_optimizer_closure_2 (tmpdir ):
553548 class TestModel (BoringModel ):
554549 def __init__ (self ):
555550 super ().__init__ ()
556551 self .automatic_optimization = False
557552
558553 def training_step (self , batch , batch_idx ):
559- # manual
560554 opt = self .optimizers ()
561555 x = batch [0 ]
562-
563- loss_1 = self (x )
564- loss_1 = self .loss (loss_1 , loss_1 )
556+ loss = self (x ).sum ()
565557
566558 def optimizer_closure ():
567559 # emulate bayesian optimization.
568560 num_backward = 1
569561 for backward_idx in range (num_backward + 1 ):
570562 retain_graph = num_backward != backward_idx
571- self .manual_backward (loss_1 , retain_graph = retain_graph )
563+ self .manual_backward (loss , retain_graph = retain_graph )
572564
573565 weight_before = self .layer .weight .clone ()
574-
575566 opt .step (closure = optimizer_closure )
576-
577567 weight_after = self .layer .weight .clone ()
578- if not self .trainer .fit_loop ._should_accumulate ():
579- assert not torch .equal (weight_before , weight_after )
580- else :
581- assert self .layer .weight .grad is not None
582-
583- def configure_optimizers (self ):
584- return torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
568+ assert not torch .equal (weight_before , weight_after )
585569
586570 model = TestModel ()
587- model .val_dataloader = None
588571 model .training_epoch_end = None
589572
590573 limit_train_batches = 4
591574 trainer = Trainer (
592575 default_root_dir = tmpdir ,
593576 limit_train_batches = limit_train_batches ,
594- limit_val_batches = 2 ,
577+ limit_val_batches = 0 ,
595578 max_epochs = 1 ,
596579 log_every_n_steps = 1 ,
597580 )
598581
599582 with mock .patch .object (Strategy , "backward" , wraps = trainer .strategy .backward ) as bwd_mock :
600583 trainer .fit (model )
601584 assert bwd_mock .call_count == limit_train_batches * 2
585+ assert trainer .global_step == limit_train_batches
602586
603587
604588@patch ("torch.optim.SGD.step" )
@@ -614,41 +598,23 @@ def on_train_start(self) -> None:
614598 step_mock .reset_mock ()
615599
616600 def training_step (self , batch , batch_idx ):
617- # manual
618601 opt = self .optimizers ()
619- x = batch [0 ]
620-
621- loss_1 = self (x )
622- loss_1 = self .loss (loss_1 , loss_1 )
623-
624- def optimizer_closure ():
625- # emulate bayesian optimization.
626- num_backward = 1
627- for backward_idx in range (num_backward + 1 ):
628- retain_graph = num_backward != backward_idx
629- self .manual_backward (loss_1 , retain_graph = retain_graph )
630-
631- opt .step (closure = optimizer_closure )
632- opt .zero_grad ()
633-
634- def configure_optimizers (self ):
635- return torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
602+ opt .step (closure = lambda : ..., foo = 123 )
636603
637604 model = TestModel ()
638- model .val_dataloader = None
639605 model .training_epoch_end = None
640606
641- limit_train_batches = 4
607+ limit_train_batches = 2
642608 trainer = Trainer (
643609 default_root_dir = tmpdir ,
644610 limit_train_batches = limit_train_batches ,
645- limit_val_batches = 2 ,
611+ limit_val_batches = 0 ,
646612 max_epochs = 1 ,
647- log_every_n_steps = 1 ,
648613 )
649614
650615 trainer .fit (model )
651- assert step_mock .mock_calls == [call (closure = ANY ) for _ in range (limit_train_batches )]
616+ assert step_mock .mock_calls == [call (closure = ANY , foo = 123 ) for _ in range (limit_train_batches )]
617+ assert trainer .global_step == limit_train_batches
652618
653619
654620@patch ("torch.optim.Adam.step" )
@@ -724,6 +690,7 @@ def configure_optimizers(self):
724690 trainer .fit (model )
725691 assert mock_sgd_step .mock_calls == [call (closure = ANY , optim = "sgd" ) for _ in range (4 )]
726692 assert mock_adam_step .mock_calls == [call (closure = ANY ) for _ in range (2 )]
693+ assert trainer .global_step == 4 + 2
727694
728695
729696class TesManualOptimizationDDPModel (BoringModel ):
0 commit comments