diff --git a/ignite/base/mixins.py b/ignite/base/mixins.py index 3ecb2922f039..bcf511df8d8f 100644 --- a/ignite/base/mixins.py +++ b/ignite/base/mixins.py @@ -1,11 +1,18 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Tuple +from typing import List, Tuple class Serializable: - _state_dict_all_req_keys: Tuple = () - _state_dict_one_of_opt_keys: Tuple = () + _state_dict_all_req_keys: Tuple[str, ...] = () + _state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),) + + def __init__(self) -> None: + self._state_dict_user_keys: List[str] = [] + + @property + def state_dict_user_keys(self) -> List: + return self._state_dict_user_keys def state_dict(self) -> OrderedDict: raise NotImplementedError @@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None: raise ValueError( f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" ) - opts = [k in state_dict for k in self._state_dict_one_of_opt_keys] - if len(opts) > 0 and ((not any(opts)) or (all(opts))): - raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys") + + # Handle groups of one-of optional keys + for one_of_opt_keys in self._state_dict_one_of_opt_keys: + if len(one_of_opt_keys) > 0: + opts = [k in state_dict for k in one_of_opt_keys] + num_present = sum(opts) + if num_present == 0: + raise ValueError(f"state_dict should contain at least one of '{one_of_opt_keys}' keys") + if num_present > 1: + raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys") + + # Check user keys + if hasattr(self, "_state_dict_user_keys") and isinstance(self._state_dict_user_keys, list): + for k in self._state_dict_user_keys: + if k not in state_dict: + raise ValueError( + f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" + ) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index c814f770e77a..4762cb79c50a 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -128,13 +128,14 @@ def compute_mean_std(engine, batch): """ - _state_dict_all_req_keys = ("epoch_length", "max_epochs") - _state_dict_one_of_opt_keys = ("iteration", "epoch") + _state_dict_all_req_keys = ("epoch_length",) + _state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters")) # Flag to disable engine._internal_run as generator feature for BC interrupt_resume_enabled = True def __init__(self, process_function: Callable[["Engine", Any], Any]): + super(Engine, self).__init__() self._event_handlers: Dict[Any, List] = defaultdict(list) self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._process_function = process_function @@ -147,7 +148,6 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): self.should_terminate_single_epoch: Union[bool, str] = False self.should_interrupt = False self.state = State() - self._state_dict_user_keys: List[str] = [] self._allowed_events: List[EventEnum] = [] self._dataloader_iter: Optional[Iterator[Any]] = None @@ -691,14 +691,20 @@ def save_engine(_): a dictionary containing engine's state """ - keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) + keys: Tuple[str, ...] = self._state_dict_all_req_keys + keys += ("iteration",) + # Include either max_epochs or max_iters based on which was originally set + if self.state.max_iters is not None: + keys += ("max_iters",) + else: + keys += ("max_epochs",) keys += tuple(self._state_dict_user_keys) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) def load_state_dict(self, state_dict: Mapping) -> None: """Setups engine from `state_dict`. - State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`. + State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero. @@ -709,10 +715,12 @@ def load_state_dict(self, state_dict: Mapping) -> None: .. code-block:: python - # Restore from the 4rd epoch + # Restore from the 4th epoch state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} # or 500th iteration # state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)} + # or with max_iters + # state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)} trainer = Engine(...) trainer.load_state_dict(state_dict) @@ -721,22 +729,20 @@ def load_state_dict(self, state_dict: Mapping) -> None: """ super(Engine, self).load_state_dict(state_dict) - for k in self._state_dict_user_keys: - if k not in state_dict: - raise ValueError( - f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" - ) - self.state.max_epochs = state_dict["max_epochs"] + # Set epoch_length self.state.epoch_length = state_dict["epoch_length"] + + # Set user keys for k in self._state_dict_user_keys: setattr(self.state, k, state_dict[k]) + # Set iteration or epoch if "iteration" in state_dict: self.state.iteration = state_dict["iteration"] self.state.epoch = 0 - if self.state.epoch_length is not None: + if self.state.epoch_length is not None and self.state.epoch_length > 0: self.state.epoch = self.state.iteration // self.state.epoch_length - elif "epoch" in state_dict: + else: # epoch is in state_dict self.state.epoch = state_dict["epoch"] if self.state.epoch_length is None: raise ValueError( @@ -745,6 +751,36 @@ def load_state_dict(self, state_dict: Mapping) -> None: ) self.state.iteration = self.state.epoch_length * self.state.epoch + # Set max_epochs or max_iters with validation + max_epochs_value = state_dict.get("max_epochs", None) + max_iters_value = state_dict.get("max_iters", None) + + # Validate max_epochs if present + if max_epochs_value is not None: + if max_epochs_value < 1: + raise ValueError("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value") + if max_epochs_value < self.state.epoch: + raise ValueError( + "max_epochs in state_dict should be larger than or equal to the current epoch " + f"defined in the state: {max_epochs_value} vs {self.state.epoch}. " + ) + self.state.max_epochs = max_epochs_value + else: + self.state.max_epochs = None + + # Validate max_iters if present + if max_iters_value is not None: + if max_iters_value < 1: + raise ValueError("max_iters in state_dict is invalid. Please, set a correct max_iters positive value") + if max_iters_value < self.state.iteration: + raise ValueError( + "max_iters in state_dict should be larger than or equal to the current iteration " + f"defined in the state: {max_iters_value} vs {self.state.iteration}. " + ) + self.state.max_iters = max_iters_value + else: + self.state.max_iters = None + @staticmethod def _is_done(state: State) -> bool: is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters @@ -756,6 +792,59 @@ def _is_done(state: State) -> bool: is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs return is_done_iters or is_done_count or is_done_epochs + def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: + """Validate and set max_epochs with proper checks.""" + if max_epochs is not None: + if max_epochs < 1: + raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") + # Only validate if training is actually done - allow resuming interrupted training + if self.state.max_epochs is not None and max_epochs < self.state.epoch: + raise ValueError( + "Argument max_epochs should be greater than or equal to the start " + f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. " + "Please, set engine.state.max_epochs = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_epochs = max_epochs + + def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: + """Validate and set max_iters with proper checks.""" + if max_iters is not None: + if max_iters < 1: + raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") + # Only validate if training is actually done - allow resuming interrupted training + if (self.state.max_iters is not None) and max_iters < self.state.iteration: + raise ValueError( + "Argument max_iters should be greater than or equal to the start " + f"iteration defined in the state: {max_iters} vs {self.state.iteration}. " + "Please, set engine.state.max_iters = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_iters = max_iters + + def _check_and_set_epoch_length(self, data: Optional[Iterable], epoch_length: Optional[int] = None) -> None: + """Validate and set epoch_length.""" + # Check if we can redefine epoch_length + if self.state.epoch_length is not None: + if epoch_length is not None: + if epoch_length != self.state.epoch_length: + raise ValueError( + "Argument epoch_length should be same as in the state, " + f"but given {epoch_length} vs {self.state.epoch_length}" + ) + else: + if epoch_length is None: + if data is not None: + epoch_length = self._get_data_length(data) + + if epoch_length is not None: + if epoch_length < 1: + raise ValueError( + "Argument epoch_length is invalid. Please, either set a correct epoch_length value or " + "check if input data has non-zero size." + ) + self.state.epoch_length = epoch_length + def set_data(self, data: Union[Iterable, DataLoader]) -> None: """Method to set data. After calling the method the next batch passed to `processing_function` is from newly provided data. Please, note that epoch length is not modified. @@ -854,46 +943,59 @@ def switch_batch(engine): if data is not None and not isinstance(data, Iterable): raise TypeError("Argument data should be iterable") - if self.state.max_epochs is not None: - # Check and apply overridden parameters - if max_epochs is not None: - if max_epochs < self.state.epoch: - raise ValueError( - "Argument max_epochs should be greater than or equal to the start " - f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. " - "Please, set engine.state.max_epochs = None " - "before calling engine.run() in order to restart the training from the beginning." - ) - self.state.max_epochs = max_epochs - if epoch_length is not None: - if epoch_length != self.state.epoch_length: - raise ValueError( - "Argument epoch_length should be same as in the state, " - f"but given {epoch_length} vs {self.state.epoch_length}" - ) + if max_epochs is not None and max_iters is not None: + raise ValueError( + "Arguments max_iters and max_epochs are mutually exclusive." + "Please provide only max_epochs or max_iters." + ) - if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None): - # Create new state - if epoch_length is None: - if data is None: - raise ValueError("epoch_length should be provided if data is None") + # Check if we need to create new state or resume + # Create new state if: + # 1. No termination params set (first run), OR + # 2. Training is done AND generator is None AND no new params provided + # 3. Training is done AND same termination params provided (restart case) + should_create_new_state = ( + (self.state.max_epochs is None and self.state.max_iters is None) + or ( + self._is_done(self.state) + and self._internal_run_generator is None + and max_epochs is None + and max_iters is None + ) + or ( + self._is_done(self.state) + and self._internal_run_generator is None + and ( + (max_epochs is not None and max_epochs == self.state.max_epochs) + or (max_iters is not None and max_iters == self.state.max_iters) + ) + ) + ) - epoch_length = self._get_data_length(data) - if epoch_length is not None and epoch_length < 1: - raise ValueError("Input data has zero size. Please provide non-empty data") + if should_create_new_state: + # Create new state + if data is None and epoch_length is None and self.state.epoch_length is None: + raise ValueError("epoch_length should be provided if data is None") + # Set epoch_length for new state + if epoch_length is None: + # Try to get from data first, then fall back to existing state + if data is not None: + epoch_length = self._get_data_length(data) + if epoch_length is None and self.state.epoch_length is not None: + epoch_length = self.state.epoch_length + if epoch_length is not None and epoch_length < 1: + raise ValueError("Input data has zero size. Please provide non-empty data") + + # Determine max_epochs/max_iters if max_iters is None: if max_epochs is None: max_epochs = 1 else: - if max_epochs is not None: - raise ValueError( - "Arguments max_iters and max_epochs are mutually exclusive." - "Please provide only max_epochs or max_iters." - ) if epoch_length is not None: max_epochs = math.ceil(max_iters / epoch_length) + # Initialize new state self.state.iteration = 0 self.state.epoch = 0 self.state.max_epochs = max_epochs @@ -901,12 +1003,38 @@ def switch_batch(engine): self.state.epoch_length = epoch_length # Reset generator if previously used self._internal_run_generator = None - self.logger.info(f"Engine run starting with max_epochs={max_epochs}.") + + # Log start message + if self.state.max_epochs is not None: + self.logger.info(f"Engine run starting with max_epochs={self.state.max_epochs}.") + else: + self.logger.info(f"Engine run starting with max_iters={self.state.max_iters}.") else: - self.logger.info( - f"Engine run resuming from iteration {self.state.iteration}, " - f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" - ) + # Resume from existing state + # Apply overridden parameters using helper methods + self._check_and_set_max_epochs(max_epochs) + self._check_and_set_max_iters(max_iters) + + # Handle epoch_length validation (simplified from original) + if epoch_length is not None: + if epoch_length != self.state.epoch_length: + raise ValueError( + "Argument epoch_length should be same as in the state, " + f"but given {epoch_length} vs {self.state.epoch_length}" + ) + + # Log resuming message + if self.state.max_epochs is not None: + self.logger.info( + f"Engine run resuming from iteration {self.state.iteration}, " + f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" + ) + else: + self.logger.info( + f"Engine run resuming from iteration {self.state.iteration}, " + f"epoch {self.state.epoch} until {self.state.max_iters} iterations" + ) + if self.state.epoch_length is None and data is None: raise ValueError("epoch_length should be provided if data is None") diff --git a/tests/ignite/base/test_mixins_update.py b/tests/ignite/base/test_mixins_update.py new file mode 100644 index 000000000000..0a66e3f2ea67 --- /dev/null +++ b/tests/ignite/base/test_mixins_update.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python +"""Test the updated mixins functionality for the max_iters fix.""" + +from ignite.base import Serializable +from collections import OrderedDict +import pytest + + +class ExampleSerializable(Serializable): + _state_dict_all_req_keys = ("a", "b") + _state_dict_one_of_opt_keys = (("c", "d"), ("e", "f")) + + def __init__(self): + super().__init__() + self.data = {} + + def state_dict(self): + return {"a": 1, "b": 2, "c": 3, "e": 5} + + +class EngineStyleSerializable(Serializable): + """Serializable that mimics Engine's key structure.""" + + _state_dict_all_req_keys = ("epoch_length",) + _state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters")) + + def __init__(self): + super().__init__() + self.data = {} + + def state_dict(self): + result = OrderedDict() + for key in self._state_dict_all_req_keys: + if key in self.data: + result[key] = self.data[key] + + # Add user keys + for key in self._state_dict_user_keys: + if key in self.data: + result[key] = self.data[key] + + return result + + +def test_load_state_dict_validation(): + """Test the updated load_state_dict validation.""" + s = ExampleSerializable() + + # Test type check + with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): + s.load_state_dict("not a dict") + + # Test missing required keys + with pytest.raises(ValueError, match=r"Required state attribute 'a' is absent"): + s.load_state_dict({}) + + with pytest.raises(ValueError, match=r"Required state attribute 'b' is absent"): + s.load_state_dict({"a": 1}) + + # Test one-of optional keys - missing all + with pytest.raises(ValueError, match=r"should contain at least one of"): + s.load_state_dict({"a": 1, "b": 2}) + + # Test one-of optional keys - having all from one group + with pytest.raises(ValueError, match=r"should contain only one of '\('c', 'd'\)'"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + + # Test one-of optional keys - having all from another group + with pytest.raises(ValueError, match=r"should contain only one of '\('e', 'f'\)'"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "f": 6}) + + # Test user keys + s.state_dict_user_keys.append("alpha") + with pytest.raises(ValueError, match=r"Required user state attribute 'alpha' is absent"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5}) + + # Valid state dict + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "alpha": 0.1}) + print("✓ Valid state dict loaded successfully") + + +def test_state_dict_user_keys_property(): + """Test the state_dict_user_keys property.""" + s = ExampleSerializable() + + assert hasattr(s, "state_dict_user_keys") + assert isinstance(s.state_dict_user_keys, list) + assert len(s.state_dict_user_keys) == 0 + + s.state_dict_user_keys.append("test_key") + assert len(s.state_dict_user_keys) == 1 + assert s.state_dict_user_keys[0] == "test_key" + + +def test_empty_optional_groups(): + """Test handling of empty optional groups.""" + + class EmptyOptionalSerializable(Serializable): + _state_dict_all_req_keys = ("required",) + _state_dict_one_of_opt_keys = ((),) # Empty tuple + + def state_dict(self): + return {} + + s = EmptyOptionalSerializable() + + # Should pass validation with just required key + s.load_state_dict({"required": "value"}) + + +def test_multiple_empty_groups(): + """Test multiple empty groups in _state_dict_one_of_opt_keys.""" + + class MultiEmptySerializable(Serializable): + _state_dict_all_req_keys = ("base",) + _state_dict_one_of_opt_keys = ((), (), ()) # Multiple empty groups + + def state_dict(self): + return {} + + s = MultiEmptySerializable() + + # Should pass with just required key + s.load_state_dict({"base": "value"}) + + +def test_mixed_empty_and_filled_groups(): + """Test mix of empty and filled optional groups.""" + + class MixedSerializable(Serializable): + _state_dict_all_req_keys = ("base",) + _state_dict_one_of_opt_keys = ((), ("opt1", "opt2"), ()) + + def state_dict(self): + return {} + + s = MixedSerializable() + + # Should require one from non-empty group + with pytest.raises(ValueError, match="should contain at least one of"): + s.load_state_dict({"base": "value"}) + + # Should pass with one from non-empty group + s.load_state_dict({"base": "value", "opt1": "option"}) + + +def test_engine_style_validation(): + """Test validation that mimics Engine usage.""" + s = EngineStyleSerializable() + + # Valid: iteration + max_iters + s.load_state_dict({"epoch_length": 100, "iteration": 150, "max_iters": 500}) + + # Valid: epoch + max_epochs + s2 = EngineStyleSerializable() + s2.load_state_dict({"epoch_length": 100, "epoch": 3, "max_epochs": 10}) + + # Invalid: both iteration and epoch + s3 = EngineStyleSerializable() + with pytest.raises(ValueError, match="should contain only one of.*iteration.*epoch"): + s3.load_state_dict({"epoch_length": 100, "iteration": 150, "epoch": 3, "max_epochs": 10}) + + # Invalid: both max_epochs and max_iters + s4 = EngineStyleSerializable() + with pytest.raises(ValueError, match="should contain only one of.*max_epochs.*max_iters"): + s4.load_state_dict({"epoch_length": 100, "iteration": 150, "max_epochs": 10, "max_iters": 500}) + + +def test_single_option_group(): + """Test group with single option.""" + + class SingleOptionSerializable(Serializable): + _state_dict_all_req_keys = ("base",) + _state_dict_one_of_opt_keys = (("single",),) + + def state_dict(self): + return {} + + s = SingleOptionSerializable() + + # Should require the single option + with pytest.raises(ValueError, match="should contain at least one of"): + s.load_state_dict({"base": "value"}) + + # Should pass with single option + s.load_state_dict({"base": "value", "single": "option"}) + + +def test_inheritance_overrides(): + """Test that subclasses can override validation rules.""" + + class BaseSerializable(Serializable): + _state_dict_all_req_keys = ("base_req",) + _state_dict_one_of_opt_keys = (("base_opt1", "base_opt2"),) + + def state_dict(self): + return {} + + class DerivedSerializable(BaseSerializable): + _state_dict_all_req_keys = ("derived_req1", "derived_req2") + _state_dict_one_of_opt_keys = (("derived_opt1", "derived_opt2"),) + + # Base class uses its own rules + base = BaseSerializable() + base.load_state_dict({"base_req": "value", "base_opt1": "opt"}) + + # Derived class uses overridden rules + derived = DerivedSerializable() + with pytest.raises(ValueError, match="Required state attribute.*derived_req1"): + derived.load_state_dict({"base_req": "value", "base_opt1": "opt"}) + + # Valid for derived class + derived.load_state_dict({"derived_req1": "d1", "derived_req2": "d2", "derived_opt2": "opt"}) + + +def test_user_keys_with_groups(): + """Test user keys work with grouped optional keys.""" + s = EngineStyleSerializable() + s.state_dict_user_keys.append("custom_param") + s.state_dict_user_keys.append("learning_rate") + + # Valid with all requirements + s.load_state_dict( + {"epoch_length": 100, "iteration": 250, "max_iters": 500, "custom_param": 42, "learning_rate": 0.01} + ) + + # Missing user key should fail + s2 = EngineStyleSerializable() + s2.state_dict_user_keys.append("custom_param") + with pytest.raises(ValueError, match="Required user state attribute.*custom_param"): + s2.load_state_dict({"epoch_length": 100, "iteration": 250, "max_iters": 500}) + + +def test_error_messages(): + """Test that error messages are clear and helpful.""" + s = EngineStyleSerializable() + + # Check specific error message format for grouped keys + try: + s.load_state_dict({"epoch_length": 100, "max_epochs": 5}) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "should contain at least one of" in str(e) + assert "iteration" in str(e) and "epoch" in str(e) + + # Check error message for having both from a group + try: + s.load_state_dict({"epoch_length": 100, "iteration": 150, "epoch": 3, "max_epochs": 5}) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "should contain only one of" in str(e) + assert "iteration" in str(e) and "epoch" in str(e) + + +def test_backwards_compatibility(): + """Test that old style validation still works.""" + + class OldStyleSerializable(Serializable): + _state_dict_all_req_keys = ("req1", "req2") + # No _state_dict_one_of_opt_keys defined - should default to empty + + def state_dict(self): + return {} + + s = OldStyleSerializable() + + # Should work with just required keys + s.load_state_dict({"req1": "r1", "req2": "r2"}) + + # Should fail without required keys + with pytest.raises(ValueError, match="Required state attribute"): + s.load_state_dict({"req1": "r1"}) + + +def test_complex_scenario(): + """Test complex scenario with multiple groups and user keys.""" + + class ComplexSerializable(Serializable): + _state_dict_all_req_keys = ("base1", "base2") + _state_dict_one_of_opt_keys = ( + ("pos1", "pos2", "pos3"), + ("term1", "term2"), + ("opt1", "opt2", "opt3", "opt4"), + ) + + def state_dict(self): + return {} + + s = ComplexSerializable() + s.state_dict_user_keys.extend(["user1", "user2"]) + + # Valid complex state + s.load_state_dict( + { + "base1": "b1", + "base2": "b2", + "pos2": "position", + "term1": "termination", + "opt3": "option", + "user1": "u1", + "user2": "u2", + } + ) + + # Missing from one group should fail + s2 = ComplexSerializable() + with pytest.raises(ValueError, match="should contain at least one of.*term1.*term2"): + s2.load_state_dict({"base1": "b1", "base2": "b2", "pos1": "pos", "opt4": "opt"}) diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index 4ccfb7ea7720..7e78eecd02f8 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -18,7 +18,7 @@ def test_state_dict(): def _test(state): engine.state = state sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 2 assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length assert sd["max_epochs"] == engine.state.max_epochs @@ -35,7 +35,7 @@ def test_state_dict_with_user_keys(): def _test(state): engine.state = state sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + len( + assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 2 + len( engine.state_dict_user_keys ) assert sd["iteration"] == engine.state.iteration @@ -52,7 +52,7 @@ def test_state_dict_integration(): data = range(100) engine.run(data, max_epochs=10) sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 2 assert sd["iteration"] == engine.state.iteration == 10 * 100 assert sd["epoch_length"] == engine.state.epoch_length == 100 assert sd["max_epochs"] == engine.state.max_epochs == 10 @@ -67,7 +67,7 @@ def test_load_state_dict_asserts(): with pytest.raises(ValueError, match=r"is absent in provided state_dict"): engine.load_state_dict({}) - with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + with pytest.raises(ValueError, match=r"state_dict should contain at least one of"): engine.load_state_dict({"max_epochs": 100, "epoch_length": 120}) with pytest.raises(ValueError, match=r"state_dict should contain only one of"): diff --git a/tests/ignite/engine/test_max_iters_fix.py b/tests/ignite/engine/test_max_iters_fix.py new file mode 100644 index 000000000000..610bb3804e73 --- /dev/null +++ b/tests/ignite/engine/test_max_iters_fix.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python +"""Test script to verify max_iters handling in Engine state dict. + +This tests the fix for issue #1521. +""" +import pytest +from ignite.engine import Engine, Events, State + + +def test_state_dict_with_max_epochs(): + """Test state_dict with max_epochs set.""" + engine = Engine(lambda e, b: 1) + data = range(100) + engine.run(data, max_epochs=5) + + sd = engine.state_dict() + assert "iteration" in sd + assert "epoch_length" in sd + assert "max_epochs" in sd + assert "max_iters" not in sd + assert sd["max_epochs"] == 5 + assert sd["epoch_length"] == 100 + assert sd["iteration"] == 500 + + +def test_state_dict_with_max_iters(): + """Test state_dict with max_iters set.""" + engine = Engine(lambda e, b: 1) + data = range(100) + engine.run(data, max_iters=250) + + sd = engine.state_dict() + assert "iteration" in sd + assert "epoch_length" in sd + assert "max_iters" in sd + assert "max_epochs" not in sd + assert sd["max_iters"] == 250 + assert sd["epoch_length"] == 100 + assert sd["iteration"] == 250 + + +def test_load_state_dict_with_max_epochs(): + """Test load_state_dict with max_epochs.""" + engine = Engine(lambda e, b: 1) + + state_dict = {"epoch": 2, "max_epochs": 5, "epoch_length": 100} + + engine.load_state_dict(state_dict) + assert engine.state.epoch == 2 + assert engine.state.max_epochs == 5 + assert engine.state.epoch_length == 100 + assert engine.state.iteration == 200 + + +def test_load_state_dict_with_max_iters(): + """Test load_state_dict with max_iters.""" + engine = Engine(lambda e, b: 1) + + state_dict = {"iteration": 150, "max_iters": 250, "epoch_length": 100} + + engine.load_state_dict(state_dict) + assert engine.state.iteration == 150 + assert engine.state.max_iters == 250 + assert engine.state.epoch_length == 100 + assert engine.state.epoch == 1 # 150 // 100 + + +def test_save_and_load_with_max_iters(): + """Test saving and loading engine state with max_iters.""" + # Create and run engine with max_iters + engine1 = Engine(lambda e, b: b) + data = list(range(20)) + engine1.run(data, max_iters=50, epoch_length=10) + + # Save state + state_dict = engine1.state_dict() + assert state_dict["iteration"] == 50 + assert state_dict["max_iters"] == 50 + assert state_dict["epoch_length"] == 10 + assert "max_epochs" not in state_dict + + # Load state in new engine + engine2 = Engine(lambda e, b: b) + engine2.load_state_dict(state_dict) + + assert engine2.state.iteration == 50 + assert engine2.state.max_iters == 50 + assert engine2.state.epoch_length == 10 + assert engine2.state.epoch == 5 # 50 // 10 + + +def test_resume_with_max_iters(): + """Test resuming engine run with max_iters using early termination.""" + counter = [0] + + def update_fn(engine, batch): + counter[0] += 1 + return batch + + engine = Engine(update_fn) + data = list(range(10)) + + # Set up early termination at iteration 15 + @engine.on(Events.ITERATION_COMPLETED(once=15)) + def stop_early(engine): + engine.terminate() + + # Run with max_iters=25 but terminate early at 15 + engine.run(data, max_iters=25, epoch_length=10) + assert counter[0] == 15 + assert engine.state.iteration == 15 + assert engine.state.max_iters == 25 # Still has the original max_iters + + # Save and reload state + state_dict = engine.state_dict() + counter[0] = 0 # Reset counter + + engine2 = Engine(update_fn) + engine2.load_state_dict(state_dict) + + # Resume running - should continue from iteration 15 to 25 + engine2.run(data) + assert counter[0] == 10 # 25 - 15 + assert engine2.state.iteration == 25 + + +def test_mutually_exclusive_max_epochs_max_iters(): + """Test that max_epochs and max_iters are mutually exclusive.""" + engine = Engine(lambda e, b: 1) + data = range(10) + + with pytest.raises(ValueError, match="mutually exclusive"): + engine.run(data, max_epochs=5, max_iters=50) + + +def test_validation_errors(): + """Test validation errors for invalid states.""" + engine = Engine(lambda e, b: 1) + + # Test invalid max_epochs in state_dict + with pytest.raises(ValueError, match="larger than or equal to the current epoch"): + state_dict = {"epoch": 5, "max_epochs": 3, "epoch_length": 10} + engine.load_state_dict(state_dict) + + # Test invalid max_iters in state_dict + with pytest.raises(ValueError, match="larger than or equal to the current iteration"): + state_dict = {"iteration": 100, "max_iters": 50, "epoch_length": 10} + engine.load_state_dict(state_dict) + + +def test_unknown_epoch_length_with_max_iters(): + """Test handling unknown epoch_length with max_iters.""" + counter = [0] + + def update_fn(engine, batch): + counter[0] += 1 + return batch + + def data_iter(): + for i in range(15): + yield i + + engine = Engine(update_fn) + + # Run with unknown epoch length and max_iters that completes before first epoch ends + engine.run(data_iter(), max_iters=10) + assert counter[0] == 10 + assert engine.state.iteration == 10 + # epoch_length remains None since we stopped before completing an epoch + assert engine.state.epoch_length is None + + # State dict should have max_iters + sd = engine.state_dict() + assert "max_iters" in sd + assert sd["max_iters"] == 10 + + # Test case where we complete a full epoch + engine2 = Engine(update_fn) + counter[0] = 0 + engine2.run(data_iter(), max_iters=20) + assert counter[0] == 15 # Iterator exhausted after 15 + assert engine2.state.iteration == 15 + # epoch_length should be determined when iterator is exhausted + assert engine2.state.epoch_length == 15 + + +def test_engine_attributes(): + """Test basic engine attributes and state.""" + engine = Engine(lambda e, b: 1) + + # Test basic attributes exist + assert hasattr(engine, "state") + assert hasattr(engine, "logger") + assert hasattr(engine, "state_dict_user_keys") + + # Test initial state + assert engine.state.iteration == 0 + assert engine.state.epoch == 0 + assert engine.state.max_epochs is None + assert engine.state.max_iters is None + assert engine.state.epoch_length is None + + +def test_helper_methods(): + """Test the helper validation methods.""" + engine = Engine(lambda e, b: 1) + data = range(10) + engine.run(data, max_epochs=3) + + # Test _check_and_set_max_epochs + with pytest.raises(ValueError, match="greater than or equal to the start"): + engine._check_and_set_max_epochs(2) + + engine._check_and_set_max_epochs(5) + assert engine.state.max_epochs == 5 + + # Test _check_and_set_max_iters + engine.state.max_epochs = None + engine.state.max_iters = 30 + + with pytest.raises(ValueError, match="greater than or equal to the start"): + engine._check_and_set_max_iters(25) + + engine._check_and_set_max_iters(40) + assert engine.state.max_iters == 40 + + +def test_backward_compatibility(): + """Test backward compatibility with old state dicts.""" + engine = Engine(lambda e, b: 1) + + # Old state dict format (with max_epochs) + old_state_dict = {"iteration": 200, "max_epochs": 5, "epoch_length": 100} + + engine.load_state_dict(old_state_dict) + assert engine.state.iteration == 200 + assert engine.state.max_epochs == 5 + assert engine.state.epoch_length == 100 + assert engine.state.epoch == 2 # 200 // 100 + + +def test_invalid_state_dict_both_termination_params(): + """Test that state dict with both max_epochs and max_iters fails.""" + engine = Engine(lambda e, b: 1) + + state_dict = {"iteration": 100, "max_epochs": 5, "max_iters": 500, "epoch_length": 100} + + with pytest.raises(ValueError, match="should contain only one of"): + engine.load_state_dict(state_dict) + + +def test_invalid_state_dict_both_position_params(): + """Test that state dict with both iteration and epoch fails.""" + engine = Engine(lambda e, b: 1) + + state_dict = {"iteration": 100, "epoch": 2, "max_epochs": 5, "epoch_length": 100} + + with pytest.raises(ValueError, match="should contain only one of"): + engine.load_state_dict(state_dict) + + +def test_invalid_state_dict_missing_termination(): + """Test that state dict without max_epochs or max_iters fails.""" + engine = Engine(lambda e, b: 1) + + state_dict = {"iteration": 100, "epoch_length": 100} + + with pytest.raises(ValueError, match="should contain at least one of"): + engine.load_state_dict(state_dict) + + +def test_user_keys_with_max_iters(): + """Test user-defined keys work with max_iters.""" + engine = Engine(lambda e, b: b) + data = list(range(10)) + + # Add user keys + engine.state_dict_user_keys.append("custom_value") + engine.state_dict_user_keys.append("another_value") + + @engine.on(Events.STARTED) + def init_custom_values(engine): + engine.state.custom_value = 42 + engine.state.another_value = "test" + + engine.run(data, max_iters=5) + + # Check state dict contains user keys + sd = engine.state_dict() + assert "custom_value" in sd + assert "another_value" in sd + assert sd["custom_value"] == 42 + assert sd["another_value"] == "test" + assert "max_iters" in sd + assert "max_epochs" not in sd + + # Load into new engine + engine2 = Engine(lambda e, b: b) + engine2.state_dict_user_keys.append("custom_value") + engine2.state_dict_user_keys.append("another_value") + + engine2.load_state_dict(sd) + assert engine2.state.custom_value == 42 + assert engine2.state.another_value == "test" + assert engine2.state.max_iters == 5 + + +def test_is_done_method_with_max_iters(): + """Test the _is_done static method with max_iters.""" + # Test with max_iters + state = State() + state.iteration = 100 + state.max_iters = 100 + state.epoch_length = 25 + state.epoch = 4 + state.max_epochs = None + + assert Engine._is_done(state) is True + + state.iteration = 99 + assert Engine._is_done(state) is False + + state.iteration = 101 + assert Engine._is_done(state) is True + + # Test with both set (shouldn't happen but test logic) + state.iteration = 50 + state.max_iters = 100 + state.max_epochs = 3 + state.epoch = 2 + state.epoch_length = 25 + assert Engine._is_done(state) is False + + state.iteration = 100 + assert Engine._is_done(state) is True + + state.iteration = 75 + state.epoch = 3 + assert Engine._is_done(state) is True + + +def test_none_data_with_max_iters(): + """Test running with None data and max_iters.""" + counter = [0] + + def update_fn(engine, batch): + assert batch is None + counter[0] += 1 + return 1 + + engine = Engine(update_fn) + + # Should work with None data if epoch_length provided + engine.run(data=None, max_iters=30, epoch_length=10) + + assert counter[0] == 30 + assert engine.state.iteration == 30 + assert engine.state.max_iters == 30 + assert engine.state.epoch_length == 10 + assert engine.state.epoch == 3 # ceil(30/10) = 3 + + +def test_epoch_calculation_with_max_iters(): + """Test epoch calculation when using max_iters.""" + engine = Engine(lambda e, b: b) + data = list(range(25)) + + # Run with max_iters that doesn't divide evenly + engine.run(data, max_iters=60) + + assert engine.state.iteration == 60 + assert engine.state.max_iters == 60 + assert engine.state.epoch_length == 25 + assert engine.state.epoch == 3 # ceil(60/25) = 3 + + # Save and verify state dict + sd = engine.state_dict() + assert sd["iteration"] == 60 + assert sd["max_iters"] == 60 + assert sd["epoch_length"] == 25 + + +def test_resume_with_higher_max_iters(): + """Test loading state and running with higher max_iters value.""" + counter = [0] + + def update_fn(engine, batch): + counter[0] += 1 + return batch + + # First engine: run until 15 then save state + engine1 = Engine(update_fn) + data = list(range(10)) + + # Use early termination to simulate partial run + @engine1.on(Events.ITERATION_COMPLETED(once=15)) + def stop_early(engine): + engine.terminate() + + engine1.run(data, max_iters=20) + assert counter[0] == 15 + assert engine1.state.iteration == 15 + assert engine1.state.max_iters == 20 + + # Save state + sd = engine1.state_dict() + counter[0] = 0 + + # Second engine: load state and increase max_iters + engine2 = Engine(update_fn) + engine2.load_state_dict(sd) + + # Directly set higher max_iters and run + engine2.state.max_iters = 25 + engine2.run(data) + assert counter[0] == 10 # 25 - 15 + assert engine2.state.iteration == 25 + assert engine2.state.max_iters == 25 + + # Final state dict + final_sd = engine2.state_dict() + assert final_sd["iteration"] == 25 + assert final_sd["max_iters"] == 25