Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree

```python
import torch
from diffusers import UNetUnconditionalModel, DDIMScheduler
from diffusers import UNet2DModel, DDIMScheduler
import PIL.Image
import numpy as np
import tqdm
Expand All @@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Load models
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)

# 2. Sample gaussian noise
generator = torch.manual_seed(23)
Expand Down
38 changes: 19 additions & 19 deletions scripts/convert_ddpm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
import argparse
import json
import torch
Expand Down Expand Up @@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
continue

new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('up.', 'upsample_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')

if additional_replacements is not None:
for replacement in additional_replacements:
Expand Down Expand Up @@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config):
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}

num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}

for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
Expand Down Expand Up @@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config):
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])

for i in range(num_upsample_blocks):
block_id = num_upsample_blocks - 1 - i
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i

if any('upsample' in layer for layer in upsample_blocks[i]):
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']

if any('block' in layer for layer in upsample_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])

if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_attn > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])

new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()}
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
return new_checkpoint


Expand Down Expand Up @@ -225,7 +225,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
if "ddpm" in config:
del config["ddpm"]

model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
Expand Down
38 changes: 19 additions & 19 deletions scripts/convert_ldm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline


def shave_segments(path, n_shave_prefix_segments=1):
Expand Down Expand Up @@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions)
to_split = {
'middle_block.1.qkv.bias': {
'key': 'mid.attentions.0.key.bias',
'query': 'mid.attentions.0.query.bias',
'value': 'mid.attentions.0.value.bias',
'key': 'mid_block.attentions.0.key.bias',
'query': 'mid_block.attentions.0.query.bias',
'value': 'mid_block.attentions.0.value.bias',
},
'middle_block.1.qkv.weight': {
'key': 'mid.attentions.0.key.weight',
'query': 'mid.attentions.0.query.weight',
'value': 'mid.attentions.0.value.weight',
'key': 'mid_block.attentions.0.key.weight',
'query': 'mid_block.attentions.0.query.weight',
'value': 'mid_block.attentions.0.value.weight',
},
}
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
Expand All @@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)

meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)

if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']

# Clear attentions as they have been attributed above.
if len(attentions) == 2:
Expand All @@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'output_blocks.{i}.1',
'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}'
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
}
to_split = {
f'output_blocks.{i}.1.qkv.bias': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
},
f'output_blocks.{i}.1.qkv.weight': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
},
}
assign_to_checkpoint(
Expand All @@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])

new_checkpoint[new_path] = checkpoint[old_path]

Expand Down Expand Up @@ -319,7 +319,7 @@ def convert_ldm_checkpoint(checkpoint, config):
if "ldm" in config:
del config["ldm"]

model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

try:
Expand Down
20 changes: 10 additions & 10 deletions scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import argparse
import json
import torch
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel


def convert_ncsnpp_checkpoint(checkpoint, config):
"""
Takes a state dict and the path to
"""
new_model_architecture = UNetUnconditionalModel(**config)
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture = UNet2DModel(**config)
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data

Expand Down Expand Up @@ -92,14 +92,14 @@ def set_resnet_weights(new_layer, old_checkpoint, index):
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
module_index += 1

set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
module_index += 1
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index)
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
module_index += 1

for i, block in enumerate(new_model_architecture.upsample_blocks):
for i, block in enumerate(new_model_architecture.up_blocks):
has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j], checkpoint, module_index)
Expand Down Expand Up @@ -134,7 +134,7 @@ def set_resnet_weights(new_layer, old_checkpoint, index):

parser.add_argument(
"--checkpoint_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
type=str,
required=False,
help="Path to the checkpoint to convert.",
Expand Down Expand Up @@ -171,7 +171,7 @@ def set_resnet_weights(new_layer, old_checkpoint, index):
if "sde" in config:
del config["sde"]

model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

try:
Expand Down
17 changes: 10 additions & 7 deletions scripts/generate_logits.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel
import random
import torch
api = HfApi()
Expand Down Expand Up @@ -70,19 +70,22 @@
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":

if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"):
model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet")
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]

print(f"Started running {mod.modelId}!!!")

if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
else:
model = UNetUnconditionalModel.from_pretrained(mod.modelId)
model = UNet2DModel.from_pretrained(local_checkpoint)

torch.manual_seed(0)
random.seed(0)

noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)['sample']

torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
print(f"{mod.modelId} has passed succesfully!!!")
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__version__ = "0.0.4"

from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def get_config_dict(

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
Expand Down
Loading