diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 93915ac946ae9..134eb2e298d4a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _NATIVE_AMP_AVAILABLE: from torch.cuda.amp import GradScaler @@ -393,7 +393,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: with self.training_type_plugin.model_sharded_context(): yield - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 29d9944a3f33d..7f5975651a100 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -26,7 +26,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -259,7 +259,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: model = self.lightning_module return model.state_dict() - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -269,7 +269,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: if self.should_rank_save_checkpoint: return self.checkpoint_io.save_checkpoint(checkpoint, filepath) - def remove_checkpoint(self, filepath: str) -> None: + def remove_checkpoint(self, filepath: _PATH) -> None: """Remove checkpoint filepath from the filesystem. Args: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2503fc61f4f7f..dee8aea79a304 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: @@ -430,7 +431,7 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: ckpt_number = max_suffix if max_suffix is not None else 0 return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt" - def save_checkpoint(self, filepath, weights_only: bool = False) -> None: + def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index fd72e6d4397fe..e33eeb8b7cbfd 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -48,6 +48,7 @@ ) from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import _PATH class TrainerProperties(ABC): @@ -388,7 +389,7 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: return self.checkpoint_connector.resume_checkpoint_path - def save_checkpoint(self, filepath, weights_only: bool = False) -> None: + def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) """