@@ -107,10 +107,54 @@ def __init__(
107107
108108 self .quant_conv = nn .Conv2d (2 * latent_channels , 2 * latent_channels , 1 )
109109 self .post_quant_conv = nn .Conv2d (latent_channels , latent_channels , 1 )
110+
111+ self .use_slicing = False
112+ self .use_tiling = False
113+
114+ # only relevant if vae tiling is enabled
115+ self .tile_sample_min_size = self .config .sample_size
116+ sample_size = (
117+ self .config .sample_size [0 ]
118+ if isinstance (self .config .sample_size , (list , tuple ))
119+ else self .config .sample_size
120+ )
121+ self .tile_latent_min_size = int (sample_size / (2 ** (len (self .block_out_channels ) - 1 )))
122+ self .tile_overlap_factor = 0.25
123+
124+ def enable_tiling (self , use_tiling : bool = True ):
125+ r"""
126+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
127+ compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
128+ the processing of larger images.
129+ """
130+ self .use_tiling = use_tiling
131+
132+ def disable_tiling (self ):
133+ r"""
134+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
135+ computing decoding in one step.
136+ """
137+ self .enable_tiling (False )
138+
139+ def enable_slicing (self ):
140+ r"""
141+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143+ """
144+ self .use_slicing = True
145+
146+ def disable_slicing (self ):
147+ r"""
148+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
149+ decoding in one step.
150+ """
110151 self .use_slicing = False
111152
112153 @apply_forward_hook
113154 def encode (self , x : torch .FloatTensor , return_dict : bool = True ) -> AutoencoderKLOutput :
155+ if self .use_tiling and (x .shape [- 1 ] > self .tile_sample_min_size or x .shape [- 2 ] > self .tile_sample_min_size ):
156+ return self .tiled_encode (x , return_dict = return_dict )
157+
114158 h = self .encoder (x )
115159 moments = self .quant_conv (h )
116160 posterior = DiagonalGaussianDistribution (moments )
@@ -121,6 +165,9 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK
121165 return AutoencoderKLOutput (latent_dist = posterior )
122166
123167 def _decode (self , z : torch .FloatTensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .FloatTensor ]:
168+ if self .use_tiling and (z .shape [- 1 ] > self .tile_latent_min_size or z .shape [- 2 ] > self .tile_latent_min_size ):
169+ return self .tiled_decode (z , return_dict = return_dict )
170+
124171 z = self .post_quant_conv (z )
125172 dec = self .decoder (z )
126173
@@ -129,22 +176,6 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
129176
130177 return DecoderOutput (sample = dec )
131178
132- def enable_slicing (self ):
133- r"""
134- Enable sliced VAE decoding.
135-
136- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
137- steps. This is useful to save some memory and allow larger batch sizes.
138- """
139- self .use_slicing = True
140-
141- def disable_slicing (self ):
142- r"""
143- Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
144- decoding in one step.
145- """
146- self .use_slicing = False
147-
148179 @apply_forward_hook
149180 def decode (self , z : torch .FloatTensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .FloatTensor ]:
150181 if self .use_slicing and z .shape [0 ] > 1 :
@@ -158,6 +189,108 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
158189
159190 return DecoderOutput (sample = decoded )
160191
192+ def blend_v (self , a , b , blend_extent ):
193+ for y in range (blend_extent ):
194+ b [:, :, y , :] = a [:, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, y , :] * (y / blend_extent )
195+ return b
196+
197+ def blend_h (self , a , b , blend_extent ):
198+ for x in range (blend_extent ):
199+ b [:, :, :, x ] = a [:, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, x ] * (x / blend_extent )
200+ return b
201+
202+ def tiled_encode (self , x : torch .FloatTensor , return_dict : bool = True ) -> AutoencoderKLOutput :
203+ r"""Encode a batch of images using a tiled encoder.
204+ Args:
205+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
206+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
207+ different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
208+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
209+ look of the output, but they should be much less noticeable.
210+ x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
211+ Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
212+ """
213+ overlap_size = int (self .tile_sample_min_size * (1 - self .tile_overlap_factor ))
214+ blend_extent = int (self .tile_latent_min_size * self .tile_overlap_factor )
215+ row_limit = self .tile_latent_min_size - blend_extent
216+
217+ # Split the image into 512x512 tiles and encode them separately.
218+ rows = []
219+ for i in range (0 , x .shape [2 ], overlap_size ):
220+ row = []
221+ for j in range (0 , x .shape [3 ], overlap_size ):
222+ tile = x [:, :, i : i + self .tile_sample_min_size , j : j + self .tile_sample_min_size ]
223+ tile = self .encoder (tile )
224+ tile = self .quant_conv (tile )
225+ row .append (tile )
226+ rows .append (row )
227+ result_rows = []
228+ for i , row in enumerate (rows ):
229+ result_row = []
230+ for j , tile in enumerate (row ):
231+ # blend the above tile and the left tile
232+ # to the current tile and add the current tile to the result row
233+ if i > 0 :
234+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent )
235+ if j > 0 :
236+ tile = self .blend_h (row [j - 1 ], tile , blend_extent )
237+ result_row .append (tile [:, :, :row_limit , :row_limit ])
238+ result_rows .append (torch .cat (result_row , dim = 3 ))
239+
240+ moments = torch .cat (result_rows , dim = 2 )
241+ posterior = DiagonalGaussianDistribution (moments )
242+
243+ if not return_dict :
244+ return (posterior ,)
245+
246+ return AutoencoderKLOutput (latent_dist = posterior )
247+
248+ def tiled_decode (self , z : torch .FloatTensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .FloatTensor ]:
249+ r"""Decode a batch of images using a tiled decoder.
250+ Args:
251+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
252+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
253+ different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
254+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
255+ look of the output, but they should be much less noticeable.
256+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
257+ `True`):
258+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
259+ """
260+ overlap_size = int (self .tile_latent_min_size * (1 - self .tile_overlap_factor ))
261+ blend_extent = int (self .tile_sample_min_size * self .tile_overlap_factor )
262+ row_limit = self .tile_sample_min_size - blend_extent
263+
264+ # Split z into overlapping 64x64 tiles and decode them separately.
265+ # The tiles have an overlap to avoid seams between tiles.
266+ rows = []
267+ for i in range (0 , z .shape [2 ], overlap_size ):
268+ row = []
269+ for j in range (0 , z .shape [3 ], overlap_size ):
270+ tile = z [:, :, i : i + self .tile_latent_min_size , j : j + self .tile_latent_min_size ]
271+ tile = self .post_quant_conv (tile )
272+ decoded = self .decoder (tile )
273+ row .append (decoded )
274+ rows .append (row )
275+ result_rows = []
276+ for i , row in enumerate (rows ):
277+ result_row = []
278+ for j , tile in enumerate (row ):
279+ # blend the above tile and the left tile
280+ # to the current tile and add the current tile to the result row
281+ if i > 0 :
282+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent )
283+ if j > 0 :
284+ tile = self .blend_h (row [j - 1 ], tile , blend_extent )
285+ result_row .append (tile [:, :, :row_limit , :row_limit ])
286+ result_rows .append (torch .cat (result_row , dim = 3 ))
287+
288+ dec = torch .cat (result_rows , dim = 2 )
289+ if not return_dict :
290+ return (dec ,)
291+
292+ return DecoderOutput (sample = dec )
293+
161294 def forward (
162295 self ,
163296 sample : torch .FloatTensor ,
0 commit comments