diff --git a/tensorflow_addons/callbacks/time_stopping.py b/tensorflow_addons/callbacks/time_stopping.py index 846885fc78..5198a23322 100644 --- a/tensorflow_addons/callbacks/time_stopping.py +++ b/tensorflow_addons/callbacks/time_stopping.py @@ -38,18 +38,18 @@ def __init__(self, seconds: int = 86400, verbose: int = 0): self.seconds = seconds self.verbose = verbose - self.stopped_epoch = 0 + self.stopped_epoch = None def on_train_begin(self, logs=None): self.stopping_time = time.time() + self.seconds def on_epoch_end(self, epoch, logs={}): - self.stopped_epoch = epoch if time.time() >= self.stopping_time: self.model.stop_training = True + self.stopped_epoch = epoch def on_train_end(self, logs=None): - if self.verbose > 0: + if self.stopped_epoch is not None and self.verbose > 0: formatted_time = datetime.timedelta(seconds=self.seconds) msg = "Timed stopping at epoch {} after training for {}".format( self.stopped_epoch + 1, formatted_time