@@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
474474 ggml_context* work_ctx,
475475 ggml_tensor* x,
476476 std::vector<float > sigmas,
477- std::shared_ptr<RNG> rng) {
477+ std::shared_ptr<RNG> rng,
478+ float eta) {
478479 size_t steps = sigmas.size () - 1 ;
479480 // sample_euler_ancestral
480481 switch (method) {
@@ -1014,6 +1015,8 @@ static void sample_k_diffusion(sample_method_t method,
10141015 // structure hides from the denoiser), and the sigmas are
10151016 // also needed to invert the behavior of CompVisDenoiser
10161017 // (k-diffusion's LMSDiscreteScheduler)
1018+ float beta_start = 0 .00085f ;
1019+ float beta_end = 0 .0120f ;
10171020 std::vector<double > alphas_cumprod;
10181021 std::vector<double > compvis_sigmas;
10191022
@@ -1023,21 +1026,41 @@ static void sample_k_diffusion(sample_method_t method,
10231026 alphas_cumprod[i] =
10241027 (i == 0 ? 1 .0f : alphas_cumprod[i - 1 ]) *
10251028 (1 .0f -
1026- std::pow (sqrtf (0 . 00085f ) +
1027- (sqrtf (0 . 0120f ) - sqrtf (0 . 00085f )) *
1029+ std::pow (sqrtf (beta_start ) +
1030+ (sqrtf (beta_end ) - sqrtf (beta_start )) *
10281031 ((float )i / (TIMESTEPS - 1 )), 2 ));
10291032 compvis_sigmas[i] =
10301033 std::sqrt ((1 - alphas_cumprod[i]) /
10311034 alphas_cumprod[i]);
10321035 }
1036+
1037+ struct ggml_tensor * pred_original_sample =
1038+ ggml_dup_tensor (work_ctx, x);
1039+ struct ggml_tensor * variance_noise =
1040+ ggml_dup_tensor (work_ctx, x);
1041+
10331042 for (int i = 0 ; i < steps; i++) {
10341043 // The "trailing" DDIM timestep, see S. Lin et al.,
10351044 // "Common Diffusion Noise Schedules and Sample Steps
10361045 // are Flawed", arXiv:2305.08891 [cs], p. 4, Table
1037- // 2. Most variables below follow Diffusers naming.
1046+ // 2. Most variables below follow Diffusers naming
1047+ //
1048+ // Diffuser naming vs. J. Song et al., "Denoising
1049+ // Diffusion Implicit Models", arXiv:2010.02502, p. 5,
1050+ // (12) and p. 16, (16) (<variable name> -> <name in
1051+ // paper>):
1052+ //
1053+ // - pred_noise_t -> epsilon_theta^(t)(x_t)
1054+ // - pred_original_sample -> f_theta^(t)(x_t) or x_0
1055+ // - std_dev_t -> sigma_t (not the LMS sigma)
1056+ // - eta -> eta (set to 0 at the moment)
1057+ // - pred_sample_direction -> "direction pointing to
1058+ // x_t"
1059+ // - pred_prev_sample -> "x_t-1"
10381060 int timestep =
10391061 roundf (TIMESTEPS -
10401062 i * ((float )TIMESTEPS / steps)) - 1 ;
1063+ // 1. get previous step value (=t-1)
10411064 int prev_timestep = timestep - TIMESTEPS / steps;
10421065 // The sigma here is chosen to cause the
10431066 // CompVisDenoiser to produce t = timestep
@@ -1066,51 +1089,53 @@ static void sample_k_diffusion(sample_method_t method,
10661089 }
10671090 else {
10681091 // For the subsequent steps after the first one,
1069- // at this point x = latents (pipeline) or x =
1070- // sample (scheduler), and needs to be prescaled
1071- // with x <- latents / c_in to compensate for
1072- // model() applying the scale c_in before the
1073- // U-net F_theta
1092+ // at this point x = latents or x = sample, and
1093+ // needs to be prescaled with x <- sample / c_in
1094+ // to compensate for model() applying the scale
1095+ // c_in before the U-net F_theta
10741096 float * vec_x = (float *)x->data ;
10751097 for (int j = 0 ; j < ggml_nelements (x); j++) {
10761098 vec_x[j] *= std::sqrt (sigma * sigma + 1 );
10771099 }
10781100 }
1079- // Note model() is the D(x, sigma) as defined in
1080- // T. Karras et al., arXiv:2206.00364, p. 3, Table 1
1081- // and p. 8 (7)
1082- struct ggml_tensor * noise_pred =
1101+ // Note (also noise_pred in Diffuser's pipeline)
1102+ // model_output = model() is the D(x, sigma) as
1103+ // defined in T. Karras et al., arXiv:2206.00364,
1104+ // p. 3, Table 1 and p. 8 (7), compare also p. 38
1105+ // (226) therein.
1106+ struct ggml_tensor * model_output =
10831107 model (x, sigma, i + 1 );
1084- // Here noise_pred is still the k-diffusion denoiser
1108+ // Here model_output is still the k-diffusion denoiser
10851109 // output, not the U-net output F_theta(c_in(sigma) x;
10861110 // ...) in Karras et al. (2022), whereas Diffusers'
1087- // noise_pred is F_theta(...). Recover the actual
1088- // noise_pred , which is also referred to as the
1111+ // model_output is F_theta(...). Recover the actual
1112+ // model_output , which is also referred to as the
10891113 // "Karras ODE derivative" d or d_cur in several
10901114 // samplers above.
10911115 {
10921116 float * vec_x = (float *)x->data ;
1093- float * vec_noise_pred = (float *)noise_pred->data ;
1117+ float * vec_model_output =
1118+ (float *)model_output->data ;
10941119 for (int j = 0 ; j < ggml_nelements (x); j++) {
1095- vec_noise_pred [j] =
1096- (vec_x[j] - vec_noise_pred [j]) *
1120+ vec_model_output [j] =
1121+ (vec_x[j] - vec_model_output [j]) *
10971122 (1 / sigma);
10981123 }
10991124 }
11001125 // 2. compute alphas, betas
11011126 float alpha_prod_t = alphas_cumprod[timestep];
1102- // Note final_alpha_cumprod = alphas_cumprod[0]
1127+ // Note final_alpha_cumprod = alphas_cumprod[0] due to
1128+ // trailing timestep spacing
11031129 float alpha_prod_t_prev = prev_timestep >= 0 ?
11041130 alphas_cumprod[prev_timestep] : alphas_cumprod[0 ];
11051131 float beta_prod_t = 1 - alpha_prod_t ;
11061132 // 3. compute predicted original sample from predicted
11071133 // noise also called "predicted x_0" of formula (12)
11081134 // from https://arxiv.org/pdf/2010.02502.pdf
1109- struct ggml_tensor * pred_original_sample =
1110- ggml_dup_tensor (work_ctx, x);
11111135 {
11121136 float * vec_x = (float *)x->data ;
1113- float * vec_noise_pred = (float *)noise_pred->data ;
1137+ float * vec_model_output =
1138+ (float *)model_output->data ;
11141139 float * vec_pred_original_sample =
11151140 (float *)pred_original_sample->data ;
11161141 // Note the substitution of latents or sample = x
@@ -1119,12 +1144,12 @@ static void sample_k_diffusion(sample_method_t method,
11191144 vec_pred_original_sample[j] =
11201145 (vec_x[j] / std::sqrt (sigma * sigma + 1 ) -
11211146 std::sqrt (beta_prod_t ) *
1122- vec_noise_pred [j]) *
1147+ vec_model_output [j]) *
11231148 (1 / std::sqrt (alpha_prod_t ));
11241149 }
11251150 }
11261151 // Assuming the "epsilon" prediction type, where below
1127- // pred_epsilon = noise_pred is inserted, and is not
1152+ // pred_epsilon = model_output is inserted, and is not
11281153 // defined/copied explicitly.
11291154 //
11301155 // 5. compute variance: "sigma_t(eta)" -> see formula
@@ -1135,34 +1160,35 @@ static void sample_k_diffusion(sample_method_t method,
11351160 float beta_prod_t_prev = 1 - alpha_prod_t_prev;
11361161 float variance = (beta_prod_t_prev / beta_prod_t ) *
11371162 (1 - alpha_prod_t / alpha_prod_t_prev);
1138- float std_dev_t = 0 * std::sqrt (variance);
1163+ float std_dev_t = eta * std::sqrt (variance);
11391164 // 6. compute "direction pointing to x_t" of formula
11401165 // (12) from https://arxiv.org/pdf/2010.02502.pdf
1141- struct ggml_tensor * pred_sample_direction =
1142- ggml_dup_tensor (work_ctx, noise_pred);
1143- {
1144- float * vec_noise_pred = (float *)noise_pred->data ;
1145- float * vec_pred_sample_direction =
1146- (float *)pred_sample_direction->data ;
1147- for (int j = 0 ; j < ggml_nelements (x); j++) {
1148- vec_pred_sample_direction[j] =
1149- std::sqrt (1 - alpha_prod_t_prev -
1150- std::pow (std_dev_t , 2 )) *
1151- vec_noise_pred[j];
1152- }
1153- }
11541166 // 7. compute x_t without "random noise" of formula
11551167 // (12) from https://arxiv.org/pdf/2010.02502.pdf
11561168 {
1169+ float * vec_model_output = (float *)model_output->data ;
11571170 float * vec_pred_original_sample =
11581171 (float *)pred_original_sample->data ;
1159- float * vec_pred_sample_direction =
1160- (float *)pred_sample_direction->data ;
11611172 float * vec_x = (float *)x->data ;
11621173 for (int j = 0 ; j < ggml_nelements (x); j++) {
1174+ // Two step inner loop without an explicit
1175+ // tensor
1176+ float pred_sample_direction =
1177+ std::sqrt (1 - alpha_prod_t_prev -
1178+ std::pow (std_dev_t , 2 )) *
1179+ vec_model_output[j];
11631180 vec_x[j] = std::sqrt (alpha_prod_t_prev) *
11641181 vec_pred_original_sample[j] +
1165- vec_pred_sample_direction[j];
1182+ pred_sample_direction;
1183+ }
1184+ }
1185+ if (eta > 0 ) {
1186+ ggml_tensor_set_f32_randn (variance_noise, rng);
1187+ float * vec_variance_noise =
1188+ (float *)variance_noise->data ;
1189+ float * vec_x = (float *)x->data ;
1190+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1191+ vec_x[j] += std_dev_t * vec_variance_noise[j];
11661192 }
11671193 }
11681194 // See the note above: x = latents or sample here, and
@@ -1173,6 +1199,174 @@ static void sample_k_diffusion(sample_method_t method,
11731199 // factor c_in.
11741200 }
11751201 } break ;
1202+ case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
1203+ // Trajectory Consistency Distillation
1204+ {
1205+ float beta_start = 0 .00085f ;
1206+ float beta_end = 0 .0120f ;
1207+ std::vector<double > alphas_cumprod;
1208+ std::vector<double > compvis_sigmas;
1209+
1210+ alphas_cumprod.reserve (TIMESTEPS);
1211+ compvis_sigmas.reserve (TIMESTEPS);
1212+ for (int i = 0 ; i < TIMESTEPS; i++) {
1213+ alphas_cumprod[i] =
1214+ (i == 0 ? 1 .0f : alphas_cumprod[i - 1 ]) *
1215+ (1 .0f -
1216+ std::pow (sqrtf (beta_start) +
1217+ (sqrtf (beta_end) - sqrtf (beta_start)) *
1218+ ((float )i / (TIMESTEPS - 1 )), 2 ));
1219+ compvis_sigmas[i] =
1220+ std::sqrt ((1 - alphas_cumprod[i]) /
1221+ alphas_cumprod[i]);
1222+ }
1223+ int original_steps = 50 ;
1224+
1225+ struct ggml_tensor * pred_original_sample =
1226+ ggml_dup_tensor (work_ctx, x);
1227+ struct ggml_tensor * noise =
1228+ ggml_dup_tensor (work_ctx, x);
1229+
1230+ for (int i = 0 ; i < steps; i++) {
1231+ // Analytic form for TCD timesteps
1232+ int timestep = TIMESTEPS - 1 -
1233+ (TIMESTEPS / original_steps) *
1234+ (int )floor (i * ((float )original_steps / steps));
1235+ // 1. get previous step value
1236+ int prev_timestep = i >= steps - 1 ? 0 :
1237+ TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
1238+ (int )floor ((i + 1 ) *
1239+ ((float )original_steps / steps));
1240+ // Here timestep_s is tau_n' in Algorithm 4. The _s
1241+ // notation appears to be that from DPM-Solver, C. Lu,
1242+ // arXiv:2206.00927 [cs.LG], but this notation is not
1243+ // continued in Algorithm 4, where _n' is used.
1244+ int timestep_s =
1245+ (int )floor ((1 - eta) * prev_timestep);
1246+ // Begin k-diffusion specific workaround for
1247+ // evaluating F_theta(x; ...) from D(x, sigma), same
1248+ // as in DDIM (and see there for detailed comments)
1249+ float sigma = compvis_sigmas[timestep];
1250+ if (i == 0 ) {
1251+ float * vec_x = (float *)x->data ;
1252+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1253+ vec_x[j] *= std::sqrt (sigma * sigma + 1 ) /
1254+ sigma;
1255+ }
1256+ }
1257+ else {
1258+ float * vec_x = (float *)x->data ;
1259+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1260+ vec_x[j] *= std::sqrt (sigma * sigma + 1 );
1261+ }
1262+ }
1263+ struct ggml_tensor * model_output =
1264+ model (x, sigma, i + 1 );
1265+ {
1266+ float * vec_x = (float *)x->data ;
1267+ float * vec_model_output =
1268+ (float *)model_output->data ;
1269+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1270+ vec_model_output[j] =
1271+ (vec_x[j] - vec_model_output[j]) *
1272+ (1 / sigma);
1273+ }
1274+ }
1275+ // 2. compute alphas, betas
1276+ //
1277+ // When comparing TCD with DDPM/DDIM note that Zheng
1278+ // et al. (2024) follows the DPM-Solver notation for
1279+ // alpha. One can find the following comment in the
1280+ // original DPM-Solver code
1281+ // (https://github.com/LuChengTHU/dpm-solver/):
1282+ // "**Important**: Please pay special attention for
1283+ // the args for `alphas_cumprod`: The `alphas_cumprod`
1284+ // is the \hat{alpha_n} arrays in the notations of
1285+ // DDPM. [...] Therefore, the notation \hat{alpha_n}
1286+ // is different from the notation alpha_t in
1287+ // DPM-Solver. In fact, we have alpha_{t_n} =
1288+ // \sqrt{\hat{alpha_n}}, [...]"
1289+ float alpha_prod_t = alphas_cumprod[timestep];
1290+ float beta_prod_t = 1 - alpha_prod_t ;
1291+ // Note final_alpha_cumprod = alphas_cumprod[0] since
1292+ // TCD is always "trailing"
1293+ float alpha_prod_t_prev = prev_timestep >= 0 ?
1294+ alphas_cumprod[prev_timestep] : alphas_cumprod[0 ];
1295+ // The subscript _s are the only portion in this
1296+ // section (2) unique to TCD
1297+ float alpha_prod_s = alphas_cumprod[timestep_s];
1298+ float beta_prod_s = 1 - alpha_prod_s;
1299+ // 3. Compute the predicted noised sample x_s based on
1300+ // the model parameterization
1301+ //
1302+ // This section is also exactly the same as DDIM
1303+ {
1304+ float * vec_x = (float *)x->data ;
1305+ float * vec_model_output =
1306+ (float *)model_output->data ;
1307+ float * vec_pred_original_sample =
1308+ (float *)pred_original_sample->data ;
1309+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1310+ vec_pred_original_sample[j] =
1311+ (vec_x[j] / std::sqrt (sigma * sigma + 1 ) -
1312+ std::sqrt (beta_prod_t ) *
1313+ vec_model_output[j]) *
1314+ (1 / std::sqrt (alpha_prod_t ));
1315+ }
1316+ }
1317+ // This consistency function step can be difficult to
1318+ // decipher from Algorithm 4, as it involves a
1319+ // difficult notation ("|->"). In Diffusers it is
1320+ // borrowed verbatim (with the same comments below for
1321+ // step (4)) from LCMScheduler's noise injection step,
1322+ // compare in S. Luo et al., arXiv:2310.04378 p. 14,
1323+ // Algorithm 3.
1324+ {
1325+ float * vec_pred_original_sample =
1326+ (float *)pred_original_sample->data ;
1327+ float * vec_model_output =
1328+ (float *)model_output->data ;
1329+ float * vec_x = (float *)x->data ;
1330+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1331+ // Substituting x = pred_noised_sample and
1332+ // pred_epsilon = model_output
1333+ vec_x[j] =
1334+ std::sqrt (alpha_prod_s) *
1335+ vec_pred_original_sample[j] +
1336+ std::sqrt (beta_prod_s) *
1337+ vec_model_output[j];
1338+ }
1339+ }
1340+ // 4. Sample and inject noise z ~ N(0, I) for
1341+ // MultiStep Inference Noise is not used on the final
1342+ // timestep of the timestep schedule. This also means
1343+ // that noise is not used for one-step sampling. Eta
1344+ // (referred to as "gamma" in the paper) was
1345+ // introduced to control the stochasticity in every
1346+ // step. When eta = 0, it represents deterministic
1347+ // sampling, whereas eta = 1 indicates full stochastic
1348+ // sampling.
1349+ if (eta > 0 && i != steps - 1 ) {
1350+ // In this case, x is still pred_noised_sample,
1351+ // continue in-place
1352+ ggml_tensor_set_f32_randn (noise, rng);
1353+ float * vec_x = (float *)x->data ;
1354+ float * vec_noise = (float *)noise->data ;
1355+ for (int j = 0 ; j < ggml_nelements (x); j++) {
1356+ // Corresponding to (35) in Zheng et
1357+ // al. (2024), substituting x =
1358+ // pred_noised_sample
1359+ vec_x[j] =
1360+ std::sqrt (alpha_prod_t_prev /
1361+ alpha_prod_s) *
1362+ vec_x[j] +
1363+ std::sqrt (1 - alpha_prod_t_prev /
1364+ alpha_prod_s) *
1365+ vec_noise[j];
1366+ }
1367+ }
1368+ }
1369+ } break ;
11761370
11771371 default :
11781372 LOG_ERROR (" Attempting to sample with nonexisting sample method %i" , method);
0 commit comments