Skip to content

Commit 7a1623f

Browse files
committed
conditionner: make t5 optional for chroma
1 parent 250e60f commit 7a1623f

File tree

1 file changed

+121
-87
lines changed

1 file changed

+121
-87
lines changed

conditioner.hpp

Lines changed: 121 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,6 @@ struct SD3CLIPEmbedder : public Conditioner {
666666
bool offload_params_to_cpu,
667667
const String2GGMLType& tensor_types = {})
668668
: clip_g_tokenizer(0) {
669-
if (clip_skip <= 0) {
670-
clip_skip = 2;
671-
}
672669

673670
for (auto pair : tensor_types) {
674671
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
@@ -684,12 +681,12 @@ struct SD3CLIPEmbedder : public Conditioner {
684681
return;
685682
}
686683
if (use_clip_l) {
687-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
684+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
688685
} else {
689686
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
690687
}
691688
if (use_clip_g) {
692-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
689+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
693690
} else {
694691
LOG_WARN("clip_g text encoder not found! Prompt adherence might be degraded.");
695692
}
@@ -698,19 +695,6 @@ struct SD3CLIPEmbedder : public Conditioner {
698695
} else {
699696
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
700697
}
701-
set_clip_skip(clip_skip);
702-
}
703-
704-
void set_clip_skip(int clip_skip) {
705-
if (clip_skip <= 0) {
706-
clip_skip = 2;
707-
}
708-
if (use_clip_l) {
709-
clip_l->set_clip_skip(clip_skip);
710-
}
711-
if (use_clip_g) {
712-
clip_g->set_clip_skip(clip_skip);
713-
}
714698
}
715699

716700
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
@@ -1113,10 +1097,30 @@ struct FluxCLIPEmbedder : public Conditioner {
11131097
FluxCLIPEmbedder(ggml_backend_t backend,
11141098
bool offload_params_to_cpu,
11151099
const String2GGMLType& tensor_types = {}) {
1116-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1117-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1118-
}
1100+
for (auto pair : tensor_types) {
1101+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
1102+
use_clip_l = true;
1103+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1104+
use_t5 = true;
1105+
}
1106+
}
1107+
1108+
if (!use_clip_l && !use_t5) {
1109+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1110+
return;
1111+
}
11191112

1113+
if (use_clip_l) {
1114+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1115+
} else {
1116+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
1117+
}
1118+
if (use_t5) {
1119+
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
1120+
} else {
1121+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
1122+
}
1123+
}
11201124

11211125
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
11221126
if (use_clip_l) {
@@ -1296,7 +1300,6 @@ struct FluxCLIPEmbedder : public Conditioner {
12961300
ggml_set_f32(chunk_hidden_states, 0.f);
12971301
}
12981302

1299-
13001303
int64_t t1 = ggml_time_ms();
13011304
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
13021305
if (zero_out_masked) {
@@ -1305,12 +1308,12 @@ struct FluxCLIPEmbedder : public Conditioner {
13051308
vec[i] = 0;
13061309
}
13071310
}
1308-
1311+
13091312
hidden_states_vec.insert(hidden_states_vec.end(),
1310-
(float*)chunk_hidden_states->data,
1311-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1313+
(float*)chunk_hidden_states->data,
1314+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
13121315
}
1313-
1316+
13141317
if (hidden_states_vec.size() > 0) {
13151318
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
13161319
hidden_states = ggml_reshape_2d(work_ctx,
@@ -1364,7 +1367,8 @@ struct T5CLIPEmbedder : public Conditioner {
13641367
size_t chunk_len = 512;
13651368
bool use_mask = false;
13661369
int mask_pad = 1;
1367-
bool is_umt5 = false;
1370+
bool use_t5 = false;
1371+
13681372

13691373
T5CLIPEmbedder(ggml_backend_t backend,
13701374
bool offload_params_to_cpu,
@@ -1373,26 +1377,43 @@ struct T5CLIPEmbedder : public Conditioner {
13731377
int mask_pad = 1,
13741378
bool is_umt5 = false)
13751379
: use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) {
1376-
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5);
1380+
for (auto pair : tensor_types) {
1381+
if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1382+
use_t5 = true;
1383+
}
1384+
}
1385+
1386+
if (!use_t5) {
1387+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1388+
return;
1389+
} else {
1390+
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5);
1391+
}
13771392
}
13781393

13791394
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1380-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1395+
if (use_t5) {
1396+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1397+
}
13811398
}
13821399

13831400
void alloc_params_buffer() {
1384-
t5->alloc_params_buffer();
1401+
if (use_t5) {
1402+
t5->alloc_params_buffer();
1403+
}
13851404
}
13861405

13871406
void free_params_buffer() {
1388-
t5->free_params_buffer();
1407+
if (use_t5) {
1408+
t5->free_params_buffer();
1409+
}
13891410
}
13901411

13911412
size_t get_params_buffer_size() {
13921413
size_t buffer_size = 0;
1393-
1394-
buffer_size += t5->get_params_buffer_size();
1395-
1414+
if (use_t5) {
1415+
buffer_size += t5->get_params_buffer_size();
1416+
}
13961417
return buffer_size;
13971418
}
13981419

@@ -1418,17 +1439,18 @@ struct T5CLIPEmbedder : public Conditioner {
14181439
std::vector<int> t5_tokens;
14191440
std::vector<float> t5_weights;
14201441
std::vector<float> t5_mask;
1421-
for (const auto& item : parsed_attention) {
1422-
const std::string& curr_text = item.first;
1423-
float curr_weight = item.second;
1424-
1425-
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1426-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1427-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1428-
}
1442+
if (use_t5) {
1443+
for (const auto& item : parsed_attention) {
1444+
const std::string& curr_text = item.first;
1445+
float curr_weight = item.second;
14291446

1430-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1447+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1448+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1449+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1450+
}
14311451

1452+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1453+
}
14321454
return {t5_tokens, t5_weights, t5_mask};
14331455
}
14341456

@@ -1465,66 +1487,78 @@ struct T5CLIPEmbedder : public Conditioner {
14651487
std::vector<float> hidden_states_vec;
14661488

14671489
size_t chunk_count = t5_tokens.size() / chunk_len;
1468-
14691490
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
14701491
// t5
1471-
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1472-
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1473-
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1474-
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1475-
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1476-
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1477-
1478-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1479-
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1480-
1481-
t5->compute(n_threads,
1482-
input_ids,
1483-
t5_attn_mask_chunk,
1484-
&chunk_hidden_states,
1485-
work_ctx);
1486-
{
1487-
auto tensor = chunk_hidden_states;
1488-
float original_mean = ggml_tensor_mean(tensor);
1489-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1490-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1491-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1492-
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1493-
value *= chunk_weights[i1];
1494-
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1492+
1493+
if (use_t5) {
1494+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1495+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1496+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1497+
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1498+
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1499+
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1500+
1501+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1502+
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1503+
t5->compute(n_threads,
1504+
input_ids,
1505+
t5_attn_mask_chunk,
1506+
&chunk_hidden_states,
1507+
work_ctx);
1508+
{
1509+
auto tensor = chunk_hidden_states;
1510+
float original_mean = ggml_tensor_mean(tensor);
1511+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1512+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1513+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1514+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1515+
value *= chunk_weights[i1];
1516+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1517+
}
14951518
}
14961519
}
1520+
float new_mean = ggml_tensor_mean(tensor);
1521+
ggml_tensor_scale(tensor, (original_mean / new_mean));
14971522
}
1498-
float new_mean = ggml_tensor_mean(tensor);
1499-
ggml_tensor_scale(tensor, (original_mean / new_mean));
1500-
}
1501-
1502-
int64_t t1 = ggml_time_ms();
1503-
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
1504-
if (zero_out_masked) {
1505-
auto tensor = chunk_hidden_states;
1506-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1507-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1508-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1509-
if (chunk_mask[i1] < 0.f) {
1510-
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2);
1523+
if (zero_out_masked) {
1524+
auto tensor = chunk_hidden_states;
1525+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1526+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1527+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1528+
if (chunk_mask[i1] < 0.f) {
1529+
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2);
1530+
}
15111531
}
15121532
}
15131533
}
15141534
}
1535+
} else {
1536+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1537+
ggml_set_f32(chunk_hidden_states, 0.f);
1538+
t5_attn_mask = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, chunk_len);
1539+
ggml_set_f32(t5_attn_mask, -HUGE_VALF);
15151540
}
15161541

1542+
int64_t t1 = ggml_time_ms();
1543+
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
1544+
15171545
hidden_states_vec.insert(hidden_states_vec.end(),
15181546
(float*)chunk_hidden_states->data,
15191547
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
15201548
}
15211549

1522-
GGML_ASSERT(hidden_states_vec.size() > 0);
1523-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1524-
hidden_states = ggml_reshape_2d(work_ctx,
1525-
hidden_states,
1526-
chunk_hidden_states->ne[0],
1527-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1550+
if (hidden_states_vec.size() > 0) {
1551+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1552+
hidden_states = ggml_reshape_2d(work_ctx,
1553+
hidden_states,
1554+
chunk_hidden_states->ne[0],
1555+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1556+
} else {
1557+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1558+
ggml_set_f32(hidden_states, 0.f);
1559+
t5_attn_mask = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, chunk_len);
1560+
ggml_set_f32(t5_attn_mask, -HUGE_VALF);
1561+
}
15281562

15291563
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
15301564

0 commit comments

Comments
 (0)