Skip to content

Commit c3d50c7

Browse files
committed
Kontext support
1 parent 10c6501 commit c3d50c7

File tree

5 files changed

+197
-73
lines changed

5 files changed

+197
-73
lines changed

diffusion_model.hpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ struct DiffusionModel {
1313
struct ggml_tensor* c_concat,
1414
struct ggml_tensor* y,
1515
struct ggml_tensor* guidance,
16-
int num_video_frames = -1,
17-
std::vector<struct ggml_tensor*> controls = {},
18-
float control_strength = 0.f,
19-
struct ggml_tensor** output = NULL,
20-
struct ggml_context* output_ctx = NULL,
21-
std::vector<int> skip_layers = std::vector<int>()) = 0;
16+
int num_video_frames = -1,
17+
std::vector<struct ggml_tensor*> controls = {},
18+
float control_strength = 0.f,
19+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
20+
struct ggml_tensor** output = NULL,
21+
struct ggml_context* output_ctx = NULL,
22+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2223
virtual void alloc_params_buffer() = 0;
2324
virtual void free_params_buffer() = 0;
2425
virtual void free_compute_buffer() = 0;
@@ -68,12 +69,13 @@ struct UNetModel : public DiffusionModel {
6869
struct ggml_tensor* c_concat,
6970
struct ggml_tensor* y,
7071
struct ggml_tensor* guidance,
71-
int num_video_frames = -1,
72-
std::vector<struct ggml_tensor*> controls = {},
73-
float control_strength = 0.f,
74-
struct ggml_tensor** output = NULL,
75-
struct ggml_context* output_ctx = NULL,
76-
std::vector<int> skip_layers = std::vector<int>()) {
72+
int num_video_frames = -1,
73+
std::vector<struct ggml_tensor*> controls = {},
74+
float control_strength = 0.f,
75+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
76+
struct ggml_tensor** output = NULL,
77+
struct ggml_context* output_ctx = NULL,
78+
std::vector<int> skip_layers = std::vector<int>()) {
7779
(void)skip_layers; // SLG doesn't work with UNet models
7880
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
7981
}
@@ -118,12 +120,13 @@ struct MMDiTModel : public DiffusionModel {
118120
struct ggml_tensor* c_concat,
119121
struct ggml_tensor* y,
120122
struct ggml_tensor* guidance,
121-
int num_video_frames = -1,
122-
std::vector<struct ggml_tensor*> controls = {},
123-
float control_strength = 0.f,
124-
struct ggml_tensor** output = NULL,
125-
struct ggml_context* output_ctx = NULL,
126-
std::vector<int> skip_layers = std::vector<int>()) {
123+
int num_video_frames = -1,
124+
std::vector<struct ggml_tensor*> controls = {},
125+
float control_strength = 0.f,
126+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
127+
struct ggml_tensor** output = NULL,
128+
struct ggml_context* output_ctx = NULL,
129+
std::vector<int> skip_layers = std::vector<int>()) {
127130
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
128131
}
129132
};
@@ -169,13 +172,14 @@ struct FluxModel : public DiffusionModel {
169172
struct ggml_tensor* c_concat,
170173
struct ggml_tensor* y,
171174
struct ggml_tensor* guidance,
172-
int num_video_frames = -1,
173-
std::vector<struct ggml_tensor*> controls = {},
174-
float control_strength = 0.f,
175-
struct ggml_tensor** output = NULL,
176-
struct ggml_context* output_ctx = NULL,
177-
std::vector<int> skip_layers = std::vector<int>()) {
178-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
175+
int num_video_frames = -1,
176+
std::vector<struct ggml_tensor*> controls = {},
177+
float control_strength = 0.f,
178+
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
179+
struct ggml_tensor** output = NULL,
180+
struct ggml_context* output_ctx = NULL,
181+
std::vector<int> skip_layers = std::vector<int>()) {
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers);
179183
}
180184
};
181185

examples/cli/main.cpp

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ struct SDParams {
9090
std::string mask_path;
9191
std::string control_image_path;
9292

93+
std::vector<std::string> kontext_image_paths;
94+
9395
std::string prompt;
9496
std::string negative_prompt;
9597
float min_cfg = 1.0f;
@@ -245,6 +247,7 @@ void print_usage(int argc, const char* argv[]) {
245247
printf(" --canny apply canny preprocessor (edge detection)\n");
246248
printf(" --color Colors the logging tags according to level\n");
247249
printf(" -v, --verbose print extra info\n");
250+
printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n");
248251
}
249252

250253
void parse_args(int argc, const char** argv, SDParams& params) {
@@ -629,6 +632,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629632
break;
630633
}
631634
params.skip_layer_end = std::stof(argv[i]);
635+
} else if (arg == "-ki" || arg == "--kontext-img") {
636+
if (++i >= argc) {
637+
invalid_arg = true;
638+
break;
639+
}
640+
params.kontext_image_paths.push_back(argv[i]);
632641
} else {
633642
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
634643
print_usage(argc, argv);
@@ -821,8 +830,40 @@ int main(int argc, const char* argv[]) {
821830
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
822831
return 1;
823832
}
833+
bool vae_decode_only = true;
834+
835+
std::vector<sd_image_t> kontext_imgs;
836+
for (auto& path : params.kontext_image_paths) {
837+
vae_decode_only = false;
838+
int c = 0;
839+
int width = 0;
840+
int height = 0;
841+
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
842+
if (image_buffer == NULL) {
843+
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
844+
return 1;
845+
}
846+
if (c < 3) {
847+
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
848+
free(image_buffer);
849+
return 1;
850+
}
851+
if (width <= 0) {
852+
fprintf(stderr, "error: the width of image must be greater than 0\n");
853+
free(image_buffer);
854+
return 1;
855+
}
856+
if (height <= 0) {
857+
fprintf(stderr, "error: the height of image must be greater than 0\n");
858+
free(image_buffer);
859+
return 1;
860+
}
861+
kontext_imgs.push_back({(uint32_t)width,
862+
(uint32_t)height,
863+
3,
864+
image_buffer});
865+
}
824866

825-
bool vae_decode_only = true;
826867
uint8_t* input_image_buffer = NULL;
827868
uint8_t* control_image_buffer = NULL;
828869
uint8_t* mask_image_buffer = NULL;
@@ -963,6 +1004,7 @@ int main(int argc, const char* argv[]) {
9631004
params.style_ratio,
9641005
params.normalize_input,
9651006
params.input_id_images_path.c_str(),
1007+
kontext_imgs.data(), kontext_imgs.size(),
9661008
params.skip_layers.data(),
9671009
params.skip_layers.size(),
9681010
params.slg_scale,
@@ -1032,6 +1074,7 @@ int main(int argc, const char* argv[]) {
10321074
params.style_ratio,
10331075
params.normalize_input,
10341076
params.input_id_images_path.c_str(),
1077+
kontext_imgs.data(), kontext_imgs.size(),
10351078
params.skip_layers.data(),
10361079
params.skip_layers.size(),
10371080
params.slg_scale,
@@ -1075,19 +1118,19 @@ int main(int argc, const char* argv[]) {
10751118

10761119
std::string dummy_name, ext, lc_ext;
10771120
bool is_jpg;
1078-
size_t last = params.output_path.find_last_of(".");
1121+
size_t last = params.output_path.find_last_of(".");
10791122
size_t last_path = std::min(params.output_path.find_last_of("/"),
10801123
params.output_path.find_last_of("\\"));
1081-
if (last != std::string::npos // filename has extension
1082-
&& (last_path == std::string::npos || last > last_path)) {
1124+
if (last != std::string::npos // filename has extension
1125+
&& (last_path == std::string::npos || last > last_path)) {
10831126
dummy_name = params.output_path.substr(0, last);
10841127
ext = lc_ext = params.output_path.substr(last);
10851128
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
10861129
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
10871130
} else {
10881131
dummy_name = params.output_path;
10891132
ext = lc_ext = "";
1090-
is_jpg = false;
1133+
is_jpg = false;
10911134
}
10921135
// appending ".png" to absent or unknown extension
10931136
if (!is_jpg && lc_ext != ".png") {
@@ -1099,7 +1142,7 @@ int main(int argc, const char* argv[]) {
10991142
continue;
11001143
}
11011144
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1102-
if(is_jpg) {
1145+
if (is_jpg) {
11031146
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11041147
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
11051148
printf("save result JPEG image to '%s'\n", final_image_path.c_str());

0 commit comments

Comments
 (0)