@@ -35,8 +35,8 @@ def __init__(self):
3535 def disable(self):
3636 self.enable = False
3737
38- def on_train_batch_end(self, trainer, pl_module, outputs):
39- super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
38+ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx ):
39+ super().on_train_batch_end(trainer, pl_module, outputs, batch_idx ) # don't forget this :)
4040 percent = (self.train_batch_idx / self.total_train_batches) * 100
4141 sys.stdout.flush()
4242 sys.stdout.write(f'{percent:.01f} percent complete \r')
@@ -161,7 +161,7 @@ def on_train_start(self, trainer, pl_module):
161161 def on_train_epoch_start (self , trainer , pl_module ):
162162 self ._train_batch_idx = trainer .fit_loop .epoch_loop .batch_progress .current .completed
163163
164- def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
164+ def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
165165 self ._train_batch_idx += 1
166166
167167 def on_validation_start (self , trainer , pl_module ):
0 commit comments