Skip to content

Commit 42f2fa6

Browse files
committed
feat: reduce CLIP memory usage with no embeddings
The CLIP weights need to be converted to f32 for textual inversions (fbd42b6), but that increases the amount of allocated VRAM even when embeddings aren't being used.
1 parent 06e7340 commit 42f2fa6

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

clip.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,15 @@ class CLIPEmbeddings : public GGMLBlock {
544544
int64_t embed_dim;
545545
int64_t vocab_size;
546546
int64_t num_positions;
547+
bool force_clip_f32;
547548

548549
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
549550
enum ggml_type token_wtype = GGML_TYPE_F32;
551+
if (!force_clip_f32) {
552+
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
553+
if (tensor_type != tensor_types.end())
554+
token_wtype = tensor_type->second;
555+
}
550556
enum ggml_type position_wtype = GGML_TYPE_F32;
551557

552558
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
@@ -556,10 +562,12 @@ class CLIPEmbeddings : public GGMLBlock {
556562
public:
557563
CLIPEmbeddings(int64_t embed_dim,
558564
int64_t vocab_size = 49408,
559-
int64_t num_positions = 77)
565+
int64_t num_positions = 77,
566+
bool force_clip_f32 = false)
560567
: embed_dim(embed_dim),
561568
vocab_size(vocab_size),
562-
num_positions(num_positions) {
569+
num_positions(num_positions),
570+
force_clip_f32(force_clip_f32) {
563571
}
564572

565573
struct ggml_tensor* get_token_embed_weight() {
@@ -677,7 +685,8 @@ class CLIPTextModel : public GGMLBlock {
677685
bool with_final_ln = true;
678686

679687
CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
680-
bool with_final_ln = true)
688+
bool with_final_ln = true,
689+
bool force_clip_f32 = false)
681690
: version(version), with_final_ln(with_final_ln) {
682691
if (version == OPEN_CLIP_VIT_H_14) {
683692
hidden_size = 1024;
@@ -691,7 +700,7 @@ class CLIPTextModel : public GGMLBlock {
691700
n_layer = 32;
692701
}
693702

694-
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPEmbeddings(hidden_size, vocab_size, n_token));
703+
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPEmbeddings(hidden_size, vocab_size, n_token, force_clip_f32));
695704
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
696705
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
697706
}
@@ -862,8 +871,9 @@ struct CLIPTextModelRunner : public GGMLRunner {
862871
const String2GGMLType& tensor_types,
863872
const std::string prefix,
864873
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
865-
bool with_final_ln = true)
866-
: GGMLRunner(backend), model(version, with_final_ln) {
874+
bool with_final_ln = true,
875+
bool force_clip_f32 = false)
876+
: GGMLRunner(backend), model(version, with_final_ln, force_clip_f32) {
867877
model.init(params_ctx, tensor_types, prefix);
868878
}
869879

conditioner.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6262
SDVersion version = VERSION_SD1,
6363
PMVersion pv = PM_VERSION_1)
6464
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
65+
bool force_clip_f32 = embd_dir.size() > 0;
6566
if (sd_version_is_sd1(version)) {
66-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14);
67+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
6768
} else if (sd_version_is_sd2(version)) {
68-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14);
69+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
6970
} else if (sd_version_is_sdxl(version)) {
70-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
71-
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
71+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32);
72+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32);
7273
}
7374
}
7475

0 commit comments

Comments
 (0)