|
29 | 29 | UNetMidBlock2DCrossAttn, |
30 | 30 | get_down_block, |
31 | 31 | ) |
| 32 | +from .unet_2d_condition import UNet2DConditionModel |
32 | 33 |
|
33 | 34 |
|
34 | 35 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -257,6 +258,60 @@ def __init__( |
257 | 258 | upcast_attention=upcast_attention, |
258 | 259 | ) |
259 | 260 |
|
| 261 | + @classmethod |
| 262 | + def from_unet( |
| 263 | + cls, |
| 264 | + unet: UNet2DConditionModel, |
| 265 | + controlnet_conditioning_channel_order: str = "rgb", |
| 266 | + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), |
| 267 | + load_weights_from_unet: bool = True, |
| 268 | + ): |
| 269 | + r""" |
| 270 | + Instantiate Controlnet class from UNet2DConditionModel. |
| 271 | +
|
| 272 | + Parameters: |
| 273 | + unet (`UNet2DConditionModel`): |
| 274 | + UNet model which weights are copied to the ControlNet. Note that all configuration options are also |
| 275 | + copied where applicable. |
| 276 | + """ |
| 277 | + controlnet = cls( |
| 278 | + in_channels=unet.config.in_channels, |
| 279 | + flip_sin_to_cos=unet.config.flip_sin_to_cos, |
| 280 | + freq_shift=unet.config.freq_shift, |
| 281 | + down_block_types=unet.config.down_block_types, |
| 282 | + only_cross_attention=unet.config.only_cross_attention, |
| 283 | + block_out_channels=unet.config.block_out_channels, |
| 284 | + layers_per_block=unet.config.layers_per_block, |
| 285 | + downsample_padding=unet.config.downsample_padding, |
| 286 | + mid_block_scale_factor=unet.config.mid_block_scale_factor, |
| 287 | + act_fn=unet.config.act_fn, |
| 288 | + norm_num_groups=unet.config.norm_num_groups, |
| 289 | + norm_eps=unet.config.norm_eps, |
| 290 | + cross_attention_dim=unet.config.cross_attention_dim, |
| 291 | + attention_head_dim=unet.config.attention_head_dim, |
| 292 | + use_linear_projection=unet.config.use_linear_projection, |
| 293 | + class_embed_type=unet.config.class_embed_type, |
| 294 | + num_class_embeds=unet.config.num_class_embeds, |
| 295 | + upcast_attention=unet.config.upcast_attention, |
| 296 | + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, |
| 297 | + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, |
| 298 | + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, |
| 299 | + conditioning_embedding_out_channels=conditioning_embedding_out_channels, |
| 300 | + ) |
| 301 | + |
| 302 | + if load_weights_from_unet: |
| 303 | + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) |
| 304 | + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) |
| 305 | + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) |
| 306 | + |
| 307 | + if controlnet.class_embedding: |
| 308 | + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) |
| 309 | + |
| 310 | + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) |
| 311 | + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) |
| 312 | + |
| 313 | + return controlnet |
| 314 | + |
260 | 315 | @property |
261 | 316 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors |
262 | 317 | def attn_processors(self) -> Dict[str, AttnProcessor]: |
|
0 commit comments