diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index 7e3e2d676..f95cbdaef 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -81,10 +81,17 @@ def convert_hf_checkpoint( "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", + "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias", + "model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias", + "model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", "model.norm.weight": "norm.weight", @@ -93,11 +100,10 @@ def convert_hf_checkpoint( bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): - dim = config.dim return ( - w.view(n_heads, 2, config.head_dim // 2, dim) + w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:]) .transpose(1, 2) - .reshape(config.head_dim * n_heads, dim) + .reshape(w.shape) ) merged_result = {} @@ -130,6 +136,7 @@ def load_safetensors(): continue assert state_dict is not None, f"Unable to load tensors from {file}" merged_result.update(state_dict) + final_result = {} for key, value in merged_result.items(): if "layers" in key: @@ -145,16 +152,18 @@ def load_safetensors(): final_result[new_key] = value for key in tuple(final_result.keys()): - if "wq" in key: + if "wq.weight" in key or "wq.bias" in key: + wk_key = key.replace("wq", "wk") + wv_key = key.replace("wq", "wv") q = final_result[key] - k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] + k = final_result[wk_key] + v = final_result[wv_key] q = permute(q, config.n_heads) k = permute(k, config.n_local_heads) final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] + del final_result[wk_key] + del final_result[wv_key] print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.") torch.save(final_result, model_dir / "model.pth") print("Done.") diff --git a/torchchat/model.py b/torchchat/model.py index 93a0d7366..673b582d3 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -34,7 +34,7 @@ try: # TODO: remove this after we figure out where in torchtune an `evaluate` module # is being imported, which is being confused with huggingface's `evaluate``. - import lm_eval # noqa + import lm_eval # noqa except Exception: pass @@ -278,6 +278,9 @@ class TransformerArgs: # For pipeline parallel n_stages: int = 1 stage_idx: int = 0 + # Optional biases + attention_bias: bool = False + feed_forward_bias: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -394,7 +397,7 @@ def from_name(cls, name: str): config = [ config for config in known_model_params - if config in str(name).upper() or config in str(name) + if config.upper() in str(name).upper() or config in str(name) ] # We may have two or more configs matched (e.g., "7B" and @@ -471,7 +474,7 @@ def build_model(self) -> nn.Module: modules[name] = module_class(TransformerArgs.from_params(config_args)) else: modules[name] = module_class(**config_args) - + # Temporary add extra params to the DeepFusionModel. # TODO: Remove it once we can make fusion model configurable in model_param. if recipe.fusion_class == DeepFusionModel: @@ -730,16 +733,16 @@ def __init__(self, config: TransformerArgs): # key, query, value projections for all heads, but in a batch # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim - # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) - self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) + # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias) + self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias) self.wk = nn.Linear( - config.dim, config.n_local_heads * config.head_dim, bias=False + config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias ) self.wv = nn.Linear( - config.dim, config.n_local_heads * config.head_dim, bias=False + config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias ) - self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias) self.kv_cache = None self.n_heads = config.n_heads @@ -766,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args): # wv = state_dict.pop(prefix + "wv.weight") # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - if prefix + "wqkv.weight" in state_dict: - wqkv = state_dict.pop(prefix + "wqkv.weight") - q_size = self.n_heads * self.head_dim - kv_size = self.n_local_heads * self.head_dim - wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0) - state_dict[prefix + "wq.weight"] = wq - state_dict[prefix + "wk.weight"] = wk - state_dict[prefix + "wv.weight"] = wv + for tensor_suffix in ["weight", "bias"]: + wqkv_key = f"{prefix}wqkv.{tensor_suffix}" + if wqkv_key in state_dict: + wqkv = state_dict.pop(wqkv_key) + q_size = self.n_heads * self.head_dim + kv_size = self.n_local_heads * self.head_dim + wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0) + state_dict[f"{prefix}wq.{tensor_suffix}"] = wq + state_dict[f"{prefix}wk.{tensor_suffix}"] = wk + state_dict[f"{prefix}wv.{tensor_suffix}"] = wv return @@ -852,9 +857,9 @@ def forward( class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() - self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) - self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) - self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias) + self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias) + self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias) def distribute(self, device_mesh: DeviceMesh): parallelize_module(self.w1, device_mesh, ColwiseParallel())