Skip to content

Commit 71cf5b6

Browse files
Big Model Renaming (huggingface#109)
* up * change model name * renaming * more changes * up * up * up * save checkpoint * finish api / naming * finish config renaming * rename all weights * finish really
1 parent 382b6bc commit 71cf5b6

18 files changed

+453
-523
lines changed

__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
__version__ = "0.0.4"
88

99
from .modeling_utils import ModelMixin
10-
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
10+
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
1111
from .pipeline_utils import DiffusionPipeline
1212
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
1313
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler

configuration_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ def get_config_dict(
161161

162162
except RepositoryNotFoundError:
163163
raise EnvironmentError(
164-
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
165-
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
166-
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
167-
" pass `use_auth_token=True`."
164+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
165+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
166+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
167+
" login` and pass `use_auth_token=True`."
168168
)
169169
except RevisionNotFoundError:
170170
raise EnvironmentError(

modeling_utils.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535

3636

37-
WEIGHTS_NAME = "diffusion_model.pt"
37+
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
3838

3939

4040
logger = logging.get_logger(__name__)
@@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module):
147147
models, `pixel_values` for vision models and `input_values` for speech models).
148148
"""
149149
config_name = CONFIG_NAME
150-
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
150+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
151151

152152
def __init__(self):
153153
super().__init__()
@@ -341,7 +341,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
341341
subfolder=subfolder,
342342
**kwargs,
343343
)
344-
model.register_to_config(name_or_path=pretrained_model_name_or_path)
344+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
345345
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
346346
# Load model
347347
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@@ -497,46 +497,45 @@ def _find_mismatched_keys(
497497
)
498498
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
499499

500-
if False:
501-
if len(unexpected_keys) > 0:
502-
logger.warning(
503-
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
504-
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
505-
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
506-
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
507-
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
508-
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
509-
" identical (initializing a BertForSequenceClassification model from a"
510-
" BertForSequenceClassification model)."
511-
)
512-
else:
513-
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
514-
if len(missing_keys) > 0:
515-
logger.warning(
516-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
517-
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
518-
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
519-
)
520-
elif len(mismatched_keys) == 0:
521-
logger.info(
522-
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
523-
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
524-
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
525-
" without further training."
526-
)
527-
if len(mismatched_keys) > 0:
528-
mismatched_warning = "\n".join(
529-
[
530-
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
531-
for key, shape1, shape2 in mismatched_keys
532-
]
533-
)
534-
logger.warning(
535-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
536-
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
537-
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
538-
" able to use it for predictions and inference."
539-
)
500+
if len(unexpected_keys) > 0:
501+
logger.warning(
502+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
503+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
504+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
505+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
506+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
507+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
508+
" identical (initializing a BertForSequenceClassification model from a"
509+
" BertForSequenceClassification model)."
510+
)
511+
else:
512+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
513+
if len(missing_keys) > 0:
514+
logger.warning(
515+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
516+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
517+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
518+
)
519+
elif len(mismatched_keys) == 0:
520+
logger.info(
521+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
522+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
523+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
524+
" without further training."
525+
)
526+
if len(mismatched_keys) > 0:
527+
mismatched_warning = "\n".join(
528+
[
529+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
530+
for key, shape1, shape2 in mismatched_keys
531+
]
532+
)
533+
logger.warning(
534+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
535+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
536+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
537+
" able to use it for predictions and inference."
538+
)
540539

541540
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
542541

models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19-
from .unet_conditional import UNetConditionalModel
20-
from .unet_unconditional import UNetUnconditionalModel
19+
from .unet_2d import UNet2DModel
20+
from .unet_2d_condition import UNet2DConditionModel
2121
from .vae import AutoencoderKL, VQModel

models/attention.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,15 @@ class AttentionBlockNew(nn.Module):
1717
def __init__(
1818
self,
1919
channels,
20-
num_heads=1,
2120
num_head_channels=None,
2221
num_groups=32,
2322
rescale_output_factor=1.0,
2423
eps=1e-5,
2524
):
2625
super().__init__()
2726
self.channels = channels
28-
if num_head_channels is None:
29-
self.num_heads = num_heads
30-
else:
31-
assert (
32-
channels % num_head_channels == 0
33-
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
34-
self.num_heads = channels // num_head_channels
3527

28+
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
3629
self.num_head_size = num_head_channels
3730
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
3831

models/resnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,11 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
7878

7979
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
8080
if name == "conv":
81+
self.Conv2d_0 = conv
8182
self.conv = conv
8283
elif name == "Conv2d_0":
83-
self.Conv2d_0 = conv
8484
self.conv = conv
8585
else:
86-
self.op = conv
8786
self.conv = conv
8887

8988
def forward(self, x):

models/unet_2d.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from typing import Dict, Union
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from ..configuration_utils import ConfigMixin, register_to_config
7+
from ..modeling_utils import ModelMixin
8+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
9+
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
10+
11+
12+
class UNet2DModel(ModelMixin, ConfigMixin):
13+
@register_to_config
14+
def __init__(
15+
self,
16+
sample_size=None,
17+
in_channels=3,
18+
out_channels=3,
19+
center_input_sample=False,
20+
time_embedding_type="positional",
21+
freq_shift=0,
22+
flip_sin_to_cos=True,
23+
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24+
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25+
block_out_channels=(224, 448, 672, 896),
26+
layers_per_block=2,
27+
mid_block_scale_factor=1,
28+
downsample_padding=1,
29+
act_fn="silu",
30+
attention_head_dim=8,
31+
norm_num_groups=32,
32+
norm_eps=1e-5,
33+
):
34+
super().__init__()
35+
36+
self.sample_size = sample_size
37+
time_embed_dim = block_out_channels[0] * 4
38+
39+
# input
40+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
41+
42+
# time
43+
if time_embedding_type == "fourier":
44+
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
45+
timestep_input_dim = 2 * block_out_channels[0]
46+
elif time_embedding_type == "positional":
47+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
48+
timestep_input_dim = block_out_channels[0]
49+
50+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
51+
52+
self.down_blocks = nn.ModuleList([])
53+
self.mid_block = None
54+
self.up_blocks = nn.ModuleList([])
55+
56+
# down
57+
output_channel = block_out_channels[0]
58+
for i, down_block_type in enumerate(down_block_types):
59+
input_channel = output_channel
60+
output_channel = block_out_channels[i]
61+
is_final_block = i == len(block_out_channels) - 1
62+
63+
down_block = get_down_block(
64+
down_block_type,
65+
num_layers=layers_per_block,
66+
in_channels=input_channel,
67+
out_channels=output_channel,
68+
temb_channels=time_embed_dim,
69+
add_downsample=not is_final_block,
70+
resnet_eps=norm_eps,
71+
resnet_act_fn=act_fn,
72+
attn_num_head_channels=attention_head_dim,
73+
downsample_padding=downsample_padding,
74+
)
75+
self.down_blocks.append(down_block)
76+
77+
# mid
78+
self.mid_block = UNetMidBlock2D(
79+
in_channels=block_out_channels[-1],
80+
temb_channels=time_embed_dim,
81+
resnet_eps=norm_eps,
82+
resnet_act_fn=act_fn,
83+
output_scale_factor=mid_block_scale_factor,
84+
resnet_time_scale_shift="default",
85+
attn_num_head_channels=attention_head_dim,
86+
resnet_groups=norm_num_groups,
87+
)
88+
89+
# up
90+
reversed_block_out_channels = list(reversed(block_out_channels))
91+
output_channel = reversed_block_out_channels[0]
92+
for i, up_block_type in enumerate(up_block_types):
93+
prev_output_channel = output_channel
94+
output_channel = reversed_block_out_channels[i]
95+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
96+
97+
is_final_block = i == len(block_out_channels) - 1
98+
99+
up_block = get_up_block(
100+
up_block_type,
101+
num_layers=layers_per_block + 1,
102+
in_channels=input_channel,
103+
out_channels=output_channel,
104+
prev_output_channel=prev_output_channel,
105+
temb_channels=time_embed_dim,
106+
add_upsample=not is_final_block,
107+
resnet_eps=norm_eps,
108+
resnet_act_fn=act_fn,
109+
attn_num_head_channels=attention_head_dim,
110+
)
111+
self.up_blocks.append(up_block)
112+
prev_output_channel = output_channel
113+
114+
# out
115+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
116+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
117+
self.conv_act = nn.SiLU()
118+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
119+
120+
def forward(
121+
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
122+
) -> Dict[str, torch.FloatTensor]:
123+
124+
# 0. center input if necessary
125+
if self.config.center_input_sample:
126+
sample = 2 * sample - 1.0
127+
128+
# 1. time
129+
timesteps = timestep
130+
if not torch.is_tensor(timesteps):
131+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
132+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
133+
timesteps = timesteps[None].to(sample.device)
134+
135+
t_emb = self.time_proj(timesteps)
136+
emb = self.time_embedding(t_emb)
137+
138+
# 2. pre-process
139+
skip_sample = sample
140+
sample = self.conv_in(sample)
141+
142+
# 3. down
143+
down_block_res_samples = (sample,)
144+
for downsample_block in self.down_blocks:
145+
if hasattr(downsample_block, "skip_conv"):
146+
sample, res_samples, skip_sample = downsample_block(
147+
hidden_states=sample, temb=emb, skip_sample=skip_sample
148+
)
149+
else:
150+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
151+
152+
down_block_res_samples += res_samples
153+
154+
# 4. mid
155+
sample = self.mid_block(sample, emb)
156+
157+
# 5. up
158+
skip_sample = None
159+
for upsample_block in self.up_blocks:
160+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
161+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
162+
163+
if hasattr(upsample_block, "skip_conv"):
164+
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
165+
else:
166+
sample = upsample_block(sample, res_samples, emb)
167+
168+
# 6. post-process
169+
sample = self.conv_norm_out(sample)
170+
sample = self.conv_act(sample)
171+
sample = self.conv_out(sample)
172+
173+
if skip_sample is not None:
174+
sample += skip_sample
175+
176+
if self.config.time_embedding_type == "fourier":
177+
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
178+
sample = sample / timesteps
179+
180+
output = {"sample": sample}
181+
182+
return output

0 commit comments

Comments
 (0)