Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4215,14 +4215,31 @@ def _load_pretrained_model(
pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")")
if sharded_metadata is None:
k_v_iterator = dict.fromkeys(
safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0].rsplit("/", 1)[1]
safe_open(checkpoint_files[0], framework="pt").keys(), os.path.basename(checkpoint_files[0])
).items()
else:
k_v_iterator = sharded_metadata["weight_map"].items()

merged_state_dict = {}
# Create a mapping from filename to full path for all checkpoint files
filename_to_path = {os.path.basename(f): f for f in checkpoint_files}

# Group weights by file to load sequentially and avoid keeping too many files open
weights_by_file = {}
for k, v in k_v_iterator:
match = pattern.match(k)
if v not in weights_by_file:
weights_by_file[v] = []
weights_by_file[v].append(k)

merged_state_dict = {}
# Load each file sequentially
for filename, weight_keys in weights_by_file.items():
# Use the mapping to get the correct file path instead of joining paths
# This handles symbolic links on Windows correctly
shard_file_path = filename_to_path.get(
filename, os.path.join(os.path.dirname(checkpoint_files[0]), filename)
)

match = pattern.match(weight_keys[0])
if match and match.group(1) != "":
device = device_map[match.group(1)]
else:
Expand All @@ -4231,11 +4248,11 @@ def _load_pretrained_model(
device = device.index # safetensors only
if device == "disk":
device = "cpu" # we read to cpu to then write to disk
file_pointer = safe_open(
os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device
)
all_pointer.add(file_pointer)
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet

with safe_open(shard_file_path, framework="pt", device=device) as f:
for k in weight_keys:
# Materialize the tensor immediately instead of keeping a lazy slice
merged_state_dict[k] = f.get_tensor(k)
elif state_dict is not None:
merged_state_dict = state_dict
elif checkpoint_files is not None:
Expand Down