From c3d50c7dc287476f1c2d4ae46e827d8c064166b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 27 Jun 2025 10:57:14 +0200 Subject: [PATCH 1/2] Kontext support --- diffusion_model.hpp | 54 +++++++++++++++------------- examples/cli/main.cpp | 55 ++++++++++++++++++++++++---- flux.hpp | 84 ++++++++++++++++++++++++++++--------------- stable-diffusion.cpp | 73 +++++++++++++++++++++++++++++-------- stable-diffusion.h | 4 +++ 5 files changed, 197 insertions(+), 73 deletions(-) diff --git a/diffusion_model.hpp b/diffusion_model.hpp index ee4d88f0c..6a36f7166 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -13,12 +13,13 @@ struct DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) = 0; + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + std::vector kontext_imgs = std::vector(), + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; @@ -68,12 +69,13 @@ struct UNetModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + std::vector kontext_imgs = std::vector(), + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { (void)skip_layers; // SLG doesn't work with UNet models return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); } @@ -118,12 +120,13 @@ struct MMDiTModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + std::vector kontext_imgs = std::vector(), + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; @@ -169,13 +172,14 @@ struct FluxModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { - return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers); + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + std::vector kontext_imgs = std::vector(), + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers); } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af6b2bbdb..e5003930e 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -90,6 +90,8 @@ struct SDParams { std::string mask_path; std::string control_image_path; + std::vector kontext_image_paths; + std::string prompt; std::string negative_prompt; float min_cfg = 1.0f; @@ -245,6 +247,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color Colors the logging tags according to level\n"); printf(" -v, --verbose print extra info\n"); + printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n"); } void parse_args(int argc, const char** argv, SDParams& params) { @@ -629,6 +632,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); + } else if (arg == "-ki" || arg == "--kontext-img") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.kontext_image_paths.push_back(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -821,8 +830,40 @@ int main(int argc, const char* argv[]) { fprintf(stderr, "SVD support is broken, do not use it!!!\n"); return 1; } + bool vae_decode_only = true; + + std::vector kontext_imgs; + for (auto& path : params.kontext_image_paths) { + vae_decode_only = false; + int c = 0; + int width = 0; + int height = 0; + uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return 1; + } + if (c < 3) { + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); + free(image_buffer); + return 1; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + kontext_imgs.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + } - bool vae_decode_only = true; uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; uint8_t* mask_image_buffer = NULL; @@ -963,6 +1004,7 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), + kontext_imgs.data(), kontext_imgs.size(), params.skip_layers.data(), params.skip_layers.size(), params.slg_scale, @@ -1032,6 +1074,7 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), + kontext_imgs.data(), kontext_imgs.size(), params.skip_layers.data(), params.skip_layers.size(), params.slg_scale, @@ -1075,11 +1118,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1087,7 +1130,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1099,7 +1142,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if(is_jpg) { + if (is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(params, params.seed + i).c_str()); printf("save result JPEG image to '%s'\n", final_image_path.c_str()); diff --git a/flux.hpp b/flux.hpp index 20ff41096..897c42f79 100644 --- a/flux.hpp +++ b/flux.hpp @@ -570,11 +570,11 @@ namespace Flux { } // Generate IDs for image patches and text - std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { + std::vector> gen_ids(int h, int w, int patch_size, int index = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); + std::vector> img_ids(h_len * w_len, std::vector(3, (float)index)); std::vector row_ids = linspace(0, h_len - 1, h_len); std::vector col_ids = linspace(0, w_len - 1, w_len); @@ -586,10 +586,22 @@ namespace Flux { } } - std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); - for (int i = 0; i < bs; ++i) { - for (int j = 0; j < img_ids.size(); ++j) { - img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + return img_ids; + } + + // Generate positional embeddings + std::vector gen_pe(std::vector imgs, struct ggml_tensor* context, int patch_size, int theta, const std::vector& axes_dim) { + int context_len = context->ne[1]; + int bs = imgs[0]->ne[3]; + + std::vector> img_ids; + for (int i = 0; i < imgs.size(); i++) { + auto x = imgs[i]; + if (x) { + int h = x->ne[1]; + int w = x->ne[0]; + std::vector> img_ids_i = gen_ids(h, w, patch_size, i); + img_ids.insert(img_ids.end(), img_ids_i.begin(), img_ids_i.end()); } } @@ -600,16 +612,10 @@ namespace Flux { ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; } for (int j = 0; j < img_ids.size(); ++j) { - ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; + ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids[j]; } } - return ids; - } - - // Generate positional embeddings - std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { - std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size(); int num_axes = axes_dim.size(); @@ -786,7 +792,7 @@ namespace Flux { } struct ggml_tensor* forward(struct ggml_context* ctx, - struct ggml_tensor* x, + std::vector imgs, struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* c_concat, @@ -804,19 +810,31 @@ namespace Flux { // pe: (L, d_head/2, 2, 2) // return: (N, C, H, W) + auto x = imgs[0]; GGML_ASSERT(x->ne[3] == 1); int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t C = x->ne[2]; int64_t patch_size = 2; - int pad_h = (patch_size - H % patch_size) % patch_size; - int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size; + int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size; // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] - + ggml_tensor* img = NULL; // [N, h*w, C * patch_size * patch_size] + int64_t patchified_img_size; + for (auto& x : imgs) { + int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size; + int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size; + ggml_tensor* pad_x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); + pad_x = patchify(ctx, pad_x, patch_size); + if (img) { + img = ggml_concat(ctx, img, pad_x, 1); + } else { + img = pad_x; + patchified_img_size = img->ne[1]; + } + } if (c_concat != NULL) { ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -831,6 +849,7 @@ namespace Flux { } auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] + out = ggml_cont(ctx, ggml_view_2d(ctx, out, out->ne[0], patchified_img_size, out->nb[1], 0)); // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -909,7 +928,8 @@ namespace Flux { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - std::vector skip_layers = std::vector()) { + std::vector kontext_imgs = std::vector(), + std::vector skip_layers = std::vector()) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); @@ -918,13 +938,20 @@ namespace Flux { if (c_concat != NULL) { c_concat = to_backend(c_concat); } - y = to_backend(y); + for (auto& img : kontext_imgs) { + img = to_backend(img); + } + + y = to_backend(y); + timesteps = to_backend(timesteps); if (flux_params.guidance_embed) { guidance = to_backend(guidance); } + auto imgs = kontext_imgs; + imgs.insert(imgs.begin(), x); - pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); + pe_vec = flux.gen_pe(imgs, context, 2, flux_params.theta, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); @@ -934,7 +961,7 @@ namespace Flux { set_backend_tensor_data(pe, pe_vec.data()); struct ggml_tensor* out = flux.forward(compute_ctx, - x, + imgs, timesteps, context, c_concat, @@ -955,16 +982,17 @@ namespace Flux { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + std::vector kontext_imgs = std::vector(), + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers); + return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -1004,7 +1032,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, std::vector(), &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..4b807a5db 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,8 +48,7 @@ const char* sampling_methods_str[] = { "iPNDM_v", "LCM", "DDIM \"trailing\"", - "TCD" -}; + "TCD"}; /*================================================== Helper Functions ================================================*/ @@ -618,7 +617,7 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, std::vector(), &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -682,7 +681,7 @@ class StableDiffusionGGML { float curr_multiplier = kv.second; lora_state_diff[lora_name] -= curr_multiplier; } - + size_t rm = lora_state_diff.size() - lora_state.size(); if (rm != 0) { LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); @@ -800,11 +799,12 @@ class StableDiffusionGGML { const std::vector& sigmas, int start_merge_step, SDCondition id_cond, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* noise_mask = nullptr) { + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + std::vector kontext_imgs = std::vector(), + ggml_tensor* noise_mask = NULL) { LOG_DEBUG("Sample"); struct ggml_init_params params; size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); @@ -890,6 +890,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + kontext_imgs, &out_cond); } else { diffusion_model->compute(n_threads, @@ -902,6 +903,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + kontext_imgs, &out_cond); } @@ -922,6 +924,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + kontext_imgs, &out_uncond); negative_data = (float*)out_uncond->data; } @@ -942,6 +945,7 @@ class StableDiffusionGGML { -1, controls, control_strength, + kontext_imgs, &out_skip, NULL, skip_layers); @@ -1209,11 +1213,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, std::string input_id_images_path, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* masked_image = NULL) { + std::vector kontext_imgs = std::vector(), + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* masked_image = NULL) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1470,6 +1475,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, slg_scale, skip_layer_start, skip_layer_end, + kontext_imgs, noise_mask); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); @@ -1539,6 +1545,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, const char* input_id_images_path_c_str, + sd_image_t* kontext_imgs, + int kontext_img_count, int* skip_layers = NULL, size_t skip_layers_count = 0, float slg_scale = 0, @@ -1597,6 +1605,22 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_version_is_inpaint(sd_ctx->sd->version)) { LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); } + std::vector kontext_latents = std::vector(); + if (kontext_imgs) { + for (int i = 0; i < kontext_img_count; i++) { + ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, kontext_imgs[i].width, kontext_imgs[i].height, 3, 1); + sd_image_to_tensor(kontext_imgs[i].data, img); + + ggml_tensor* latent = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); + latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + latent = sd_ctx->sd->encode_first_stage(work_ctx, img); + } + kontext_latents.push_back(latent); + } + } sd_image_t* result_images = generate_image(sd_ctx, work_ctx, @@ -1618,6 +1642,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, + kontext_latents, skip_layers_vec, slg_scale, skip_layer_start, @@ -1651,6 +1676,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, const char* input_id_images_path_c_str, + sd_image_t* kontext_imgs, + int kontext_img_count, int* skip_layers = NULL, size_t skip_layers_count = 0, float slg_scale = 0, @@ -1766,6 +1793,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } + std::vector kontext_latents = std::vector(); + if (kontext_imgs) { + for (int i = 0; i < kontext_img_count; i++) { + ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_image_to_tensor(kontext_imgs[i].data, img); + + ggml_tensor* latent = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); + latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + latent = sd_ctx->sd->encode_first_stage(work_ctx, img); + } + kontext_latents.push_back(latent); + } + } + print_ggml_tensor(init_latent, true); size_t t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); @@ -1798,6 +1842,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, + kontext_latents, skip_layers_vec, slg_scale, skip_layer_start, diff --git a/stable-diffusion.h b/stable-diffusion.h index 52dcc848a..e59b2a30d 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -172,6 +172,8 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, float style_strength, bool normalize_input, const char* input_id_images_path, + sd_image_t* kontext_imgs, + int kontext_img_count, int* skip_layers, size_t skip_layers_count, float slg_scale, @@ -199,6 +201,8 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, float style_strength, bool normalize_input, const char* input_id_images_path, + sd_image_t* kontext_imgs, + int kontext_img_count, int* skip_layers, size_t skip_layers_count, float slg_scale, From 896788957a0b985ef3a7c8787705d3235c116806 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 28 Jun 2025 12:54:51 +0800 Subject: [PATCH 2/2] add edit mode --- diffusion_model.hpp | 58 ++++++------ examples/cli/main.cpp | 133 ++++++++++++++++---------- flux.hpp | 181 ++++++++++++++++++++++-------------- stable-diffusion.cpp | 211 ++++++++++++++++++++++++++++++------------ stable-diffusion.h | 30 +++++- 5 files changed, 400 insertions(+), 213 deletions(-) diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 6a36f7166..94e9a2678 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -13,13 +13,13 @@ struct DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - std::vector kontext_imgs = std::vector(), - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) = 0; + std::vector ref_latents = {}, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; @@ -69,13 +69,13 @@ struct UNetModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - std::vector kontext_imgs = std::vector(), - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { (void)skip_layers; // SLG doesn't work with UNet models return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); } @@ -120,13 +120,13 @@ struct MMDiTModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - std::vector kontext_imgs = std::vector(), - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; @@ -172,14 +172,14 @@ struct FluxModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - std::vector kontext_imgs = std::vector(), - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { - return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers); + std::vector ref_latents = {}, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers); } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index e5003930e..466fe87c2 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -57,6 +57,7 @@ const char* modes_str[] = { "txt2img", "img2img", "img2vid", + "edit", "convert", }; @@ -64,6 +65,7 @@ enum SDMode { TXT2IMG, IMG2IMG, IMG2VID, + EDIT, CONVERT, MODE_COUNT }; @@ -89,8 +91,7 @@ struct SDParams { std::string input_path; std::string mask_path; std::string control_image_path; - - std::vector kontext_image_paths; + std::vector ref_image_paths; std::string prompt; std::string negative_prompt; @@ -156,6 +157,10 @@ void print_params(SDParams params) { printf(" init_img: %s\n", params.input_path.c_str()); printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" ref_images_paths:\n"); + for (auto& path : params.ref_image_paths) { + printf(" %s\n", path.c_str()); + }; printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); @@ -210,6 +215,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); + printf(" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); @@ -245,9 +251,8 @@ void print_usage(int argc, const char* argv[]) { printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); - printf(" --color Colors the logging tags according to level\n"); + printf(" --color colors the logging tags according to level\n"); printf(" -v, --verbose print extra info\n"); - printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n"); } void parse_args(int argc, const char** argv, SDParams& params) { @@ -632,12 +637,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); - } else if (arg == "-ki" || arg == "--kontext-img") { + } else if (arg == "-r" || arg == "--ref-image") { if (++i >= argc) { invalid_arg = true; break; } - params.kontext_image_paths.push_back(argv[i]); + params.ref_image_paths.push_back(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -666,7 +671,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { } if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) { - fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n"); + fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n"); + print_usage(argc, argv); + exit(1); + } + + if (params.mode == EDIT && params.ref_image_paths.size() == 0) { + fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n"); print_usage(argc, argv); exit(1); } @@ -830,43 +841,12 @@ int main(int argc, const char* argv[]) { fprintf(stderr, "SVD support is broken, do not use it!!!\n"); return 1; } - bool vae_decode_only = true; - - std::vector kontext_imgs; - for (auto& path : params.kontext_image_paths) { - vae_decode_only = false; - int c = 0; - int width = 0; - int height = 0; - uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); - if (image_buffer == NULL) { - fprintf(stderr, "load image from '%s' failed\n", path.c_str()); - return 1; - } - if (c < 3) { - fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); - free(image_buffer); - return 1; - } - if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0\n"); - free(image_buffer); - return 1; - } - if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0\n"); - free(image_buffer); - return 1; - } - kontext_imgs.push_back({(uint32_t)width, - (uint32_t)height, - 3, - image_buffer}); - } + bool vae_decode_only = true; uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; uint8_t* mask_image_buffer = NULL; + std::vector ref_images; if (params.mode == IMG2IMG || params.mode == IMG2VID) { vae_decode_only = false; @@ -918,6 +898,37 @@ int main(int argc, const char* argv[]) { free(input_image_buffer); input_image_buffer = resized_image_buffer; } + } else if (params.mode == EDIT) { + vae_decode_only = false; + for (auto& path : params.ref_image_paths) { + int c = 0; + int width = 0; + int height = 0; + uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return 1; + } + if (c < 3) { + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); + free(image_buffer); + return 1; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + ref_images.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + } } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), @@ -1004,13 +1015,12 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), - kontext_imgs.data(), kontext_imgs.size(), params.skip_layers.data(), params.skip_layers.size(), params.slg_scale, params.skip_layer_start, params.skip_layer_end); - } else { + } else if (params.mode == IMG2IMG || params.mode == IMG2VID) { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, 3, @@ -1074,13 +1084,38 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), - kontext_imgs.data(), kontext_imgs.size(), params.skip_layers.data(), params.skip_layers.size(), params.slg_scale, params.skip_layer_start, params.skip_layer_end); } + } else { // EDIT + results = edit(sd_ctx, + ref_images.data(), + ref_images.size(), + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + params.cfg_scale, + params.guidance, + params.eta, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.strength, + params.seed, + params.batch_count, + control_image, + params.control_strength, + params.style_ratio, + params.normalize_input, + params.skip_layers.data(), + params.skip_layers.size(), + params.slg_scale, + params.skip_layer_start, + params.skip_layer_end); } if (results == NULL) { @@ -1118,11 +1153,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1130,7 +1165,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1142,7 +1177,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if (is_jpg) { + if(is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(params, params.seed + i).c_str()); printf("save result JPEG image to '%s'\n", final_image_path.c_str()); @@ -1160,4 +1195,4 @@ int main(int argc, const char* argv[]) { free(input_image_buffer); return 0; -} +} \ No newline at end of file diff --git a/flux.hpp b/flux.hpp index 897c42f79..289e8554f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -570,52 +570,81 @@ namespace Flux { } // Generate IDs for image patches and text - std::vector> gen_ids(int h, int w, int patch_size, int index = 0) { + std::vector> gen_txt_ids(int bs, int context_len) { + return std::vector>(bs * context_len, std::vector(3, 0.0)); + } + + std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(3, (float)index)); + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); - std::vector row_ids = linspace(0, h_len - 1, h_len); - std::vector col_ids = linspace(0, w_len - 1, w_len); + std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); + std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][0] = index; img_ids[i * w_len + j][1] = row_ids[i]; img_ids[i * w_len + j][2] = col_ids[j]; } } - return img_ids; - } - - // Generate positional embeddings - std::vector gen_pe(std::vector imgs, struct ggml_tensor* context, int patch_size, int theta, const std::vector& axes_dim) { - int context_len = context->ne[1]; - int bs = imgs[0]->ne[3]; - - std::vector> img_ids; - for (int i = 0; i < imgs.size(); i++) { - auto x = imgs[i]; - if (x) { - int h = x->ne[1]; - int w = x->ne[0]; - std::vector> img_ids_i = gen_ids(h, w, patch_size, i); - img_ids.insert(img_ids.end(), img_ids_i.begin(), img_ids_i.end()); + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < img_ids.size(); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; } } + return img_ids_repeated; + } - std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); - std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); + std::vector> concat_ids(const std::vector>& a, + const std::vector>& b, + int bs) { + size_t a_len = a.size() / bs; + size_t b_len = b.size() / bs; + std::vector> ids(a.size() + b.size(), std::vector(3)); for (int i = 0; i < bs; ++i) { - for (int j = 0; j < context_len; ++j) { - ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; + for (int j = 0; j < a_len; ++j) { + ids[i * (a_len + b_len) + j] = a[i * a_len + j]; } - for (int j = 0; j < img_ids.size(); ++j) { - ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids[j]; + for (int j = 0; j < b_len; ++j) { + ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j]; } } + return ids; + } + std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents) { + auto txt_ids = gen_txt_ids(bs, context_len); + auto img_ids = gen_img_ids(h, w, patch_size, bs); + + auto ids = concat_ids(txt_ids, img_ids, bs); + uint64_t curr_h_offset = 0; + uint64_t curr_w_offset = 0; + for (ggml_tensor* ref : ref_latents) { + uint64_t h_offset = 0; + uint64_t w_offset = 0; + if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { + w_offset = curr_w_offset; + } else { + h_offset = curr_h_offset; + } + + auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset); + ids = concat_ids(ids, ref_ids, bs); + + curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); + curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); + } + return ids; + } + + // Generate positional embeddings + std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents, int theta, const std::vector& axes_dim) { + std::vector> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents); std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size(); int num_axes = axes_dim.size(); @@ -732,7 +761,7 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, - std::vector skip_layers = std::vector()) { + std::vector skip_layers = {}) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); @@ -791,15 +820,31 @@ namespace Flux { return img; } + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t patch_size = 2; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + return img; + } + struct ggml_tensor* forward(struct ggml_context* ctx, - std::vector imgs, + struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + std::vector skip_layers = {}) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps @@ -810,46 +855,41 @@ namespace Flux { // pe: (L, d_head/2, 2, 2) // return: (N, C, H, W) - auto x = imgs[0]; GGML_ASSERT(x->ne[3] == 1); int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t C = x->ne[2]; int64_t patch_size = 2; - int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size; - int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + + auto img = process_img(ctx, x); + uint64_t img_tokens = img->ne[1]; - // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - ggml_tensor* img = NULL; // [N, h*w, C * patch_size * patch_size] - int64_t patchified_img_size; - for (auto& x : imgs) { - int pad_h = (patch_size - x->ne[0] % patch_size) % patch_size; - int pad_w = (patch_size - x->ne[1] % patch_size) % patch_size; - ggml_tensor* pad_x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); - pad_x = patchify(ctx, pad_x, patch_size); - if (img) { - img = ggml_concat(ctx, img, pad_x, 1); - } else { - img = pad_x; - patchified_img_size = img->ne[1]; - } - } if (c_concat != NULL) { ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); - mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); - - masked = patchify(ctx, masked, patch_size); - mask = patchify(ctx, mask, patch_size); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); } - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] - out = ggml_cont(ctx, ggml_view_2d(ctx, out, out->ne[0], patchified_img_size, out->nb[1], 0)); + if (ref_latents.size() > 0) { + for (ggml_tensor* ref : ref_latents) { + ref = process_img(ctx, ref); + img = ggml_concat(ctx, img, ref, 1); + } + } + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, num_tokens, C * patch_size * patch_size] + if (out->ne[1] > img_tokens) { + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + } // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -928,8 +968,8 @@ namespace Flux { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - std::vector kontext_imgs = std::vector(), - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + std::vector skip_layers = std::vector()) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); @@ -938,20 +978,16 @@ namespace Flux { if (c_concat != NULL) { c_concat = to_backend(c_concat); } - for (auto& img : kontext_imgs) { - img = to_backend(img); - } - - y = to_backend(y); - + y = to_backend(y); timesteps = to_backend(timesteps); if (flux_params.guidance_embed) { guidance = to_backend(guidance); } - auto imgs = kontext_imgs; - imgs.insert(imgs.begin(), x); + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = to_backend(ref_latents[i]); + } - pe_vec = flux.gen_pe(imgs, context, 2, flux_params.theta, flux_params.axes_dim); + pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); @@ -961,13 +997,14 @@ namespace Flux { set_backend_tensor_data(pe, pe_vec.data()); struct ggml_tensor* out = flux.forward(compute_ctx, - imgs, + x, timesteps, context, c_concat, y, guidance, pe, + ref_latents, skip_layers); ggml_build_forward_expand(gf, out); @@ -982,17 +1019,17 @@ namespace Flux { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - std::vector kontext_imgs = std::vector(), - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs, skip_layers); + return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -1032,7 +1069,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, y, guidance, std::vector(), &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, {}, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4b807a5db..cf52c4f97 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,7 +48,8 @@ const char* sampling_methods_str[] = { "iPNDM_v", "LCM", "DDIM \"trailing\"", - "TCD"}; + "TCD" +}; /*================================================== Helper Functions ================================================*/ @@ -617,7 +618,7 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, std::vector(), &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -681,7 +682,7 @@ class StableDiffusionGGML { float curr_multiplier = kv.second; lora_state_diff[lora_name] -= curr_multiplier; } - + size_t rm = lora_state_diff.size() - lora_state.size(); if (rm != 0) { LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); @@ -799,12 +800,12 @@ class StableDiffusionGGML { const std::vector& sigmas, int start_merge_step, SDCondition id_cond, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - std::vector kontext_imgs = std::vector(), - ggml_tensor* noise_mask = NULL) { + std::vector ref_latents = {}, + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* noise_mask = nullptr) { LOG_DEBUG("Sample"); struct ggml_init_params params; size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); @@ -887,10 +888,10 @@ class StableDiffusionGGML { cond.c_concat, cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, - kontext_imgs, &out_cond); } else { diffusion_model->compute(n_threads, @@ -900,10 +901,10 @@ class StableDiffusionGGML { cond.c_concat, id_cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, - kontext_imgs, &out_cond); } @@ -921,10 +922,10 @@ class StableDiffusionGGML { uncond.c_concat, uncond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, - kontext_imgs, &out_uncond); negative_data = (float*)out_uncond->data; } @@ -942,10 +943,10 @@ class StableDiffusionGGML { cond.c_concat, cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, - kontext_imgs, &out_skip, NULL, skip_layers); @@ -1213,12 +1214,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, std::string input_id_images_path, - std::vector kontext_imgs = std::vector(), - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* masked_image = NULL) { + std::vector ref_latents, + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* masked_image = NULL) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1471,11 +1472,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sigmas, start_merge_step, id_cond, + ref_latents, skip_layers, slg_scale, skip_layer_start, skip_layer_end, - kontext_imgs, noise_mask); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); @@ -1545,8 +1546,6 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, const char* input_id_images_path_c_str, - sd_image_t* kontext_imgs, - int kontext_img_count, int* skip_layers = NULL, size_t skip_layers_count = 0, float slg_scale = 0, @@ -1605,22 +1604,6 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_version_is_inpaint(sd_ctx->sd->version)) { LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); } - std::vector kontext_latents = std::vector(); - if (kontext_imgs) { - for (int i = 0; i < kontext_img_count; i++) { - ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, kontext_imgs[i].width, kontext_imgs[i].height, 3, 1); - sd_image_to_tensor(kontext_imgs[i].data, img); - - ggml_tensor* latent = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); - latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); - } else { - latent = sd_ctx->sd->encode_first_stage(work_ctx, img); - } - kontext_latents.push_back(latent); - } - } sd_image_t* result_images = generate_image(sd_ctx, work_ctx, @@ -1642,7 +1625,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, - kontext_latents, + {}, skip_layers_vec, slg_scale, skip_layer_start, @@ -1676,8 +1659,6 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, const char* input_id_images_path_c_str, - sd_image_t* kontext_imgs, - int kontext_img_count, int* skip_layers = NULL, size_t skip_layers_count = 0, float slg_scale = 0, @@ -1793,23 +1774,6 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } - std::vector kontext_latents = std::vector(); - if (kontext_imgs) { - for (int i = 0; i < kontext_img_count; i++) { - ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(kontext_imgs[i].data, img); - - ggml_tensor* latent = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); - latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); - } else { - latent = sd_ctx->sd->encode_first_stage(work_ctx, img); - } - kontext_latents.push_back(latent); - } - } - print_ggml_tensor(init_latent, true); size_t t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); @@ -1842,7 +1806,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, - kontext_latents, + {}, skip_layers_vec, slg_scale, skip_layer_start, @@ -1988,3 +1952,132 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, return result_images; } + + +sd_image_t* edit(sd_ctx_t* sd_ctx, + sd_image_t* ref_images, + int ref_images_count, + const char* prompt_c_str, + const char* negative_prompt_c_str, + int clip_skip, + float cfg_scale, + float guidance, + float eta, + int width, + int height, + sample_method_t sample_method, + int sample_steps, + float strength, + int64_t seed, + int batch_count, + const sd_image_t* control_cond, + float control_strength, + float style_ratio, + bool normalize_input, + int* skip_layers = NULL, + size_t skip_layers_count = 0, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { + std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); + LOG_DEBUG("edit %dx%d", width, height); + if (sd_ctx == NULL) { + return NULL; + } + if (ref_images_count <= 0) { + LOG_ERROR("ref images count should > 0"); + return NULL; + } + + struct ggml_init_params params; + params.mem_size = static_cast(30 * 1024 * 1024); // 10 MB + params.mem_size += width * height * 3 * sizeof(float) * 3 * ref_images_count; + params.mem_size *= batch_count; + params.mem_buffer = NULL; + params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); + + struct ggml_context* work_ctx = ggml_init(params); + if (!work_ctx) { + LOG_ERROR("ggml_init() failed"); + return NULL; + } + + if (seed < 0) { + srand((int)time(NULL)); + seed = rand(); + } + sd_ctx->sd->rng->manual_seed(seed); + + int C = 4; + if (sd_version_is_sd3(sd_ctx->sd->version)) { + C = 16; + } else if (sd_version_is_flux(sd_ctx->sd->version)) { + C = 16; + } + int W = width / 8; + int H = height / 8; + ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); + if (sd_version_is_sd3(sd_ctx->sd->version)) { + ggml_set_f32(init_latent, 0.0609f); + } else if (sd_version_is_flux(sd_ctx->sd->version)) { + ggml_set_f32(init_latent, 0.1159f); + } else { + ggml_set_f32(init_latent, 0.f); + } + + size_t t0 = ggml_time_ms(); + + std::vector ref_latents; + for (int i = 0; i < ref_images_count; i++) { + ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, ref_images[i].width, ref_images[i].height, 3, 1); + sd_image_to_tensor(ref_images[i].data, img); + + ggml_tensor* latent = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); + latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + latent = sd_ctx->sd->encode_first_stage(work_ctx, img); + } + ref_latents.push_back(latent); + } + + size_t t1 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); + + sd_image_t* result_images = generate_image(sd_ctx, + work_ctx, + init_latent, + prompt_c_str, + negative_prompt_c_str, + clip_skip, + cfg_scale, + guidance, + eta, + width, + height, + sample_method, + sigmas, + seed, + batch_count, + control_cond, + control_strength, + style_ratio, + normalize_input, + "", + ref_latents, + skip_layers_vec, + slg_scale, + skip_layer_start, + skip_layer_end, + NULL); + + size_t t2 = ggml_time_ms(); + + LOG_INFO("edit completed in %.2fs", (t2 - t0) * 1.0f / 1000); + + return result_images; +} \ No newline at end of file diff --git a/stable-diffusion.h b/stable-diffusion.h index e59b2a30d..804dff71f 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -172,8 +172,6 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, float style_strength, bool normalize_input, const char* input_id_images_path, - sd_image_t* kontext_imgs, - int kontext_img_count, int* skip_layers, size_t skip_layers_count, float slg_scale, @@ -201,8 +199,6 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, float style_strength, bool normalize_input, const char* input_id_images_path, - sd_image_t* kontext_imgs, - int kontext_img_count, int* skip_layers, size_t skip_layers_count, float slg_scale, @@ -224,6 +220,32 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, float strength, int64_t seed); +SD_API sd_image_t* edit(sd_ctx_t* sd_ctx, + sd_image_t* ref_images, + int ref_images_count, + const char* prompt, + const char* negative_prompt, + int clip_skip, + float cfg_scale, + float guidance, + float eta, + int width, + int height, + enum sample_method_t sample_method, + int sample_steps, + float strength, + int64_t seed, + int batch_count, + const sd_image_t* control_cond, + float control_strength, + float style_strength, + bool normalize_input, + int* skip_layers, + size_t skip_layers_count, + float slg_scale, + float skip_layer_start, + float skip_layer_end); + typedef struct upscaler_ctx_t upscaler_ctx_t; SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,