Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/llava/llava-surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]

# store these tensors in a new dictionary and torch.save them
projector = {name: checkpoint[name] for name in mm_tensors}
projector = {name: checkpoint[name].float() for name in mm_tensors}
torch.save(projector, f"{args.model}/llava.projector")

# remove these tensors from the checkpoint and save it again
for name in mm_tensors:
del checkpoint[name]

# BakLLaVA models contain CLIP tensors in it
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
if len(clip_tensors) > 0:
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
torch.save(clip, f"{args.model}/llava.clip")

# remove these tensors
for name in clip_tensors:
del checkpoint[name]

# added tokens should be removed to be able to convert Mistral models
if os.path.exists(f"{args.model}/added_tokens.json"):
with open(f"{args.model}/added_tokens.json", "w") as f:
f.write("{}\n")


torch.save(checkpoint, path)

print("Done!")
Expand Down