Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b75ecf7
First pass at doing classifier-free guidance sampling with Cog
dashstander Feb 7, 2022
1b513ea
Not installing any system packages
dashstander Feb 7, 2022
7b2447b
Update name of predictor class
dashstander Feb 8, 2022
d53dcc6
Lower pytorch version to 1.8
dashstander Feb 8, 2022
9077ca3
Fixed bug in the import
dashstander Feb 8, 2022
dc3f574
Merge branch 'cog' of github.com:dashstander/v-diffusion-pytorch into…
dashstander Feb 8, 2022
3ef585d
Bring into line with torch 1.8
dashstander Feb 8, 2022
5360f0b
Merge branch 'cog' of github.com:dashstander/v-diffusion-pytorch into…
dashstander Feb 8, 2022
dcaa631
cfg_sample_fn needs to have a specific signature
dashstander Feb 8, 2022
c838257
Merge branch 'cog' of github.com:dashstander/v-diffusion-pytorch into…
dashstander Feb 8, 2022
fe674f4
cfg_sample_fn needs to have a specific signature
dashstander Feb 8, 2022
fa5e671
Merge branch 'cog' of github.com:dashstander/v-diffusion-pytorch into…
dashstander Feb 8, 2022
21f547a
Dumb dumb dumb
dashstander Feb 8, 2022
d74c675
Merge branch 'cog' of github.com:dashstander/v-diffusion-pytorch into…
dashstander Feb 8, 2022
3b8743b
Add cog.yaml and cfg_predict.py files to be able to run a demo of the…
dashstander Feb 8, 2022
e1b9c1b
Merge branch 'master' into cog
dashstander Feb 11, 2022
0f3d53b
Update to use PLMS sampling
dashstander Feb 11, 2022
3cf9f96
Merge in changes
dashstander Feb 11, 2022
486eac8
How did I forget the git conflict lines???
dashstander Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ If they are somewhere else, you need to specify the path to the checkpoint with

### CFG sampling (best, but only cc12m_1_cfg supports it)

[Demo and Docker image on Replicate](https://replicate.ai/crowsonkb/clip-guided-diffusion-cfg)

```
usage: cfg_sample.py [-h] [--images [IMAGE ...]] [--batch-size BATCH_SIZE]
[--checkpoint CHECKPOINT] [--device DEVICE] [--eta ETA] [--init INIT]
Expand Down
91 changes: 91 additions & 0 deletions cfg_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Prediction interface for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/python.md

import cog
from pathlib import Path
from PIL import Image
import tempfile
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF

from CLIP import clip
from diffusion import get_model, sampling, utils


def resize_and_center_crop(image, size):
fac = max(size[0] / image.size[0], size[1] / image.size[1])
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])


def parse_prompt(prompt, default_weight=3.):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', default_weight][len(vals):]
return vals[0], float(vals[1])


class ClassifierFreeGuidanceDiffusionPredictor(cog.Predictor):
model_name = 'cc12m_1_cfg'
checkpoint_path = 'checkpoints/cc12m_1_cfg.pth'
device = 'cuda:0'

def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
assert torch.cuda.is_available()
self.model = get_model(self.model_name)()
self.model.load_state_dict(torch.load(self.checkpoint_path, map_location='cpu'))
self.model.half()
self.model.to(self.device).eval().requires_grad_(False)
self.clip = clip.load('ViT-B/16', jit=False, device=self.device)[0]
self.clip.eval().requires_grad_(False)
self.normalize_fn = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711]
)

@property
def output_dim(self):
return self.clip.visual.output_dim

def normalize(self, image):
return self.normalize_fn(image)

def run_sampling(self, x, steps, eta, sample_fn):
return sampling.plms_sample(sample_fn, x, steps, {})

@cog.input('prompt', type=str, help='The prompt for image generation')
@cog.input('seed', type=int, default=0, help='Random seed for reproducibility.')
@cog.input('steps', type=int, default=20, max=100, min=1, help='Number of steps to sample for.')
def predict(self, prompt: str, eta: float, seed: int, steps: int):
"""Run a single prediction on the model"""
_, side_y, side_x = self.model.shape
torch.manual_seed(seed)
zero_embed = torch.zeros([1, self.output_dim], device=self.device)
target_embeds, weights = [zero_embed], []
txt, weight = parse_prompt(prompt)
target_embeds.append(self.clip.encode_text(clip.tokenize(txt).to(self.device)).float())
weights.append(weight)
weights = torch.tensor([1 - sum(weights), *weights], device=self.device)
def cfg_sample_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
v = vs.mul(weights[:, None, None, None, None]).sum(0)
return v
x = torch.randn([1, 3, side_y, side_x], device=self.device)
t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1]
steps = utils.get_spliced_ddpm_cosine_schedule(t)
output_image = self.run_sampling(x, steps, eta, cfg_sample_fn)
out_path = Path(tempfile.mkdtemp()) / "out.png"
utils.to_pil_image(output_image).save(out_path)
return out_path
24 changes: 24 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
# set to true if your model requires a GPU
gpu: true

# python version in the form '3.8' or '3.8.12'
python_version: "3.8"

# a list of packages in the format <package-name>==<version>
python_packages:
- "torch==1.9.0"
- "torchvision==0.9.0"
- "ftfy==6.0.3"
- "tqdm==4.62.3"
- "pillow==9.0.1"
# commands run after the environment is setup
run:
# Most recent commit to master as 2022-02-07, just wanted to pin it to single one
- "pip install git+https://github.com/openai/CLIP.git@40f5484c1c74edd83cb9cf687c6ab92b28d8b656"

# predict.py defines how predictions are run on your model
predict: "cfg_predict.py:ClassifierFreeGuidanceDiffusionPredictor"
2 changes: 1 addition & 1 deletion diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def alpha_sigma_to_t(alpha, sigma):

def get_ddpm_schedule(ddpm_t):
"""Returns timesteps for the noise schedule from the DDPM paper."""
log_snr = -torch.special.expm1(1e-4 + 10 * ddpm_t**2).log()
log_snr = -torch.expm1(1e-4 + 10 * ddpm_t**2).log()
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return alpha_sigma_to_t(alpha, sigma)

Expand Down