Skip to content

Conversation

@MarkRich
Copy link
Contributor

@MarkRich MarkRich commented Oct 23, 2022

Part of #876 and bigger story #841

Followed the rough format of #900 and the code here: https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch/tree/16e630f7fee5483bb986cb828762d6e06c2869c3, here's a script that allows for users to create "composable" prompts delimited by |.

For instance:

import torch as th
import numpy as np
import torchvision.utils as tvu
from diffusers import DiffusionPipeline

has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')

pipe = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=True,
    custom_pipeline="composable_stable_diffusion"
).to(device)


def dummy(images, **kwargs):
    return images, False

pipe.safety_checker = dummy

images = []
generator = th.Generator("cuda").manual_seed(0)

seed = 0
prompt = "a forest | a camel"
weights = " 1 | 1"  # Equal weight to each prompt. Cna be negative

images = []
for i in range(4):
    res = pipe(
        prompt,
        guidance_scale=7.5,
        num_inference_steps=50,
        weights=weights,
        generator=generator)
    image = res.images[0]
    images.append(image)

for i, img in enumerate(images):
    img.save(f"./composable_diffusion/image_{i}.png")

yields:
image_0
image_1
image_2
image_3

Note for reviewers:
The script above is actually not 100% accurate, if I run it exactly as is I get the following error:

Could not locate the pipeline.py inside composable_stable_diffusion.
Traceback (most recent call last):
  File "/home/mark/.local/lib/python3.8/site-packages/huggingface_hub/utils/_errors.py", line 213, in hf_raise_for_status
    response.raise_for_status()
  File "/usr/local/lib/python3.8/dist-packages/requests/models.py", line 941, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/composable_stable_diffusion.py

The trick is I have to change one line in the script to:

pipe = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=True,
    custom_pipeline="/home/mark/open_source/diffusers/examples/community"
).to(device)

and in diffusers/src/pipeline_utils.py I need to change the following line:

CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"

to

CUSTOM_PIPELINE_FILE_NAME = "composable_stable_diffusion.py"

Given the nature of the error above (i.e. it is searching on github for composable_stable_diffusion.py" I am guessing that I am just missing something about how we're expected to test these changes? 😅 . Open to more testing!

Anyway, based on the initial tests seems like it works

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 23, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Very cool! Thanks a lot for the PR :-)

@patrickvonplaten
Copy link
Contributor

cc @osanseviero @apolinario FYI

@patrickvonplaten patrickvonplaten merged commit 38ae5a2 into huggingface:main Oct 25, 2022
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
This reverts commit 8115b26.
Additionally fixes img2col by adding detach elementwise from named op
passes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants