|
3 | 3 |
|
4 | 4 | # This source code is licensed under the license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
| 6 | +import glob |
6 | 7 | import json |
7 | 8 | import os |
8 | 9 | import re |
@@ -41,7 +42,12 @@ def convert_hf_checkpoint( |
41 | 42 | print(f"Model config {config.__dict__}") |
42 | 43 |
|
43 | 44 | # Load the json file containing weight mapping |
44 | | - model_map_json = model_dir / "pytorch_model.bin.index.json" |
| 45 | + model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] |
| 46 | + assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" |
| 47 | + if len(model_map_json_matches): |
| 48 | + model_map_json = model_map_json_matches[0] |
| 49 | + else: |
| 50 | + model_map_json = model_dir / "pytorch_model.bin.index.json" |
45 | 51 |
|
46 | 52 | # If there is no weight mapping, check for a consolidated model and |
47 | 53 | # tokenizer we can move. Llama 2 and Mistral have weight mappings, while |
@@ -96,9 +102,33 @@ def permute(w, n_heads): |
96 | 102 |
|
97 | 103 | merged_result = {} |
98 | 104 | for file in sorted(bin_files): |
99 | | - state_dict = torch.load( |
| 105 | + |
| 106 | + # The state_dict can be loaded from either a torch zip file or |
| 107 | + # safetensors. We take our best guess from the name and try all |
| 108 | + # possibilities |
| 109 | + load_pt_mmap = lambda: torch.load( |
100 | 110 | str(file), map_location="cpu", mmap=True, weights_only=True |
101 | 111 | ) |
| 112 | + load_pt_no_mmap = lambda: torch.load( |
| 113 | + str(file), map_location="cpu", mmap=False, weights_only=True |
| 114 | + ) |
| 115 | + def load_safetensors(): |
| 116 | + import safetensors.torch |
| 117 | + with open(file, "rb") as handle: |
| 118 | + return safetensors.torch.load(handle.read()) |
| 119 | + if "safetensors" in str(file): |
| 120 | + loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap] |
| 121 | + else: |
| 122 | + loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors] |
| 123 | + |
| 124 | + state_dict = None |
| 125 | + for loader in loaders: |
| 126 | + try: |
| 127 | + state_dict = loader() |
| 128 | + break |
| 129 | + except Exception: |
| 130 | + continue |
| 131 | + assert state_dict is not None, f"Unable to load tensors from {file}" |
102 | 132 | merged_result.update(state_dict) |
103 | 133 | final_result = {} |
104 | 134 | for key, value in merged_result.items(): |
|
0 commit comments