From b75ecf79c2d80b5488bec69c5b76552c649d5498 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 15:38:58 -0800 Subject: [PATCH 01/12] First pass at doing classifier-free guidance sampling with Cog --- cfg_predict.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ cog.yaml | 29 ++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 cfg_predict.py create mode 100644 cog.yaml diff --git a/cfg_predict.py b/cfg_predict.py new file mode 100644 index 0000000..8754538 --- /dev/null +++ b/cfg_predict.py @@ -0,0 +1,90 @@ +# Prediction interface for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/python.md + +import cog +from pathlib import Path +from PIL import Image +import tempfile +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from torchvision.transforms import functional as TF + +from CLIP import clip +from diffusion import get_model, sampling, utils + + +def resize_and_center_crop(image, size): + fac = max(size[0] / image.size[0], size[1] / image.size[1]) + image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) + return TF.center_crop(image, size[::-1]) + + +def parse_prompt(prompt, default_weight=3.): + if prompt.startswith('http://') or prompt.startswith('https://'): + vals = prompt.rsplit(':', 2) + vals = [vals[0] + ':' + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(':', 1) + vals = vals + ['', default_weight][len(vals):] + return vals[0], float(vals[1]) + + +class ClassifierFreeGuidanceDiffusionSampler(cog.Predictor): + model_name = 'cc12m_1_cfg' + checkpoint_path = 'checkpoints/cc12m_1_cfg.pth' + device = 'cuda:0' + + def setup(self): + """Load the model into memory to make running multiple predictions efficient""" + assert torch.cuda.is_available() + self.model = get_model(self.model_name)() + self.model.load_state_dict(torch.load(self.checkpoint_path, map_location='cpu')) + self.model.half() + self.model.to(self.device).eval().requires_grad_(False) + self.clip = clip.load('ViT-B/16', jit=False, device=self.device)[0] + self.clip.eval().requires_grad_(False) + self.normalize_fn = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711] + ) + + def normalize(self, image): + return self.normalize_fn(image) + + + def cfg_sample_fn(self, x, t, target_embeds, weights): + n = x.shape[0] + n_conds = len(target_embeds) + x_in = x.repeat([n_conds, 1, 1, 1]) + t_in = t.repeat([n_conds]) + clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) + vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) + v = vs.mul(weights[:, None, None, None, None]).sum(0) + return v + + + def run_sampling(self, x, steps, eta): + return sampling.sample(self.cfg_sample_fn, x, steps, eta, {}) + + @cog.input('prompt', type=str, help='The prompt for image generation') + @cog.input("eta", type=float, default=1.0, help='The amount of randomness') + @cog.input('seed', type=int, default=0, help='Random seed for reproducibility.') + @cog.input('steps', type=int, default=500, max=1000, min=0, help='Number of steps to sample for.') + def predict(self, prompt: str, eta: float, seed: int, steps: int): + """Run a single prediction on the model""" + _, side_y, side_x = self.model.shape + torch.manual_seed(seed) + zero_embed = torch.zeros([1, clip.visual.output_dim], device=self.device) + target_embeds, weights = [zero_embed], [] + txt, weight = parse_prompt(prompt) + target_embeds.append(self.clip.encode_text(clip.tokenize(txt).to(self.device)).float()) + weights.append(weight) + x = torch.randn([1, 3, side_y, side_x], device=self.device) + t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] + steps = utils.get_spliced_ddpm_cosine_schedule(t) + output_image = self.run_sampling(x, steps, eta) + out_path = Path(tempfile.mkdtemp()) / "my-file.txt" + utils.to_pil_image(output_image).save(out_path) + return out_path diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..8ae297c --- /dev/null +++ b/cog.yaml @@ -0,0 +1,29 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + # set to true if your model requires a GPU + gpu: true + + # a list of ubuntu apt packages to install + system_packages: + # - "libgl1-mesa-glx" + # - "libglib2.0-0" + + # python version in the form '3.8' or '3.8.12' + python_version: "3.8" + + # a list of packages in the format == + python_packages: + - "torch==1.9.0" + - "torchvision==0.9.0" + - "ftfy==6.0.3" + - "tqdm==4.62.3" + - "pillow==9.0.1" + # commands run after the environment is setup + run: + # Most recent commit to master as 2022-02-07, just wanted to pin it to single one + - "pip install git+https://github.com/openai/CLIP.git@40f5484c1c74edd83cb9cf687c6ab92b28d8b656" + +# predict.py defines how predictions are run on your model +predict: "cfg_predict.py:Predictor" From 1b513ea48a47aad02f227b275a64310a07cf936c Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 15:42:14 -0800 Subject: [PATCH 02/12] Not installing any system packages --- cog.yaml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cog.yaml b/cog.yaml index 8ae297c..1c1ddd9 100644 --- a/cog.yaml +++ b/cog.yaml @@ -4,12 +4,7 @@ build: # set to true if your model requires a GPU gpu: true - - # a list of ubuntu apt packages to install - system_packages: - # - "libgl1-mesa-glx" - # - "libglib2.0-0" - + # python version in the form '3.8' or '3.8.12' python_version: "3.8" From 7b2447b3f5005e3f0a91c9246067059364699d79 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 16:00:59 -0800 Subject: [PATCH 03/12] Update name of predictor class --- cfg_predict.py | 2 +- cog.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cfg_predict.py b/cfg_predict.py index 8754538..252f76f 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -31,7 +31,7 @@ def parse_prompt(prompt, default_weight=3.): return vals[0], float(vals[1]) -class ClassifierFreeGuidanceDiffusionSampler(cog.Predictor): +class ClassifierFreeGuidanceDiffusionPredictor(cog.Predictor): model_name = 'cc12m_1_cfg' checkpoint_path = 'checkpoints/cc12m_1_cfg.pth' device = 'cuda:0' diff --git a/cog.yaml b/cog.yaml index 1c1ddd9..c782422 100644 --- a/cog.yaml +++ b/cog.yaml @@ -4,7 +4,7 @@ build: # set to true if your model requires a GPU gpu: true - + # python version in the form '3.8' or '3.8.12' python_version: "3.8" @@ -21,4 +21,4 @@ build: - "pip install git+https://github.com/openai/CLIP.git@40f5484c1c74edd83cb9cf687c6ab92b28d8b656" # predict.py defines how predictions are run on your model -predict: "cfg_predict.py:Predictor" +predict: "cfg_predict.py:ClassifierFreeGuidanceDiffusionPredictor" From d53dcc6c7375e941b9492cfdf5796baf6deae537 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 16:02:20 -0800 Subject: [PATCH 04/12] Lower pytorch version to 1.8 --- cog.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cog.yaml b/cog.yaml index c782422..21c2383 100644 --- a/cog.yaml +++ b/cog.yaml @@ -10,7 +10,7 @@ build: # a list of packages in the format == python_packages: - - "torch==1.9.0" + - "torch==1.8.0" - "torchvision==0.9.0" - "ftfy==6.0.3" - "tqdm==4.62.3" From 9077ca34fdc8bac337b61756e66f5ea35c14e5fb Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 16:26:16 -0800 Subject: [PATCH 05/12] Fixed bug in the import --- cfg_predict.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cfg_predict.py b/cfg_predict.py index 252f76f..7df0e1d 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -50,10 +50,13 @@ def setup(self): std=[0.26862954, 0.26130258, 0.27577711] ) + @property + def output_dim(self): + return self.clip.visual.output_dim + def normalize(self, image): return self.normalize_fn(image) - def cfg_sample_fn(self, x, t, target_embeds, weights): n = x.shape[0] n_conds = len(target_embeds) @@ -64,7 +67,6 @@ def cfg_sample_fn(self, x, t, target_embeds, weights): v = vs.mul(weights[:, None, None, None, None]).sum(0) return v - def run_sampling(self, x, steps, eta): return sampling.sample(self.cfg_sample_fn, x, steps, eta, {}) @@ -76,7 +78,7 @@ def predict(self, prompt: str, eta: float, seed: int, steps: int): """Run a single prediction on the model""" _, side_y, side_x = self.model.shape torch.manual_seed(seed) - zero_embed = torch.zeros([1, clip.visual.output_dim], device=self.device) + zero_embed = torch.zeros([1, self.output_dim], device=self.device) target_embeds, weights = [zero_embed], [] txt, weight = parse_prompt(prompt) target_embeds.append(self.clip.encode_text(clip.tokenize(txt).to(self.device)).float()) From 3ef585dd66d1e5540bd3646524d4676becd6d661 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 16:32:34 -0800 Subject: [PATCH 06/12] Bring into line with torch 1.8 --- diffusion/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/utils.py b/diffusion/utils.py index 0d3cc4e..22aaf49 100644 --- a/diffusion/utils.py +++ b/diffusion/utils.py @@ -62,7 +62,7 @@ def alpha_sigma_to_t(alpha, sigma): def get_ddpm_schedule(ddpm_t): """Returns timesteps for the noise schedule from the DDPM paper.""" - log_snr = -torch.special.expm1(1e-4 + 10 * ddpm_t**2).log() + log_snr = -torch.expm1(1e-4 + 10 * ddpm_t**2).log() alpha, sigma = log_snr_to_alpha_sigma(log_snr) return alpha_sigma_to_t(alpha, sigma) From dcaa6316d8d84e4b390eb0d8022c910505159ba0 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 16:47:06 -0800 Subject: [PATCH 07/12] cfg_sample_fn needs to have a specific signature --- cfg_predict.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/cfg_predict.py b/cfg_predict.py index 7df0e1d..e5499c4 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -57,18 +57,8 @@ def output_dim(self): def normalize(self, image): return self.normalize_fn(image) - def cfg_sample_fn(self, x, t, target_embeds, weights): - n = x.shape[0] - n_conds = len(target_embeds) - x_in = x.repeat([n_conds, 1, 1, 1]) - t_in = t.repeat([n_conds]) - clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) - vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) - v = vs.mul(weights[:, None, None, None, None]).sum(0) - return v - - def run_sampling(self, x, steps, eta): - return sampling.sample(self.cfg_sample_fn, x, steps, eta, {}) + def run_sampling(self, x, steps, eta, sample_fn): + return sampling.sample(sample_fn, x, steps, eta, {}) @cog.input('prompt', type=str, help='The prompt for image generation') @cog.input("eta", type=float, default=1.0, help='The amount of randomness') @@ -83,10 +73,21 @@ def predict(self, prompt: str, eta: float, seed: int, steps: int): txt, weight = parse_prompt(prompt) target_embeds.append(self.clip.encode_text(clip.tokenize(txt).to(self.device)).float()) weights.append(weight) + + def cfg_sample_fn(x, t): + n = x.shape[0] + n_conds = len(target_embeds) + x_in = x.repeat([n_conds, 1, 1, 1]) + t_in = t.repeat([n_conds]) + clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) + vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) + v = vs.mul(weights[:, None, None, None, None]).sum(0) + return v + x = torch.randn([1, 3, side_y, side_x], device=self.device) t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] steps = utils.get_spliced_ddpm_cosine_schedule(t) - output_image = self.run_sampling(x, steps, eta) + output_image = self.run_sampling(x, steps, eta, cfg_sample_fn) out_path = Path(tempfile.mkdtemp()) / "my-file.txt" utils.to_pil_image(output_image).save(out_path) return out_path From fe674f40d311063f272042a93d31bae27aef210b Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 16:53:58 -0800 Subject: [PATCH 08/12] cfg_sample_fn needs to have a specific signature --- cfg_predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cfg_predict.py b/cfg_predict.py index e5499c4..afe6ebb 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -73,7 +73,7 @@ def predict(self, prompt: str, eta: float, seed: int, steps: int): txt, weight = parse_prompt(prompt) target_embeds.append(self.clip.encode_text(clip.tokenize(txt).to(self.device)).float()) weights.append(weight) - + weights = torch.tensor([1 - sum(weights), *weights], device=self.device) def cfg_sample_fn(x, t): n = x.shape[0] n_conds = len(target_embeds) @@ -83,7 +83,7 @@ def cfg_sample_fn(x, t): vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) v = vs.mul(weights[:, None, None, None, None]).sum(0) return v - + x = torch.randn([1, 3, side_y, side_x], device=self.device) t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] steps = utils.get_spliced_ddpm_cosine_schedule(t) From 21f547aa9ee0b014036953bc6a3538650e8e2087 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 17:00:36 -0800 Subject: [PATCH 09/12] Dumb dumb dumb --- cfg_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cfg_predict.py b/cfg_predict.py index afe6ebb..8ac97a4 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -88,6 +88,6 @@ def cfg_sample_fn(x, t): t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] steps = utils.get_spliced_ddpm_cosine_schedule(t) output_image = self.run_sampling(x, steps, eta, cfg_sample_fn) - out_path = Path(tempfile.mkdtemp()) / "my-file.txt" + out_path = Path(tempfile.mkdtemp()) / "out.png" utils.to_pil_image(output_image).save(out_path) return out_path From 3b8743b19598048d6175840eb7a5167acd40a51a Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 7 Feb 2022 17:27:41 -0800 Subject: [PATCH 10/12] Add cog.yaml and cfg_predict.py files to be able to run a demo of the cfg model on replicate.com --- README.md | 2 + cfg_predict.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++ cog.yaml | 24 ++++++++++++ diffusion/utils.py | 2 +- 4 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 cfg_predict.py create mode 100644 cog.yaml diff --git a/README.md b/README.md index 8d494b5..cb83310 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,8 @@ If they are somewhere else, you need to specify the path to the checkpoint with ### CFG sampling (best, but only cc12m_1_cfg supports it) +[Demo and Docker image on Replicate](https://replicate.ai/crowsonkb/clip-guided-diffusion-cfg) + ``` usage: cfg_sample.py [-h] [--images [IMAGE ...]] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--device DEVICE] [--eta ETA] [--init INIT] diff --git a/cfg_predict.py b/cfg_predict.py new file mode 100644 index 0000000..8ac97a4 --- /dev/null +++ b/cfg_predict.py @@ -0,0 +1,93 @@ +# Prediction interface for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/python.md + +import cog +from pathlib import Path +from PIL import Image +import tempfile +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from torchvision.transforms import functional as TF + +from CLIP import clip +from diffusion import get_model, sampling, utils + + +def resize_and_center_crop(image, size): + fac = max(size[0] / image.size[0], size[1] / image.size[1]) + image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) + return TF.center_crop(image, size[::-1]) + + +def parse_prompt(prompt, default_weight=3.): + if prompt.startswith('http://') or prompt.startswith('https://'): + vals = prompt.rsplit(':', 2) + vals = [vals[0] + ':' + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(':', 1) + vals = vals + ['', default_weight][len(vals):] + return vals[0], float(vals[1]) + + +class ClassifierFreeGuidanceDiffusionPredictor(cog.Predictor): + model_name = 'cc12m_1_cfg' + checkpoint_path = 'checkpoints/cc12m_1_cfg.pth' + device = 'cuda:0' + + def setup(self): + """Load the model into memory to make running multiple predictions efficient""" + assert torch.cuda.is_available() + self.model = get_model(self.model_name)() + self.model.load_state_dict(torch.load(self.checkpoint_path, map_location='cpu')) + self.model.half() + self.model.to(self.device).eval().requires_grad_(False) + self.clip = clip.load('ViT-B/16', jit=False, device=self.device)[0] + self.clip.eval().requires_grad_(False) + self.normalize_fn = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711] + ) + + @property + def output_dim(self): + return self.clip.visual.output_dim + + def normalize(self, image): + return self.normalize_fn(image) + + def run_sampling(self, x, steps, eta, sample_fn): + return sampling.sample(sample_fn, x, steps, eta, {}) + + @cog.input('prompt', type=str, help='The prompt for image generation') + @cog.input("eta", type=float, default=1.0, help='The amount of randomness') + @cog.input('seed', type=int, default=0, help='Random seed for reproducibility.') + @cog.input('steps', type=int, default=500, max=1000, min=0, help='Number of steps to sample for.') + def predict(self, prompt: str, eta: float, seed: int, steps: int): + """Run a single prediction on the model""" + _, side_y, side_x = self.model.shape + torch.manual_seed(seed) + zero_embed = torch.zeros([1, self.output_dim], device=self.device) + target_embeds, weights = [zero_embed], [] + txt, weight = parse_prompt(prompt) + target_embeds.append(self.clip.encode_text(clip.tokenize(txt).to(self.device)).float()) + weights.append(weight) + weights = torch.tensor([1 - sum(weights), *weights], device=self.device) + def cfg_sample_fn(x, t): + n = x.shape[0] + n_conds = len(target_embeds) + x_in = x.repeat([n_conds, 1, 1, 1]) + t_in = t.repeat([n_conds]) + clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) + vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) + v = vs.mul(weights[:, None, None, None, None]).sum(0) + return v + + x = torch.randn([1, 3, side_y, side_x], device=self.device) + t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] + steps = utils.get_spliced_ddpm_cosine_schedule(t) + output_image = self.run_sampling(x, steps, eta, cfg_sample_fn) + out_path = Path(tempfile.mkdtemp()) / "out.png" + utils.to_pil_image(output_image).save(out_path) + return out_path diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..c782422 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,24 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + # set to true if your model requires a GPU + gpu: true + + # python version in the form '3.8' or '3.8.12' + python_version: "3.8" + + # a list of packages in the format == + python_packages: + - "torch==1.9.0" + - "torchvision==0.9.0" + - "ftfy==6.0.3" + - "tqdm==4.62.3" + - "pillow==9.0.1" + # commands run after the environment is setup + run: + # Most recent commit to master as 2022-02-07, just wanted to pin it to single one + - "pip install git+https://github.com/openai/CLIP.git@40f5484c1c74edd83cb9cf687c6ab92b28d8b656" + +# predict.py defines how predictions are run on your model +predict: "cfg_predict.py:ClassifierFreeGuidanceDiffusionPredictor" diff --git a/diffusion/utils.py b/diffusion/utils.py index 0d3cc4e..22aaf49 100644 --- a/diffusion/utils.py +++ b/diffusion/utils.py @@ -62,7 +62,7 @@ def alpha_sigma_to_t(alpha, sigma): def get_ddpm_schedule(ddpm_t): """Returns timesteps for the noise schedule from the DDPM paper.""" - log_snr = -torch.special.expm1(1e-4 + 10 * ddpm_t**2).log() + log_snr = -torch.expm1(1e-4 + 10 * ddpm_t**2).log() alpha, sigma = log_snr_to_alpha_sigma(log_snr) return alpha_sigma_to_t(alpha, sigma) From 0f3d53b6bf9b7d335c4b45e57ccd224f1e8ecf8c Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Fri, 11 Feb 2022 12:13:32 -0800 Subject: [PATCH 11/12] Update to use PLMS sampling --- cfg_predict.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cfg_predict.py b/cfg_predict.py index 8ac97a4..76fbdfa 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -58,12 +58,11 @@ def normalize(self, image): return self.normalize_fn(image) def run_sampling(self, x, steps, eta, sample_fn): - return sampling.sample(sample_fn, x, steps, eta, {}) + return sampling.plms_sample(sample_fn, x, steps, {}) @cog.input('prompt', type=str, help='The prompt for image generation') - @cog.input("eta", type=float, default=1.0, help='The amount of randomness') @cog.input('seed', type=int, default=0, help='Random seed for reproducibility.') - @cog.input('steps', type=int, default=500, max=1000, min=0, help='Number of steps to sample for.') + @cog.input('steps', type=int, default=20, max=100, min=1, help='Number of steps to sample for.') def predict(self, prompt: str, eta: float, seed: int, steps: int): """Run a single prediction on the model""" _, side_y, side_x = self.model.shape @@ -83,7 +82,6 @@ def cfg_sample_fn(x, t): vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) v = vs.mul(weights[:, None, None, None, None]).sum(0) return v - x = torch.randn([1, 3, side_y, side_x], device=self.device) t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] steps = utils.get_spliced_ddpm_cosine_schedule(t) From 486eac86e33fd5da2d515f7209e498f674db8cc3 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Fri, 11 Feb 2022 12:21:49 -0800 Subject: [PATCH 12/12] How did I forget the git conflict lines??? --- cfg_predict.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cfg_predict.py b/cfg_predict.py index f7e1c47..76fbdfa 100644 --- a/cfg_predict.py +++ b/cfg_predict.py @@ -82,10 +82,6 @@ def cfg_sample_fn(x, t): vs = self.model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) v = vs.mul(weights[:, None, None, None, None]).sum(0) return v -<<<<<<< HEAD - -======= ->>>>>>> 0f3d53b6bf9b7d335c4b45e57ccd224f1e8ecf8c x = torch.randn([1, 3, side_y, side_x], device=self.device) t = torch.linspace(1, 0, steps + 1, device=self.device)[:-1] steps = utils.get_spliced_ddpm_cosine_schedule(t)