@@ -1370,14 +1370,24 @@ class StableDiffusionGGML {
13701370 ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
13711371 int64_t t0 = ggml_time_ms ();
13721372 ggml_tensor* result = NULL ;
1373+ int W = x->ne [0 ] / 8 ;
1374+ int H = x->ne [1 ] / 8 ;
1375+ if (vae_tiling && !decode_video) {
1376+ // TODO wan2.2 vae support?
1377+ int C = sd_version_is_dit (version) ? 16 : 4 ;
1378+ if (!use_tiny_autoencoder) {
1379+ C *= 2 ;
1380+ }
1381+ result = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, W, H, C, x->ne [3 ]);
1382+ }
13731383 // TODO: args instead of env for tile size / overlap?
13741384 if (!use_tiny_autoencoder) {
13751385 float tile_overlap = 0 .5f ;
13761386 int tile_size_x = 32 ;
13771387 int tile_size_y = 32 ;
13781388
13791389 get_vae_tile_overlap (tile_overlap);
1380- get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x-> ne [ 0 ] / 8 , x-> ne [ 1 ] / 8 );
1390+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, W, H );
13811391
13821392 // TODO: also use an arg for this one?
13831393 // multiply tile size for encode to keep the compute buffer size consistent
@@ -1387,7 +1397,7 @@ class StableDiffusionGGML {
13871397 process_vae_input_tensor (x);
13881398 if (vae_tiling && !decode_video) {
13891399 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1390- first_stage_model->compute (n_threads, in, true , &out, NULL );
1400+ first_stage_model->compute (n_threads, in, false , &out, work_ctx );
13911401 };
13921402 sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap, on_tiling);
13931403 } else {
@@ -1398,7 +1408,7 @@ class StableDiffusionGGML {
13981408 if (vae_tiling && !decode_video) {
13991409 // split latent in 32x32 tiles and compute in several steps
14001410 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1401- tae_first_stage->compute (n_threads, in, true , &out, NULL );
1411+ tae_first_stage->compute (n_threads, in, false , &out, NULL );
14021412 };
14031413 sd_tiling (x, result, 8 , 64 , 0 .5f , on_tiling);
14041414 } else {
0 commit comments