Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
502950b
Updating torch.load To Load Weights Only
ericspod Sep 11, 2025
77d2d82
Autofix
ericspod Sep 11, 2025
79c2cf8
StateCacher should be fine with default pickle protocol
ericspod Sep 11, 2025
5f1f57c
Merge branch 'dev' into torch_load_fix
ericspod Sep 11, 2025
f6f9867
Docstring Update
ericspod Sep 11, 2025
93a5dd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2025
10b5de3
Removing pickle_operations
ericspod Sep 12, 2025
64221d0
Fixes loading with weights_only for PersistenDataset by force convert…
ericspod Sep 12, 2025
8a75795
Tweak
ericspod Sep 12, 2025
a60569c
Comment unneeded components
ericspod Sep 13, 2025
b54e55d
Modify convert_to_tensor to skip converting primitives
ericspod Sep 14, 2025
7dc3ad3
Merge branch 'pickle_fixes' into torch_load_fix
ericspod Sep 14, 2025
52f8694
Trying safe torch load save usage in place of pickle
ericspod Sep 14, 2025
14e5e6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2025
38a618b
Updates to further remove pickle usage
ericspod Sep 14, 2025
28c7df2
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 14, 2025
77d6992
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2025
79e0966
Autofix
ericspod Sep 14, 2025
2a88d83
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 14, 2025
11c0ee5
Removing commented code
ericspod Sep 14, 2025
2edf46c
Pass argument in recursive call of convert_to_tensor
ericspod Sep 14, 2025
9b171d4
Type fix
ericspod Sep 14, 2025
3d6e0ca
Merge branch 'dev' into torch_load_fix
ericspod Sep 15, 2025
65a7b6d
Merge branch 'dev' into torch_load_fix
KumoLiu Sep 16, 2025
149a5bb
Fixing pickle protocol issue
ericspod Sep 16, 2025
58561b3
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 16, 2025
dd1de4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions monai/apps/nnunet/nnunet_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_nnunet_trainer(
cudnn.benchmark = True

if pretrained_model is not None:
state_dict = torch.load(pretrained_model)
state_dict = torch.load(pretrained_model, weights_only=True)
if "network_weights" in state_dict:
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
return nnunet_trainer
Expand Down Expand Up @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
parameters = []

checkpoint = torch.load(
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
map_location=torch.device("cpu"),
weights_only=True,
)
trainer_name = checkpoint["trainer_name"]
configuration_name = checkpoint["init_args"]["configuration"]
Expand All @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
else None
)
if Path(model_training_output_dir).joinpath(model_name).is_file():
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
monai_checkpoint = torch.load(
join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
)
if "network_weights" in monai_checkpoint.keys():
parameters.append(monai_checkpoint["network_weights"])
else:
Expand Down Expand Up @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
)

nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
nnunet_checkpoint_final = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
)
nnunet_checkpoint_best = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
)

nnunet_checkpoint = {}
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
Expand Down Expand Up @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
if model_ckpt is None:
return network
else:
state_dict = torch.load(model_ckpt)
state_dict = torch.load(model_ckpt, weights_only=True)
network.load_state_dict(state_dict[model_key_in_ckpt])
return network

Expand Down Expand Up @@ -534,7 +542,7 @@ def subfiles(

Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)

nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
latest_checkpoints: list[str] = subfiles(
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
)
Expand All @@ -545,7 +553,7 @@ def subfiles(
epochs.sort()
final_epoch: int = epochs[-1]
monai_last_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
)

best_checkpoints: list[str] = subfiles(
Expand All @@ -558,7 +566,7 @@ def subfiles(
key_metrics.sort()
best_key_metric: str = key_metrics[-1]
monai_best_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
)

nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
Expand Down
1 change: 0 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
from .thread_buffer import ThreadBuffer, ThreadDataLoader
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
from .utils import (
PICKLE_KEY_SUFFIX,
affine_to_spacing,
compute_importance_map,
compute_shape_offset,
Expand Down
50 changes: 34 additions & 16 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import collections.abc
import math
import pickle
import shutil
import sys
import tempfile
Expand All @@ -22,9 +21,11 @@
import warnings
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from io import BytesIO
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from pathlib import Path
from pickle import UnpicklingError
from typing import IO, TYPE_CHECKING, Any, cast

import numpy as np
Expand Down Expand Up @@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
errors. If in doubt, it is advisable to clear the cache directory.

Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
be converted to tensors, however any other object type returned by transforms will not be loadable since
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
Legacy cache files may not be loadable and may need to be recomputed.

Lazy Resampling:
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
Expand Down Expand Up @@ -248,8 +254,8 @@ def __init__(
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Expand Down Expand Up @@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=False)
return torch.load(hashfile, weights_only=True)
except PermissionError as e:
if sys.platform != "win32":
raise e
except RuntimeError as e:
if "Invalid magic number; corrupt file" in str(e):
except (UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed
if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError):
warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
hashfile.unlink()
else:
Expand All @@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=_item_transformed,
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand Down Expand Up @@ -455,8 +461,8 @@ def __init__(
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Expand Down Expand Up @@ -531,7 +537,7 @@ def __init__(
hash_func: Callable[..., bytes] = pickle_hashing,
db_name: str = "monai_cache",
progress: bool = True,
pickle_protocol=pickle.HIGHEST_PROTOCOL,
pickle_protocol=DEFAULT_PROTOCOL,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
lmdb_kwargs: dict | None = None,
Expand All @@ -551,8 +557,9 @@ def __init__(
defaults to `monai.data.utils.pickle_hashing`.
db_name: lmdb database file name. Defaults to "monai_cache".
progress: whether to display a progress bar.
pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
https://docs.python.org/3/library/pickle.html#pickle-protocols
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
Expand Down Expand Up @@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
super().set_data(data=data)
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)

def _safe_serialize(self, val):
out = BytesIO()
torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol)
out.seek(0)
return out.read()

def _safe_deserialize(self, val):
return torch.load(BytesIO(val), map_location="cpu", weights_only=True)

def _fill_cache_start_reader(self, show_progress=True):
"""
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
Expand All @@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
continue
if val is None:
val = self._pre_transform(deepcopy(item)) # keep the original hashed
val = pickle.dumps(val, protocol=self.pickle_protocol)
# val = pickle.dumps(val, protocol=self.pickle_protocol)
val = self._safe_serialize(val)
with env.begin(write=True) as txn:
txn.put(key, val)
done = True
Expand Down Expand Up @@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
return super()._cachecheck(item_transformed)
try:
return pickle.loads(data)
# return pickle.loads(data)
return self._safe_deserialize(data)
except Exception as err:
raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err

Expand Down Expand Up @@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
meta_hash_file = self.cache_dir / meta_hash_file_name
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
torch.save(
obj=self._meta_cache[meta_hash_file_name],
obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand All @@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)
2 changes: 1 addition & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,4 @@ def print_verbose(self) -> None:

# needed in later versions of Pytorch to indicate the class is safe for serialisation
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([MetaTensor])
torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])
46 changes: 10 additions & 36 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import torch
from torch.utils.data._utils.collate import default_collate

from monai import config
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_obj import MetaObj
from monai.utils import (
Expand Down Expand Up @@ -93,7 +92,6 @@
"remove_keys",
"remove_extra_metadata",
"get_extra_metadata_keys",
"PICKLE_KEY_SUFFIX",
"is_no_channel",
]

Expand Down Expand Up @@ -418,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
return


PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX


def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
"""
Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.

Args:
data: a list or dictionary with substructures to be pickled/unpickled.
key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
"""
if isinstance(data, Mapping):
data = dict(data)
for k in data:
if f"{k}".endswith(key):
if is_encode and not isinstance(data[k], bytes):
data[k] = pickle.dumps(data[k], 0)
if not is_encode and isinstance(data[k], bytes):
data[k] = pickle.loads(data[k])
return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
elif isinstance(data, (list, tuple)):
return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
return data


def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
"""
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
Expand Down Expand Up @@ -500,8 +472,8 @@ def list_data_collate(batch: Sequence):
key = None
collate_fn = default_collate
try:
if config.USE_META_DICT:
data = pickle_operations(data) # bc 0.9.0
# if config.USE_META_DICT:
# data = pickle_operations(data) # bc 0.9.0
if isinstance(elem, Mapping):
ret = {}
for k in elem:
Expand Down Expand Up @@ -654,15 +626,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
if isinstance(deco, Mapping):
_gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())
ret = [dict(zip(deco, item)) for item in _gen]
if not config.USE_META_DICT:
return ret
return pickle_operations(ret, is_encode=False) # bc 0.9.0
# if not config.USE_META_DICT:
# return ret
# return pickle_operations(ret, is_encode=False) # bc 0.9.0
return ret
if isinstance(deco, Iterable):
_gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)
ret_list = [list(item) for item in _gen]
if not config.USE_META_DICT:
return ret_list
return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
# if not config.USE_META_DICT:
# return ret_list
# return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
return ret_list
raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")


Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)

k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
Expand Down
6 changes: 3 additions & 3 deletions monai/utils/state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(
pickle_module: module used for pickling metadata and objects, default to `pickle`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

"""
Expand Down Expand Up @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
fn = self.cached[key]["obj"] # pytype: disable=attribute-error
if not os.path.exists(fn): # pytype: disable=wrong-arg-types
raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
# copy back to device if necessary
if "device" in self.cached[key]:
data_obj = data_obj.to(self.cached[key]["device"])
Expand Down
Loading
Loading