@@ -1382,14 +1382,24 @@ class StableDiffusionGGML {
13821382 ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
13831383 int64_t t0 = ggml_time_ms ();
13841384 ggml_tensor* result = NULL ;
1385+ int W = x->ne [0 ] / 8 ;
1386+ int H = x->ne [1 ] / 8 ;
1387+ if (vae_tiling && !decode_video) {
1388+ // TODO wan2.2 vae support?
1389+ int C = sd_version_is_dit (version) ? 16 : 4 ;
1390+ if (!use_tiny_autoencoder) {
1391+ C *= 2 ;
1392+ }
1393+ result = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, W, H, C, x->ne [3 ]);
1394+ }
13851395 // TODO: args instead of env for tile size / overlap?
13861396 if (!use_tiny_autoencoder) {
13871397 float tile_overlap = 0 .5f ;
13881398 int tile_size_x = 32 ;
13891399 int tile_size_y = 32 ;
13901400
13911401 get_vae_tile_overlap (tile_overlap);
1392- get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x-> ne [ 0 ] / 8 , x-> ne [ 1 ] / 8 );
1402+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, W, H );
13931403
13941404 // TODO: also use an arg for this one?
13951405 // multiply tile size for encode to keep the compute buffer size consistent
@@ -1399,7 +1409,7 @@ class StableDiffusionGGML {
13991409 process_vae_input_tensor (x);
14001410 if (vae_tiling && !decode_video) {
14011411 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1402- first_stage_model->compute (n_threads, in, true , &out, NULL );
1412+ first_stage_model->compute (n_threads, in, false , &out, work_ctx );
14031413 };
14041414 sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap, on_tiling);
14051415 } else {
@@ -1410,7 +1420,7 @@ class StableDiffusionGGML {
14101420 if (vae_tiling && !decode_video) {
14111421 // split latent in 32x32 tiles and compute in several steps
14121422 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1413- tae_first_stage->compute (n_threads, in, true , &out, NULL );
1423+ tae_first_stage->compute (n_threads, in, false , &out, NULL );
14141424 };
14151425 sd_tiling (x, result, 8 , 64 , 0 .5f , on_tiling);
14161426 } else {
0 commit comments