Skip to content

Conversation

@pcuenca
Copy link
Member

@pcuenca pcuenca commented Sep 19, 2022

Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects.

I started with FlaxPNDMScheduler, and temporarily removed step_prk.

Otherwise jitting/parallelization don't work properly as they don't know
how to deal with traced objects.

I temporarily removed `step_prk`.
@pcuenca pcuenca marked this pull request as draft September 19, 2022 17:33
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 19, 2022

The documentation is not available anymore as the PR was closed or merged.

@pcuenca
Copy link
Member Author

pcuenca commented Sep 19, 2022

We can test with just jax.jit like this:

import jax
import jax.numpy as jnp

from diffusers import FlaxPNDMScheduler

scheduler = FlaxPNDMScheduler.from_config(PATH_TO_SCHEDULER_DIR)

latents_shape = (1, 64, 64, 3)
scheduler_state = scheduler.set_timesteps(
    scheduler.state,
    shape = latents_shape,         # Needs to be known in advance to reserve space
    num_inference_steps = 50,
)

key1, key2 = jax.random.split(jax.random.PRNGKey(0))
latents = jax.random.normal(key1, shape=latents_shape, dtype=jnp.float32)
noise = jax.random.normal(key2, shape=latents_shape, dtype=jnp.float32)

p_step = jax.jit(scheduler.step, static_argnums=4)
latents, scheduler_state = p_step(scheduler_state, noise, 37, latents, return_dict=False)

This example should work with both jax.jit and without (invoking scheduler.step directly), for all schedulers.
I'll prepare a more complete example with pmap later.

@patrickvonplaten
Copy link
Contributor

Cool! Will focus on DDIM for now to get the pipeline working with it

@pcuenca pcuenca mentioned this pull request Sep 20, 2022
@pcuenca
Copy link
Member Author

pcuenca commented Sep 21, 2022

Replaced by #583 for PNDM. We'll open separate PRs for others.

@pcuenca pcuenca closed this Sep 21, 2022
@pcuenca pcuenca deleted the functional-schedulers branch October 2, 2022 18:03
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
* Minor fixes to benchmark runner

* Add Mnasnet to tank.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants