diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f428e4cc6..122ab0f28 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -39,19 +39,14 @@ def convert_hf_checkpoint( config = TransformerArgs.from_params(config_args) print(f"Model config {config.__dict__}") - # Load the json file containing weight mapping + # Find all candidate weight mapping index files model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] - assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" - if len(model_map_json_matches): - model_map_json = model_map_json_matches[0] - else: - model_map_json = model_dir / "pytorch_model.bin.index.json" # If there is no weight mapping, check for a consolidated model and # tokenizer we can move. Llama 2 and Mistral have weight mappings, while # Llama 3 has a consolidated model and tokenizer. # Otherwise raise an error. - if not model_map_json.is_file(): + if not model_map_json_matches: consolidated_pth = model_dir / "original" / "consolidated.00.pth" tokenizer_pth = model_dir / "original" / "tokenizer.model" if consolidated_pth.is_file() and tokenizer_pth.is_file(): @@ -68,11 +63,30 @@ def convert_hf_checkpoint( return else: raise RuntimeError( - f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}" + f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}" ) - with open(model_map_json) as json_map: - bin_index = json.load(json_map) + # Load the json file(s) containing weight mapping + # + # NOTE: If there are multiple index files, there are two possibilities: + # 1. The files could be mapped to different weight format files (e.g. .bin + # vs .safetensors) + # 2. The files could be split subsets of the mappings that need to be + # merged + # + # In either case, we can simply keep the mappings where the target file is + # valid in the model dir. + bin_index = {} + for weight_map_file in model_map_json_matches: + with open(weight_map_file, "r") as handle: + weight_map = json.load(handle) + valid_mappings = { + k: model_dir / v + for (k, v) in weight_map.get("weight_map", {}).items() + if (model_dir / v).is_file() + } + bin_index.update(valid_mappings) + bin_files = set(bin_index.values()) weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", @@ -96,7 +110,6 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): return ( diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index f334eb555..4da2bc390 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -35,11 +35,12 @@ def _download_hf_snapshot( model_info = model_info(model_config.distribution_path, token=hf_token) model_fnames = [f.rfilename for f in model_info.siblings] - # Check the model config for preference between safetensors and pth + # Check the model config for preference between safetensors and pth/bin has_pth = any(f.endswith(".pth") for f in model_fnames) + has_bin = any(f.endswith(".bin") for f in model_fnames) has_safetensors = any(f.endswith(".safetensors") for f in model_fnames) - # If told to prefer safetensors, ignore pth files + # If told to prefer safetensors, ignore pth/bin files if model_config.prefer_safetensors: if not has_safetensors: print( @@ -47,10 +48,10 @@ def _download_hf_snapshot( file=sys.stderr, ) exit(1) - ignore_patterns = "*.pth" + ignore_patterns = ["*.pth", "*.bin"] # If the model has both, prefer pth files over safetensors - elif has_pth and has_safetensors: + elif (has_pth or has_bin) and has_safetensors: ignore_patterns = "*safetensors*" # Otherwise, download everything