@@ -737,62 +737,67 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
737737typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
738738
739739// Tiling
740- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
740+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
741741 output = ggml_set_f32 (output, 0 );
742742
743743 int input_width = (int )input->ne [0 ];
744744 int input_height = (int )input->ne [1 ];
745745 int output_width = (int )output->ne [0 ];
746746 int output_height = (int )output->ne [1 ];
747747
748- int input_tile_size, output_tile_size;
749- if (scaled_out) {
750- input_tile_size = tile_size;
751- output_tile_size = tile_size * scale;
752- } else {
753- input_tile_size = tile_size * scale;
754- output_tile_size = tile_size;
748+ GGML_ASSERT (input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
749+ GGML_ASSERT (input_width / output_width == scale || output_width / input_width == scale);
750+
751+ int small_width = output_width;
752+ int small_height = output_height;
753+
754+ bool big_out = output_width > input_width;
755+ if (big_out) {
756+ // Ex: decode
757+ small_width = input_width;
758+ small_height = input_height;
755759 }
756- int tile_overlap = (input_tile_size * tile_overlap_factor);
757- int non_tile_overlap = input_tile_size - tile_overlap;
758760
759- int num_tiles_x = (input_width - tile_overlap) / non_tile_overlap;
760- int overshoot_x = ((num_tiles_x + 1 ) * non_tile_overlap + tile_overlap) % input_width;
761+ int tile_overlap = (tile_size * tile_overlap_factor);
762+ int non_tile_overlap = tile_size - tile_overlap;
763+
764+ int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
765+ int overshoot_x = ((num_tiles_x + 1 ) * non_tile_overlap + tile_overlap) % small_width;
761766
762- if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (input_tile_size / 2 - tile_overlap))) {
767+ if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
763768 // if tiles don't fit perfectly using the desired overlap
764769 // and there is enough room to squeeze an extra tile without overlap becoming >0.5
765770 num_tiles_x++;
766771 }
767772
768- float tile_overlap_factor_x = (float )(input_tile_size * num_tiles_x - input_width ) / (float )(input_tile_size * (num_tiles_x - 1 ));
773+ float tile_overlap_factor_x = (float )(tile_size * num_tiles_x - small_width ) / (float )(tile_size * (num_tiles_x - 1 ));
769774 if (num_tiles_x <= 2 ) {
770- if (input_width <= input_tile_size ) {
775+ if (small_width <= tile_size ) {
771776 num_tiles_x = 1 ;
772777 tile_overlap_factor_x = 0 ;
773778 } else {
774779 num_tiles_x = 2 ;
775- tile_overlap_factor_x = (2 * input_tile_size - input_width ) / (float )input_tile_size ;
780+ tile_overlap_factor_x = (2 * tile_size - small_width ) / (float )tile_size ;
776781 }
777782 }
778783
779- int num_tiles_y = (input_height - tile_overlap) / non_tile_overlap;
780- int overshoot_y = ((num_tiles_y + 1 ) * non_tile_overlap + tile_overlap) % input_height ;
784+ int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
785+ int overshoot_y = ((num_tiles_y + 1 ) * non_tile_overlap + tile_overlap) % small_height ;
781786
782- if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (input_tile_size / 2 - tile_overlap))) {
787+ if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
783788 // if tiles don't fit perfectly using the desired overlap
784789 // and there is enough room to squeeze an extra tile without overlap becoming >0.5
785790 num_tiles_y++;
786791 }
787792
788- float tile_overlap_factor_y = (float )(input_tile_size * num_tiles_y - input_height ) / (float )(input_tile_size * (num_tiles_y - 1 ));
793+ float tile_overlap_factor_y = (float )(tile_size * num_tiles_y - small_height ) / (float )(tile_size * (num_tiles_y - 1 ));
789794 if (num_tiles_y <= 2 ) {
790- if (input_height <= input_tile_size ) {
795+ if (small_height <= tile_size ) {
791796 num_tiles_y = 1 ;
792797 tile_overlap_factor_y = 0 ;
793798 } else {
794799 num_tiles_y = 2 ;
795- tile_overlap_factor_y = (2 * input_tile_size - input_height ) / (float )input_tile_size ;
800+ tile_overlap_factor_y = (2 * tile_size - small_height ) / (float )tile_size ;
796801 }
797802 }
798803
@@ -801,11 +806,20 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
801806
802807 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
803808
804- int tile_overlap_x = (int32_t )(input_tile_size * tile_overlap_factor_x);
805- int non_tile_overlap_x = input_tile_size - tile_overlap_x;
809+ int tile_overlap_x = (int32_t )(tile_size * tile_overlap_factor_x);
810+ int non_tile_overlap_x = tile_size - tile_overlap_x;
806811
807- int tile_overlap_y = (int32_t )(input_tile_size * tile_overlap_factor_y);
808- int non_tile_overlap_y = input_tile_size - tile_overlap_y;
812+ int tile_overlap_y = (int32_t )(tile_size * tile_overlap_factor_y);
813+ int non_tile_overlap_y = tile_size - tile_overlap_y;
814+
815+ int input_tile_size = tile_size;
816+ int output_tile_size = tile_size;
817+
818+ if (big_out) {
819+ output_tile_size *= scale;
820+ } else {
821+ input_tile_size *= scale;
822+ }
809823
810824 struct ggml_init_params params = {};
811825 params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
@@ -826,37 +840,48 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
826840 // tiling
827841 ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne [2 ], 1 );
828842 ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne [2 ], 1 );
829- on_processing (input_tile, NULL , true );
830843 int num_tiles = num_tiles_x * num_tiles_y;
831844 LOG_INFO (" processing %i tiles" , num_tiles);
832- pretty_progress (1 , num_tiles, 0 .0f );
845+ pretty_progress (0 , num_tiles, 0 .0f );
833846 int tile_count = 1 ;
834847 bool last_y = false , last_x = false ;
835848 float last_time = 0 .0f ;
836- for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap_y) {
849+ for (int y = 0 ; y < small_height && !last_y; y += non_tile_overlap_y) {
837850 int dy = 0 ;
838- if (y + input_tile_size >= input_height ) {
851+ if (y + tile_size >= small_height ) {
839852 int _y = y;
840- y = input_height - input_tile_size ;
853+ y = small_height - tile_size ;
841854 dy = _y - y;
855+ if (big_out) {
856+ dy *= scale;
857+ }
842858 last_y = true ;
843859 }
844- for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap_x) {
860+ for (int x = 0 ; x < small_width && !last_x; x += non_tile_overlap_x) {
845861 int dx = 0 ;
846- if (x + input_tile_size >= input_width ) {
862+ if (x + tile_size >= small_width ) {
847863 int _x = x;
848- x = input_width - input_tile_size ;
864+ x = small_width - tile_size ;
849865 dx = _x - x;
866+ if (big_out) {
867+ dx *= scale;
868+ }
850869 last_x = true ;
851870 }
871+
872+ int x_in = big_out ? x : scale * x;
873+ int y_in = big_out ? y : scale * y;
874+ int x_out = big_out ? x * scale : x;
875+ int y_out = big_out ? y * scale : y;
876+
877+ int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
878+ int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;
879+
852880 int64_t t1 = ggml_time_ms ();
853- ggml_split_tensor_2d (input, input_tile, x, y );
881+ ggml_split_tensor_2d (input, input_tile, x_in, y_in );
854882 on_processing (input_tile, output_tile, false );
855- if (scaled_out) {
856- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
857- } else {
858- ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
859- }
883+ ggml_merge_tensor_2d (output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
884+
860885 int64_t t2 = ggml_time_ms ();
861886 last_time = (t2 - t1) / 1000 .0f ;
862887 pretty_progress (tile_count, num_tiles, last_time);
0 commit comments