Skip to content

Commit 6ed41b4

Browse files
kigIlmari Heikkinenpatrickvonplatenpcuenca
authored
8k Stable Diffusion with tiled VAE (huggingface#1441)
* Tiled VAE for high-res text2img and img2img * vae tiling, fix formatting * enable_vae_tiling API and tests * tiled vae docs, disable tiling for images that would have only one tile * tiled vae tests, use channels_last memory format * tiled vae tests, use smaller test image * tiled vae tests, remove tiling test from fast tests * up * up * make style * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * make style * improve naming * finish * apply suggestions * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * up --------- Co-authored-by: Ilmari Heikkinen <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 7b71e6a commit 6ed41b4

File tree

3 files changed

+181
-16
lines changed

3 files changed

+181
-16
lines changed

models/autoencoder_kl.py

Lines changed: 149 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,54 @@ def __init__(
107107

108108
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
109109
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
110+
111+
self.use_slicing = False
112+
self.use_tiling = False
113+
114+
# only relevant if vae tiling is enabled
115+
self.tile_sample_min_size = self.config.sample_size
116+
sample_size = (
117+
self.config.sample_size[0]
118+
if isinstance(self.config.sample_size, (list, tuple))
119+
else self.config.sample_size
120+
)
121+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
122+
self.tile_overlap_factor = 0.25
123+
124+
def enable_tiling(self, use_tiling: bool = True):
125+
r"""
126+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
127+
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
128+
the processing of larger images.
129+
"""
130+
self.use_tiling = use_tiling
131+
132+
def disable_tiling(self):
133+
r"""
134+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
135+
computing decoding in one step.
136+
"""
137+
self.enable_tiling(False)
138+
139+
def enable_slicing(self):
140+
r"""
141+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143+
"""
144+
self.use_slicing = True
145+
146+
def disable_slicing(self):
147+
r"""
148+
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
149+
decoding in one step.
150+
"""
110151
self.use_slicing = False
111152

112153
@apply_forward_hook
113154
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
155+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
156+
return self.tiled_encode(x, return_dict=return_dict)
157+
114158
h = self.encoder(x)
115159
moments = self.quant_conv(h)
116160
posterior = DiagonalGaussianDistribution(moments)
@@ -121,6 +165,9 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK
121165
return AutoencoderKLOutput(latent_dist=posterior)
122166

123167
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
168+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
169+
return self.tiled_decode(z, return_dict=return_dict)
170+
124171
z = self.post_quant_conv(z)
125172
dec = self.decoder(z)
126173

@@ -129,22 +176,6 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
129176

130177
return DecoderOutput(sample=dec)
131178

132-
def enable_slicing(self):
133-
r"""
134-
Enable sliced VAE decoding.
135-
136-
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
137-
steps. This is useful to save some memory and allow larger batch sizes.
138-
"""
139-
self.use_slicing = True
140-
141-
def disable_slicing(self):
142-
r"""
143-
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
144-
decoding in one step.
145-
"""
146-
self.use_slicing = False
147-
148179
@apply_forward_hook
149180
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
150181
if self.use_slicing and z.shape[0] > 1:
@@ -158,6 +189,108 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
158189

159190
return DecoderOutput(sample=decoded)
160191

192+
def blend_v(self, a, b, blend_extent):
193+
for y in range(blend_extent):
194+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
195+
return b
196+
197+
def blend_h(self, a, b, blend_extent):
198+
for x in range(blend_extent):
199+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
200+
return b
201+
202+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
203+
r"""Encode a batch of images using a tiled encoder.
204+
Args:
205+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
206+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
207+
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
208+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
209+
look of the output, but they should be much less noticeable.
210+
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
211+
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
212+
"""
213+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
214+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
215+
row_limit = self.tile_latent_min_size - blend_extent
216+
217+
# Split the image into 512x512 tiles and encode them separately.
218+
rows = []
219+
for i in range(0, x.shape[2], overlap_size):
220+
row = []
221+
for j in range(0, x.shape[3], overlap_size):
222+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
223+
tile = self.encoder(tile)
224+
tile = self.quant_conv(tile)
225+
row.append(tile)
226+
rows.append(row)
227+
result_rows = []
228+
for i, row in enumerate(rows):
229+
result_row = []
230+
for j, tile in enumerate(row):
231+
# blend the above tile and the left tile
232+
# to the current tile and add the current tile to the result row
233+
if i > 0:
234+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
235+
if j > 0:
236+
tile = self.blend_h(row[j - 1], tile, blend_extent)
237+
result_row.append(tile[:, :, :row_limit, :row_limit])
238+
result_rows.append(torch.cat(result_row, dim=3))
239+
240+
moments = torch.cat(result_rows, dim=2)
241+
posterior = DiagonalGaussianDistribution(moments)
242+
243+
if not return_dict:
244+
return (posterior,)
245+
246+
return AutoencoderKLOutput(latent_dist=posterior)
247+
248+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
249+
r"""Decode a batch of images using a tiled decoder.
250+
Args:
251+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
252+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
253+
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
254+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
255+
look of the output, but they should be much less noticeable.
256+
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
257+
`True`):
258+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
259+
"""
260+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
261+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
262+
row_limit = self.tile_sample_min_size - blend_extent
263+
264+
# Split z into overlapping 64x64 tiles and decode them separately.
265+
# The tiles have an overlap to avoid seams between tiles.
266+
rows = []
267+
for i in range(0, z.shape[2], overlap_size):
268+
row = []
269+
for j in range(0, z.shape[3], overlap_size):
270+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
271+
tile = self.post_quant_conv(tile)
272+
decoded = self.decoder(tile)
273+
row.append(decoded)
274+
rows.append(row)
275+
result_rows = []
276+
for i, row in enumerate(rows):
277+
result_row = []
278+
for j, tile in enumerate(row):
279+
# blend the above tile and the left tile
280+
# to the current tile and add the current tile to the result row
281+
if i > 0:
282+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
283+
if j > 0:
284+
tile = self.blend_h(row[j - 1], tile, blend_extent)
285+
result_row.append(tile[:, :, :row_limit, :row_limit])
286+
result_rows.append(torch.cat(result_row, dim=3))
287+
288+
dec = torch.cat(result_rows, dim=2)
289+
if not return_dict:
290+
return (dec,)
291+
292+
return DecoderOutput(sample=dec)
293+
161294
def forward(
162295
self,
163296
sample: torch.FloatTensor,

pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,22 @@ def disable_vae_slicing(self):
183183
"""
184184
self.vae.disable_slicing()
185185

186+
def enable_vae_tiling(self):
187+
r"""
188+
Enable tiled VAE decoding.
189+
190+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
191+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
192+
"""
193+
self.vae.enable_tiling()
194+
195+
def disable_vae_tiling(self):
196+
r"""
197+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
198+
computing decoding in one step.
199+
"""
200+
self.vae.disable_tiling()
201+
186202
def enable_sequential_cpu_offload(self, gpu_id=0):
187203
r"""
188204
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,22 @@ def disable_vae_slicing(self):
186186
"""
187187
self.vae.disable_slicing()
188188

189+
def enable_vae_tiling(self):
190+
r"""
191+
Enable tiled VAE decoding.
192+
193+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
194+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
195+
"""
196+
self.vae.enable_tiling()
197+
198+
def disable_vae_tiling(self):
199+
r"""
200+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
201+
computing decoding in one step.
202+
"""
203+
self.vae.disable_tiling()
204+
189205
def enable_sequential_cpu_offload(self, gpu_id=0):
190206
r"""
191207
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

0 commit comments

Comments
 (0)