Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6237,9 +6237,11 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("DeepseekV2ForCausalLM")
@ModelBase.register("DeepseekV3ForCausalLM")
@ModelBase.register("KimiVLForConditionalGeneration")
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2

Expand Down Expand Up @@ -8484,6 +8486,43 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
return "mm.2.weight"
return super().map_tensor_name(name, try_suffixes)


@ModelBase.register("KimiVLForConditionalGeneration")
class KimiVLModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.hparams_vision["image_size"] = 64 * 14 # for compatibility

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_projector_scale_factor(2)
# eps is the same as pytorch's default value
assert self.hparams_vision is not None
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name

if is_vision_tensor:
if "pos_emb.weight" in name:
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
elif "wqkv" in name:
split_dim = 0 if "weight" in name else -1
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
return [
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),
(self.map_tensor_name(name.replace("wqkv", "wv")), wv)
]

return [(self.map_tensor_name(name), data_torch)]

return [] # skip other tensors

###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2833,6 +2833,7 @@ class VisionProjectorType:
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"


# Items here are (block size, type size)
Expand Down
12 changes: 12 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,7 @@ class TensorNameMap:
"vision_encoder.patch_conv", # pixtral
"vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl
),

MODEL_TENSOR.V_ENC_EMBD_POS: (
Expand All @@ -1131,6 +1132,7 @@ class TensorNameMap:
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
"vision_model.positional_embedding_vlm", # llama 4
"vision_tower.patch_embed.pos_emb", # kimi-vl
),

MODEL_TENSOR.V_ENC_ATTN_Q: (
Expand All @@ -1142,6 +1144,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
),

MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
Expand All @@ -1158,6 +1161,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
),

MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
Expand All @@ -1174,6 +1178,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
),

MODEL_TENSOR.V_ENC_INPUT_NORM: (
Expand All @@ -1186,6 +1191,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
"vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
),

MODEL_TENSOR.V_ENC_ATTN_O: (
Expand All @@ -1198,6 +1204,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
),

MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
Expand All @@ -1210,6 +1217,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
),

MODEL_TENSOR.V_ENC_FFN_UP: (
Expand All @@ -1222,6 +1230,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
),

MODEL_TENSOR.V_ENC_FFN_GATE: (
Expand All @@ -1240,6 +1249,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
),

MODEL_TENSOR.V_LAYER_SCALE_1: (
Expand All @@ -1264,6 +1274,7 @@ class TensorNameMap:
"model.vision_model.post_layernorm", # SmolVLM
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
),

MODEL_TENSOR.V_MM_INP_PROJ: (
Expand All @@ -1273,6 +1284,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
),

Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_UNKNOWN,
};

Expand All @@ -156,6 +157,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
Loading
Loading