Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
8d1a17c
re-add RL model code
natolambert Jul 19, 2022
84e94d7
match model forward api
natolambert Jul 19, 2022
f67b036
add register_to_config, pass training tests
natolambert Jul 26, 2022
e42d1c0
fix tests, update forward outputs
natolambert Oct 3, 2022
2dd514e
remove unused code, some comments
natolambert Oct 3, 2022
b4c6188
add to docs
natolambert Oct 3, 2022
c53bba9
remove extra embedding code
natolambert Oct 6, 2022
effcbdb
unify time embedding
natolambert Oct 7, 2022
7865231
remove conv1d output sequential
natolambert Oct 8, 2022
35b0a43
remove sequential from conv1dblock
natolambert Oct 8, 2022
9b1379d
style and deleting duplicated code
natolambert Oct 8, 2022
e97a610
clean files
natolambert Oct 8, 2022
f29ace4
valuefunction code
bglick13 Oct 8, 2022
1684e8b
start example scripts
bglick13 Oct 8, 2022
c757985
missing imports
bglick13 Oct 8, 2022
b315918
bug fixes and placeholder example script
bglick13 Oct 8, 2022
f01c014
add value function scheduler
bglick13 Oct 9, 2022
7b60c93
load value function from hub and get best actions in example
bglick13 Oct 9, 2022
8642560
remove unused variables
natolambert Oct 10, 2022
f58c915
clean variables
natolambert Oct 10, 2022
ad8376d
Merge branch 'main' into rl
natolambert Oct 10, 2022
0de435e
very close to working example
bglick13 Oct 10, 2022
a396529
larger batch size for planning
bglick13 Oct 10, 2022
3b08bea
add 1d resnet block structure for downsample
natolambert Oct 10, 2022
aae2a9a
rename as unet1d
natolambert Oct 10, 2022
dd872af
fix renaming
natolambert Oct 10, 2022
713bd80
more tests
bglick13 Oct 11, 2022
e3fb50f
Merge branch 'main' into rl
bglick13 Oct 11, 2022
686069f
Merge branch 'hf_rl' into rl
bglick13 Oct 11, 2022
d9384ff
merge unet1d changes
bglick13 Oct 11, 2022
52e2668
wandb for debugging, use newer models
bglick13 Oct 11, 2022
75fe8b4
success!
bglick13 Oct 11, 2022
9b67bb7
rename files
natolambert Oct 12, 2022
c7fe1dc
turns out we just need more diffusion steps
bglick13 Oct 12, 2022
db012eb
add get_block(...) api
natolambert Oct 12, 2022
4db6e0b
unify args for model1d like model2d
natolambert Oct 12, 2022
634a526
minor cleaning
natolambert Oct 12, 2022
aebf547
fix docs
natolambert Oct 12, 2022
305ecd8
improve 1d resnet blocks
natolambert Oct 12, 2022
42855b9
Merge branch 'main' into rl
natolambert Oct 12, 2022
95d3a1c
fix tests, remove permuts
natolambert Oct 12, 2022
6cbb73b
fix style
natolambert Oct 12, 2022
a6871b1
run on modal
bglick13 Oct 12, 2022
13a443c
Merge branch 'hf_rl' into rl
bglick13 Oct 12, 2022
38616cf
merge and code cleanup
bglick13 Oct 12, 2022
d37b472
use same api for rl model
bglick13 Oct 12, 2022
798263f
init v-pred pr
natolambert Oct 13, 2022
b7d0c1e
placeholder code
natolambert Oct 13, 2022
7eb4bfa
up
natolambert Oct 13, 2022
3eb2593
a few more additions
natolambert Oct 13, 2022
aa19286
fix variance type
bglick13 Oct 13, 2022
02293e2
wrong normalization function
bglick13 Oct 13, 2022
56818e5
add tests
bglick13 Oct 17, 2022
d085725
style
bglick13 Oct 17, 2022
93fe3ef
style and quality
bglick13 Oct 17, 2022
4c68504
add ddim
natolambert Oct 18, 2022
ac6be90
style
natolambert Oct 18, 2022
ffb7355
add output activation
natolambert Oct 18, 2022
a6314f6
rename flax blocks file
natolambert Oct 18, 2022
4e378e9
edits based on comments
bglick13 Oct 18, 2022
e7e6963
style and quality
bglick13 Oct 18, 2022
4f77d89
remove unused var
bglick13 Oct 19, 2022
5de8a6a
Merge branch 'hf_rl' into rl
bglick13 Oct 19, 2022
fb02b7d
Merge branch 'main' into fork-main
bglick13 Oct 19, 2022
6bd8397
hack unet1d into a value function
bglick13 Oct 20, 2022
435ad26
add pipeline
bglick13 Oct 20, 2022
5653408
fix arg order
bglick13 Oct 20, 2022
1491932
add pipeline to core library
bglick13 Oct 20, 2022
1a8098e
community pipeline
bglick13 Oct 20, 2022
0e4be75
fix couple shape bugs
bglick13 Oct 21, 2022
5ef88ef
style
bglick13 Oct 21, 2022
c6d94ce
Apply suggestions from code review
Oct 21, 2022
48a7414
Add Value Function and corresponding example script to Diffuser imple…
bglick13 Oct 21, 2022
3acddb5
update post merge of scripts
natolambert Oct 21, 2022
a9cee78
clean up comments
bglick13 Oct 21, 2022
5c8cfc2
Merge remote-tracking branch 'bglick13/rl' into rl
bglick13 Oct 21, 2022
b7fac18
convert older script to using pipeline and add readme
bglick13 Oct 21, 2022
b3edd7b
rename scripts
bglick13 Oct 21, 2022
8b01b93
style, update tests
bglick13 Oct 21, 2022
b0b8b0b
Merge branch 'hf_rl' into rl
bglick13 Oct 22, 2022
3c668a7
delete unet rl model file
bglick13 Oct 22, 2022
469779b
Merge branch 'main' into fork-main
bglick13 Oct 24, 2022
713e8f2
add mdiblock / outblock architecture
natolambert Oct 24, 2022
af26faa
remove imports in src
Oct 24, 2022
268ebdf
Pipeline cleanup (#947)
bglick13 Oct 24, 2022
daa05fb
Update src/diffusers/models/unet_1d_blocks.py
Oct 24, 2022
ea5f231
Update tests/test_models_unet.py
Oct 24, 2022
84efdac
add specific vf block and update tests
bglick13 Oct 24, 2022
3bf848f
Merge branch 'hf_rl' into rl
bglick13 Oct 24, 2022
9faf55a
style
bglick13 Oct 24, 2022
24bb52a
Update tests/test_models_unet.py
Oct 24, 2022
4f7a3a4
RL Cleanup v2 (#965)
bglick13 Oct 24, 2022
d90b8b1
fix quality in tests
natolambert Oct 24, 2022
35f03be
quality
bglick13 Oct 24, 2022
ad8b6cf
fix quality style, split test file
natolambert Oct 24, 2022
e06a4a4
Merge branch 'main' into rl
natolambert Oct 24, 2022
99b2c81
fix checks / tests
natolambert Oct 24, 2022
8828f73
Merge remote-tracking branch 'bglick13/rl' into rl
bglick13 Oct 24, 2022
1fed4f1
Merge branch 'main' into fork-main
bglick13 Oct 25, 2022
eceafd5
placeholder script
bglick13 Oct 25, 2022
de4b6e4
make timesteps closer to main
natolambert Oct 25, 2022
ef6ca1f
unify block API
natolambert Oct 25, 2022
6e3485c
Merge branch 'main' into rl
natolambert Oct 25, 2022
e6f1a83
unify forward api
natolambert Oct 25, 2022
c35a925
delete lines in examples
natolambert Oct 25, 2022
949b93a
style
natolambert Oct 25, 2022
2f6462b
examples style
natolambert Oct 25, 2022
a2dd559
all tests pass
natolambert Oct 26, 2022
39dff73
make style
natolambert Oct 26, 2022
7653c4f
update conversion script
bglick13 Oct 26, 2022
2f97adf
first go
Oct 26, 2022
d5eedff
make dance_diff test pass
natolambert Oct 26, 2022
1f93de2
use ddim insstead of ddpm
Oct 27, 2022
826b459
close on multi step, but still some quality loss
Oct 27, 2022
b76d084
equation fixes and comments
Oct 27, 2022
8900921
correct beta schedule
Oct 27, 2022
8b13860
Merge branch 'hf_rl' into rl
bglick13 Oct 27, 2022
671e55f
Merge branch 'progressive-distillation' into rl_distillation
bglick13 Oct 27, 2022
a6105a7
add distillation pipeline
bglick13 Oct 27, 2022
23d8c05
add some code to make it work with rl example
bglick13 Oct 29, 2022
1c6dfb3
Merge branch 'rl_distillation' into progressive-distillation
bglick13 Oct 29, 2022
beeb9b1
code cleanup, use pipeline in example
bglick13 Oct 29, 2022
e6a2425
code cleanup
bglick13 Oct 31, 2022
c0cfe79
remove diffuser stuff from this pr
bglick13 Oct 31, 2022
778744d
Merge branch 'main' into fork-main
bglick13 Oct 31, 2022
57af150
Merge branch 'fork-main' into progressive-distillation
bglick13 Oct 31, 2022
79021e1
remove more diffusers stuff
bglick13 Oct 31, 2022
fcd5dee
rebase main onto branch
bglick13 Oct 31, 2022
c6ac284
make fix copies
bglick13 Oct 31, 2022
1ab27a4
Update docs/source/api/models.mdx
Oct 31, 2022
9bb7818
code cleanup and start writing tests
bglick13 Nov 1, 2022
c584c3c
fast test passing
bglick13 Nov 1, 2022
1261c2f
accomodate dict collation
bglick13 Nov 1, 2022
47d2913
colab script for debugging
bglick13 Nov 1, 2022
97801e3
some v diffusion support
Nov 2, 2022
fc47982
Merge branch 'main' into fork-main
bglick13 Nov 2, 2022
fbe5b5c
Merge branch 'fork-main' into progressive-distillation
bglick13 Nov 2, 2022
2aec5b1
v diffusion and training on butterflies example with it
Nov 3, 2022
a48b026
v diffusion support for ddpm
bglick13 Nov 3, 2022
3d702c6
quality and style
bglick13 Nov 3, 2022
0889fd1
variable name consistency
bglick13 Nov 3, 2022
f7c7095
missing base case
bglick13 Nov 3, 2022
0c23e11
pass prediction type along in the pipeline
bglick13 Nov 3, 2022
d887e58
correct variance type
Nov 3, 2022
b46327e
put prediction type in scheduler config
bglick13 Nov 7, 2022
45c36c8
style
bglick13 Nov 7, 2022
f00d896
DDPM changes to support v diffusion (#1121)
bglick13 Nov 9, 2022
8fe2ff4
Merge branch 'main' into v_prediction
natolambert Nov 9, 2022
56164f5
quality
natolambert Nov 9, 2022
13404a6
try to train on ddim
bglick13 Nov 9, 2022
1fa3cc8
changes to ddim
Nov 15, 2022
0b60c2b
ddim v prediction works to train butterflies example
Nov 16, 2022
1892306
Merge branch 'v_prediction' into v_prediction_ddim
bglick13 Nov 16, 2022
8311d89
fix bad merge, style and quality
bglick13 Nov 16, 2022
24af3d9
Merge branch 'v_prediction_ddim' into progressive-distillation
bglick13 Nov 16, 2022
f0c0dee
closest yet
Nov 18, 2022
d50cd2f
weird torch.rand bug
Nov 19, 2022
eeb8e31
remove literals
bglick13 Nov 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ tags
*.lock

# DS_Store (MacOS)
.DS_Store
.DS_Store
*.png
93 changes: 93 additions & 0 deletions examples/progressive_distillation/colab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from dataclasses import dataclass


@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "ddpm-butterflies-128" # the model namy locally and on the HF Hub

push_to_hub = True # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0


config = TrainingConfig()

from datasets import load_dataset

config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")
from torchvision import transforms

preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)


def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}


dataset.set_transform(transform)
import torch
import os

from diffusers import UNet2DModel, DistillationPipeline, DDPMPipeline, DDPMScheduler, DDIMPipeline, DDIMScheduler
from accelerate import Accelerator


teacher = UNet2DModel.from_pretrained("bglick13/ddim-butterflies-128-v-diffusion", subfolder="unet")

# accelerator = Accelerator(
# mixed_precision=config.mixed_precision,
# gradient_accumulation_steps=config.gradient_accumulation_steps,
# log_with="tensorboard",
# logging_dir=os.path.join(config.output_dir, "logs"),
# )
# teacher = accelerator.prepare(teacher)
distiller = DistillationPipeline()
n_teacher_trainsteps = 1000
new_teacher, distilled_ema, distill_accelrator = distiller(
teacher,
n_teacher_trainsteps,
dataset,
epochs=100,
batch_size=32,
mixed_precision="fp16",
sample_every=1,
gamma=0.0,
lr=1e-4,
)
new_scheduler = DDIMScheduler(
num_train_timesteps=500, beta_schedule="squaredcos_cap_v2", variance_type="v_diffusion", prediction_type="v"
)
pipeline = DDIMPipeline(
unet=distill_accelrator.unwrap_model(distilled_ema.averaged_model),
scheduler=new_scheduler,
)

# run pipeline in inference (sample random noise and denoise)
images = pipeline(batch_size=4, output_type="numpy", generator=torch.manual_seed(0)).images

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
from PIL import Image

img = Image.fromarray(images_processed[0])
img.save("denoised.png")
273 changes: 273 additions & 0 deletions examples/progressive_distillation/image_diffusion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
"WARNING:torch.distributed.elastic.multiprocessing.redirects:NOTE: Redirects are currently not supported in Windows or MacOs.\n"
]
}
],
"source": [
"import torch\n",
"from PIL import Image\n",
"from diffusers import AutoencoderKL, UNet2DModel, DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, DistillationPipeline\n",
"from diffusers.optimization import get_scheduler\n",
"from diffusers.training_utils import EMAModel\n",
"import math\n",
"import requests\n",
"from torchvision.transforms import (\n",
" CenterCrop,\n",
" Compose,\n",
" InterpolationMode,\n",
" Normalize,\n",
" RandomHorizontalFlip,\n",
" Resize,\n",
" ToTensor,\n",
" ToPILImage\n",
")\n",
"from torch.utils.data import Dataset\n",
"from accelerate import Accelerator\n",
"import utils\n",
"from tqdm import tqdm\n",
"import torch.nn.functional as F\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f9a051d2010>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.manual_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"training_config = utils.DiffusionTrainingArgs()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Load an image of my dog for this example\n",
"\n",
"image_url = \"https://i.imgur.com/IJcs4Aa.jpeg\"\n",
"image = Image.open(requests.get(image_url, stream=True).raw)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Define the transforms to apply to the image for training\n",
"augmentations = utils.get_train_transforms(training_config)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class SingleImageDataset(Dataset):\n",
" def __init__(self, image, batch_size):\n",
" self.image = image\n",
" self.batch_size = batch_size\n",
"\n",
" def __len__(self):\n",
" return self.batch_size\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.image\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_image = augmentations(image.convert(\"RGB\"))\n",
"train_dataset = SingleImageDataset(train_image, training_config.batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b23b591496741a299b75e4e9448b29a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/455M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1df9166b338f49adbaac183384972ea0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/665 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"teacher = UNet2DModel.from_pretrained(\"bglick13/minnie-diffusion\")\n",
"distiller = DistillationPipeline()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"N = 1000\n",
"generator = torch.manual_seed(0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"teacher = UNet2DModel.from_pretrained(\"bglick13/minnie-diffusion\")\n",
"N = 1000\n",
"distilled_images = []\n",
"for distill_step in range(2):\n",
" print(f\"Distill step {distill_step} from {N} -> {N // 2}\")\n",
" teacher, distilled_ema, distill_accelrator = distiller(teacher, N, train_dataset, epochs=300, batch_size=training_config.batch_size)\n",
" N = N // 2\n",
" new_scheduler = DDPMScheduler(num_train_timesteps=N, beta_schedule=\"squaredcos_cap_v2\")\n",
" pipeline = DDPMPipeline(\n",
" unet=distill_accelrator.unwrap_model(distilled_ema.averaged_model if training_config.use_ema else teacher),\n",
" scheduler=new_scheduler,\n",
" )\n",
"\n",
" # run pipeline in inference (sample random noise and denoise)\n",
" images = pipeline(generator=generator, batch_size=training_config.batch_size, output_type=\"numpy\").images\n",
"\n",
" # denormalize the images and save to tensorboard\n",
" images_processed = (images * 255).round().astype(\"uint8\")\n",
" distilled_images.append(images_processed[0])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display train image for reference\n",
"train_image_display = train_image * 0.5 + 0.5\n",
"train_image_display = ToPILImage()(train_image_display)\n",
"display(train_image_display)\n",
"\n",
"for i, image in enumerate(distilled_images):\n",
" print(f\"Distilled image {i}\")\n",
" display(Image.fromarray(image))\n",
" Image.fromarray(image).save(f\"distilled_{i}.png\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"display(Image.fromarray(images_processed[0]))\n",
"display(Image.fromarray(images_processed[1]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.6 ('diffusers')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "77f6871a522595648ebba7232d315a2f946cc4cd5f56470cb61e517ec9b94e2e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading