88import json
99import re
1010import shutil
11+ import sys
1112from pathlib import Path
1213from typing import Optional
13-
14+ from safetensors . torch import load_file as load_safetensors_file
1415import torch
1516
1617from torchao ._models .llama .model import ModelArgs
@@ -24,63 +25,49 @@ def convert_hf_checkpoint(
2425) -> None :
2526 if model_name is None :
2627 model_name = checkpoint_dir .name
27-
28- # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
29- # need to be copied into model.pth.
30- # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
31- # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
32- # currently supported.
33- # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
34- is_llama3 = "Llama-3" in model_name
35- if is_llama3 :
36- # Check if we have multiple original/consolidated.NN.pth files and report error
37- # if we do for Llama 3.
38- original_dir = checkpoint_dir / "original"
39- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
40- bin_files = [bin for bin in original_dir .iterdir () if pattern .match (bin .name )]
41- if len (bin_files ) > 1 :
42- raise ValueError (
43- f"Multiple consolidated.NN.pth files found in { original_dir } . "
44- "Merging them into one model.pth file is not supported for Llama 3." )
45-
46-
4728 config = ModelArgs .from_name (model_name )
4829 print (f"Model config { config .__dict__ } " )
4930
5031 # Load the json file containing weight mapping
51- if not is_llama3 :
52- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
53-
54- assert model_map_json .is_file ()
55-
56- with open (model_map_json ) as json_map :
57- bin_index = json .load (json_map )
58-
59- weight_map = {
60- "model.embed_tokens.weight" : "tok_embeddings.weight" ,
61- "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
62- "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
63- "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
64- "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
65- 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
66- 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
67- "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
68- "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
69- "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
70- "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
71- "model.norm.weight" : "norm.weight" ,
72- "lm_head.weight" : "output.weight" ,
73- }
74- bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
75- else :
76- # There is no separate pytorch_model.bin.index.json file for llama3.
77- # Instead, we will just use all original/consolidated.NN.pth files.
78- # so, we use model.safetensors.index.json
79- weight_map = None
80- original_dir = checkpoint_dir / "original"
81- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
82- bin_files = {bin for bin in original_dir .iterdir () if pattern .match (bin .name )}
83-
32+ model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
33+ model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
34+ model_map_json = None
35+
36+ try :
37+ assert model_map_json_safetensors .is_file ()
38+ model_map_json = model_map_json_safetensors
39+ print (f"Found safetensors index at { model_map_json_safetensors } " )
40+ except AssertionError :
41+ print (f"{ model_map_json_safetensors } not found" )
42+ if model_map_json is None :
43+ try :
44+ assert model_map_json_pytorch .is_file ()
45+ model_map_json = model_map_json_pytorch
46+ print (f"Found pytorch index at { model_map_json_pytorch } " )
47+ except AssertionError :
48+ print (f"{ model_map_json_pytorch } not found" )
49+
50+ if model_map_json is None : raise Exception ("No model map found!" )
51+
52+ with open (model_map_json ) as json_map :
53+ bin_index = json .load (json_map )
54+
55+ weight_map = {
56+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
57+ "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
58+ "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
59+ "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
60+ "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
61+ 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
62+ 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
63+ "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
64+ "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
65+ "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
66+ "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
67+ "model.norm.weight" : "norm.weight" ,
68+ "lm_head.weight" : "output.weight" ,
69+ }
70+ bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
8471
8572 def permute (w , n_head ):
8673 dim = config .dim
@@ -92,40 +79,44 @@ def permute(w, n_head):
9279
9380 merged_result = {}
9481 for file in sorted (bin_files ):
95- state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
96- merged_result .update (state_dict )
82+ if "safetensors" in str (file ):
83+ state_dict = load_safetensors_file (str (file ), device = "cpu" )
84+ merged_result .update (state_dict )
85+ else :
86+ state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
87+ merged_result .update (state_dict )
9788 final_result = {}
98- if weight_map is not None :
99- for key , value in merged_result .items ():
100- if "layers" in key :
101- abstract_key = re .sub (r'(\d+)' , '{}' , key )
102- layer_num = re .search (r'\d+' , key ).group (0 )
103- new_key = weight_map [abstract_key ]
104- if new_key is None :
105- continue
106- new_key = new_key .format (layer_num )
107- else :
108- new_key = weight_map [key ]
109-
110- final_result [new_key ] = value
111-
112- for key in tuple (final_result .keys ()):
113- if "wq" in key :
114- q = final_result [key ]
115- k = final_result [key .replace ("wq" , "wk" )]
116- v = final_result [key .replace ("wq" , "wv" )]
117- q = permute (q , config .n_head )
118- k = permute (k , config .n_local_heads )
119- final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
120- del final_result [key ]
121- del final_result [key .replace ("wq" , "wk" )]
122- del final_result [key .replace ("wq" , "wv" )]
123- else :
124- final_result = merged_result
89+ for key , value in merged_result .items ():
90+ if "layers" in key :
91+ abstract_key = re .sub (r'(\d+)' , '{}' , key )
92+ layer_num = re .search (r'\d+' , key ).group (0 )
93+ new_key = weight_map [abstract_key ]
94+ if new_key is None :
95+ continue
96+ new_key = new_key .format (layer_num )
97+ else :
98+ new_key = weight_map [key ]
99+
100+ final_result [new_key ] = value
101+
102+ for key in tuple (final_result .keys ()):
103+ if "wq" in key :
104+ q = final_result [key ]
105+ k = final_result [key .replace ("wq" , "wk" )]
106+ v = final_result [key .replace ("wq" , "wv" )]
107+ q = permute (q , config .n_head )
108+ k = permute (k , config .n_local_heads )
109+ final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
110+ del final_result [key ]
111+ del final_result [key .replace ("wq" , "wk" )]
112+ del final_result [key .replace ("wq" , "wv" )]
125113 print (f"Saving checkpoint to { checkpoint_dir / 'model.pth' } " )
126114 torch .save (final_result , checkpoint_dir / "model.pth" )
127- if is_llama3 :
128- original_dir = checkpoint_dir / "original"
115+ if 'llama-3-' in model_name .lower () or 'llama-3.1-' in model_name .lower ():
116+ if 'llama-3.1-405b' in model_name .lower ():
117+ original_dir = checkpoint_dir / "original" / "mp16"
118+ else :
119+ original_dir = checkpoint_dir / "original"
129120 tokenizer_model = original_dir / "tokenizer.model"
130121 tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
131122 print (f"Copying { tokenizer_model } to { tokenizer_model_tiktoken } " )
0 commit comments