Skip to content

Commit 1af7dd4

Browse files
Ttlwilliambermansayakpaulpatrickvonplaten
authored
Controlnet training (huggingface#2545)
* Controlnet training code initial commit Works with circle dataset: https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md * Script for adding a controlnet to existing model * Fix control image transform Control image should be in 0..1 range. * Add license header and remove more unused configs * controlnet training readme * Allow nonlocal model in add_controlnet.py * Formatting * Remove unused code * Code quality * Initialize controlnet in training script * Formatting * Address review comments * doc style * explicit constructor args and submodule names * hub dataset NOTE - not tested * empty prompts * add conditioning image * rename * remove instance data dir * image_transforms -> -1,1 . conditioning_image_transformers -> 0, 1 * nits * remove local rank config I think this isn't necessary in any of our training scripts * validation images * proportion_empty_prompts typo * weight copying to controlnet bug * call log validation fix * fix * gitignore wandb * fix progress bar and resume from checkpoint iteration * initial step fix * log multiple images * fix * fixes * tracker project name configurable * misc * add controlnet requirements.txt * update docs * image labels * small fixes * log validation using existing models for pipeline * fix for deepspeed saving * memory usage docs * Update examples/controlnet/train_controlnet.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/train_controlnet.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * remove extra is main process check * link to dataset in intro paragraph * remove unnecessary paragraph * note on deepspeed * Update examples/controlnet/README.md Co-authored-by: Patrick von Platen <[email protected]> * assert -> value error * weights and biases note * move images out of git * remove .gitignore --------- Co-authored-by: William Berman <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 68552ed commit 1af7dd4

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

models/controlnet.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
UNetMidBlock2DCrossAttn,
3030
get_down_block,
3131
)
32+
from .unet_2d_condition import UNet2DConditionModel
3233

3334

3435
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -257,6 +258,60 @@ def __init__(
257258
upcast_attention=upcast_attention,
258259
)
259260

261+
@classmethod
262+
def from_unet(
263+
cls,
264+
unet: UNet2DConditionModel,
265+
controlnet_conditioning_channel_order: str = "rgb",
266+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
267+
load_weights_from_unet: bool = True,
268+
):
269+
r"""
270+
Instantiate Controlnet class from UNet2DConditionModel.
271+
272+
Parameters:
273+
unet (`UNet2DConditionModel`):
274+
UNet model which weights are copied to the ControlNet. Note that all configuration options are also
275+
copied where applicable.
276+
"""
277+
controlnet = cls(
278+
in_channels=unet.config.in_channels,
279+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
280+
freq_shift=unet.config.freq_shift,
281+
down_block_types=unet.config.down_block_types,
282+
only_cross_attention=unet.config.only_cross_attention,
283+
block_out_channels=unet.config.block_out_channels,
284+
layers_per_block=unet.config.layers_per_block,
285+
downsample_padding=unet.config.downsample_padding,
286+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
287+
act_fn=unet.config.act_fn,
288+
norm_num_groups=unet.config.norm_num_groups,
289+
norm_eps=unet.config.norm_eps,
290+
cross_attention_dim=unet.config.cross_attention_dim,
291+
attention_head_dim=unet.config.attention_head_dim,
292+
use_linear_projection=unet.config.use_linear_projection,
293+
class_embed_type=unet.config.class_embed_type,
294+
num_class_embeds=unet.config.num_class_embeds,
295+
upcast_attention=unet.config.upcast_attention,
296+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
297+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
298+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
299+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
300+
)
301+
302+
if load_weights_from_unet:
303+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
304+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
305+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
306+
307+
if controlnet.class_embedding:
308+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
309+
310+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
311+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
312+
313+
return controlnet
314+
260315
@property
261316
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
262317
def attn_processors(self) -> Dict[str, AttnProcessor]:

pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,17 @@ def prepare_image(
611611
image = [image]
612612

613613
if isinstance(image[0], PIL.Image.Image):
614-
image = [
615-
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image
616-
]
614+
images = []
615+
616+
for image_ in image:
617+
image_ = image_.convert("RGB")
618+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
619+
image_ = np.array(image_)
620+
image_ = image_[None, :]
621+
images.append(image_)
622+
623+
image = images
624+
617625
image = np.concatenate(image, axis=0)
618626
image = np.array(image).astype(np.float32) / 255.0
619627
image = image.transpose(0, 3, 1, 2)

0 commit comments

Comments
 (0)