diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index a6d342f575a9..893fc6bea0ca 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DOutput [[autodoc]] models.unet_2d.UNet2DOutput -## UNet1DModel -[[autodoc]] UNet1DModel - ## UNet2DModel [[autodoc]] UNet2DModel +## UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput + +## UNet1DModel +[[autodoc]] UNet1DModel + ## UNet2DConditionOutput [[autodoc]] models.unet_2d_condition.UNet2DConditionOutput @@ -37,12 +40,6 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## DecoderOutput [[autodoc]] models.vae.DecoderOutput -## UNet1DModel -[[autodoc]] UNet1DModel - -## UNet1DOutput -[[autodoc]] models.unet_1d.UNet1DOutput - ## VQEncoderOutput [[autodoc]] models.vae.VQEncoderOutput diff --git a/examples/README.md b/examples/README.md index 2680b638d585..a50fbfc3a713 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,9 +36,10 @@ If you feel like another important example should exist, we are more than happy Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support: -| Task | ๐Ÿค— Accelerate | ๐Ÿค— Datasets | Colab -|---|---|:---:|:---:| +| Task | ๐Ÿค— Accelerate | ๐Ÿค— Datasets | Colab +|---------------------------------------------------------------------------------------------------------------------------------------------------------|---|:---:|:---:| | [**Unconditional Image Generation**](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) | โœ… | โœ… | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | | | coming soon. ## Community diff --git a/examples/community/value_guided_diffuser.py b/examples/community/value_guided_diffuser.py deleted file mode 100644 index 6b28e868eddd..000000000000 --- a/examples/community/value_guided_diffuser.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch - -import tqdm -from diffusers import DiffusionPipeline -from diffusers.models.unet_1d import UNet1DModel -from diffusers.utils.dummy_pt_objects import DDPMScheduler - - -class ValueGuidedDiffuserPipeline(DiffusionPipeline): - def __init__( - self, - value_function: UNet1DModel, - unet: UNet1DModel, - scheduler: DDPMScheduler, - env, - ): - super().__init__() - self.value_function = value_function - self.unet = unet - self.scheduler = scheduler - self.env = env - self.data = env.get_dataset() - self.means = dict() - for key in self.data.keys(): - try: - self.means[key] = self.data[key].mean() - except: - pass - self.stds = dict() - for key in self.data.keys(): - try: - self.stds[key] = self.data[key].std() - except: - pass - self.state_dim = env.observation_space.shape[0] - self.action_dim = env.action_space.shape[0] - - def normalize(self, x_in, key): - return (x_in - self.means[key]) / self.stds[key] - - def de_normalize(self, x_in, key): - return x_in * self.stds[key] + self.means[key] - - def to_torch(self, x_in): - if type(x_in) is dict: - return {k: self.to_torch(v) for k, v in x_in.items()} - elif torch.is_tensor(x_in): - return x_in.to(self.unet.device) - return torch.tensor(x_in, device=self.unet.device) - - def reset_x0(self, x_in, cond, act_dim): - for key, val in cond.items(): - x_in[:, key, act_dim:] = val.clone() - return x_in - - def run_diffusion(self, x, conditions, n_guide_steps, scale): - batch_size = x.shape[0] - y = None - for i in tqdm.tqdm(self.scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - for _ in range(n_guide_steps): - with torch.enable_grad(): - x.requires_grad_() - y = self.value_function(x.permute(0, 2, 1), timesteps).sample - grad = torch.autograd.grad([y.sum()], [x])[0] - - posterior_variance = self.scheduler._get_variance(i) - model_std = torch.exp(0.5 * posterior_variance) - grad = model_std * grad - grad[timesteps < 2] = 0 - x = x.detach() - x = x + scale * grad - x = self.reset_x0(x, conditions, self.action_dim) - prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) - x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - - # apply conditions to the trajectory - x = self.reset_x0(x, conditions, self.action_dim) - x = self.to_torch(x) - return x, y - - def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): - # normalize the observations and create batch dimension - obs = self.normalize(obs, "observations") - obs = obs[None].repeat(batch_size, axis=0) - - conditions = {0: self.to_torch(obs)} - shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) - - # generate initial noise and apply our conditions (to make the trajectories start at current state) - x1 = torch.randn(shape, device=self.unet.device) - x = self.reset_x0(x1, conditions, self.action_dim) - x = self.to_torch(x) - - # run the diffusion process - x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) - - # sort output trajectories by value - sorted_idx = y.argsort(0, descending=True).squeeze() - sorted_values = x[sorted_idx] - actions = sorted_values[:, :, : self.action_dim] - actions = actions.detach().cpu().numpy() - denorm_actions = self.de_normalize(actions, key="actions") - - # select the action with the highest value - denorm_actions = denorm_actions[0, 0] - return denorm_actions diff --git a/examples/diffuser/run_diffuser.py b/examples/diffuser/run_diffuser.py deleted file mode 100644 index b29d89992dfc..000000000000 --- a/examples/diffuser/run_diffuser.py +++ /dev/null @@ -1,122 +0,0 @@ -import numpy as np -import torch - -import d4rl # noqa -import gym -import tqdm -import train_diffuser -from diffusers import DDPMScheduler, UNet1DModel - - -env_name = "hopper-medium-expert-v2" -env = gym.make(env_name) -data = env.get_dataset() # dataset is only used for normalization in this colab - -DEVICE = "cpu" -DTYPE = torch.float - -# diffusion model settings -n_samples = 4 # number of trajectories planned via diffusion -horizon = 128 # length of sampled trajectories -state_dim = env.observation_space.shape[0] -action_dim = env.action_space.shape[0] -num_inference_steps = 100 # number of difusion steps - - -# Two generators for different parts of the diffusion loop to work in colab -generator_cpu = torch.Generator(device="cpu") - -scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") - -# 3 different pretrained models are available for this task. -# The horizion represents the length of trajectories used in training. -network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) -# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) -# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) - - -# network specific constants for inference -clip_denoised = network.clip_denoised -predict_epsilon = network.predict_epsilon - -# [ observation_dim ] --> [ n_samples x observation_dim ] -obs = env.reset() -total_reward = 0 -done = False -T = 300 -rollout = [obs.copy()] - -try: - for t in tqdm.tqdm(range(T)): - obs_raw = obs - - # normalize observations for forward passes - obs = train_diffuser.normalize(obs, data, "observations") - obs = obs[None].repeat(n_samples, axis=0) - conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} - - # constants for inference - batch_size = len(conditions[0]) - shape = (batch_size, horizon, state_dim + action_dim) - - # sample random initial noise vector - x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) - - # this model is conditioned from an initial state, so you will see this function - # multiple times to change the initial state of generated data to the state - # generated via env.reset() above or env.step() below - x = train_diffuser.reset_x0(x1, conditions, action_dim) - - # convert a np observation to torch for model forward pass - x = train_diffuser.to_torch(x) - - eta = 1.0 # noise factor for sampling reconstructed state - - # run the diffusion process - # for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - for i in tqdm.tqdm(scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) - - # 1. generate prediction from model - with torch.no_grad(): - residual = network(x, timesteps).sample - - # 2. use the model prediction to reconstruct an observation (de-noise) - obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] - - # 3. [optional] add posterior noise to the sample - if eta > 0: - noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) - posterior_variance = scheduler._get_variance(i) # * noise - # no noise when t == 0 - # NOTE: original implementation missing sqrt on posterior_variance - obs_reconstruct = ( - obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise - ) # MJ had as log var, exponentiated - - # 4. apply conditions to the trajectory - obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) - x = train_diffuser.to_torch(obs_reconstruct_postcond) - plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) - # select random plan - idx = np.random.randint(plans.shape[0]) - # select action at correct time - action = plans[idx, 0, :] - actions = train_diffuser.de_normalize(action, data, "actions") - # execute action in environment - next_observation, reward, terminal, _ = env.step(action) - - # update return - total_reward += reward - print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") - - # save observations for rendering - rollout.append(next_observation.copy()) - obs = next_observation -except KeyboardInterrupt: - pass - -print(f"Total reward: {total_reward}") -render = train_diffuser.MuJoCoRenderer(env) -train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) diff --git a/examples/diffuser/run_diffuser_value_guided.py b/examples/diffuser/run_diffuser_value_guided.py deleted file mode 100644 index 707663abb3bf..000000000000 --- a/examples/diffuser/run_diffuser_value_guided.py +++ /dev/null @@ -1,69 +0,0 @@ -import d4rl # noqa -import gym -import tqdm -from diffusers import DiffusionPipeline - - -config = dict( - n_samples=64, - horizon=32, - num_inference_steps=20, - n_guide_steps=2, - scale_grad_by_std=True, - scale=0.1, - eta=0.0, - t_grad_cutoff=2, - device="cpu", -) - - -def _run(): - env_name = "hopper-medium-v2" - env = gym.make(env_name) - - pipeline = DiffusionPipeline.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", - env=env, - custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", - ) - - # add a batch dimension and repeat for multiple samples - # [ observation_dim ] --> [ n_samples x observation_dim ] - env.seed(0) - obs = env.reset() - total_reward = 0 - total_score = 0 - T = 1000 - rollout = [obs.copy()] - try: - for t in tqdm.tqdm(range(T)): - # 1. Call the policy - # normalize observations for forward passes - denorm_actions = pipeline(obs, planning_horizon=32) - - # execute action in environment - next_observation, reward, terminal, _ = env.step(denorm_actions) - score = env.get_normalized_score(total_reward) - # update return - total_reward += reward - total_score += score - print( - f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" - f" {total_score}" - ) - # save observations for rendering - rollout.append(next_observation.copy()) - - obs = next_observation - except KeyboardInterrupt: - pass - - print(f"Total reward: {total_reward}") - - -def run(): - _run() - - -if __name__ == "__main__": - run() diff --git a/examples/diffuser/train_diffuser.py b/examples/diffuser/train_diffuser.py deleted file mode 100644 index b063a0456d97..000000000000 --- a/examples/diffuser/train_diffuser.py +++ /dev/null @@ -1,312 +0,0 @@ -import os -import warnings - -import numpy as np -import torch - -import d4rl # noqa -import gym -import mediapy as media -import mujoco_py as mjc -import tqdm -from diffusers import DDPMScheduler, UNet1DModel - - -# Define some helper functions - - -DTYPE = torch.float - - -def normalize(x_in, data, key): - means = data[key].mean(axis=0) - stds = data[key].std(axis=0) - return (x_in - means) / stds - - -def de_normalize(x_in, data, key): - means = data[key].mean(axis=0) - stds = data[key].std(axis=0) - return x_in * stds + means - - -def to_torch(x_in, dtype=None, device="cuda"): - dtype = dtype or DTYPE - device = device - if type(x_in) is dict: - return {k: to_torch(v, dtype, device) for k, v in x_in.items()} - elif torch.is_tensor(x_in): - return x_in.to(device).type(dtype) - return torch.tensor(x_in, dtype=dtype, device=device) - - -def reset_x0(x_in, cond, act_dim): - for key, val in cond.items(): - x_in[:, key, act_dim:] = val.clone() - return x_in - - -def run_diffusion(x, scheduler, network, unet, conditions, action_dim, config): - y = None - for i in tqdm.tqdm(scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((config["n_samples"],), i, device=config["device"], dtype=torch.long) - # 3. call the sample function - for _ in range(config["n_guide_steps"]): - with torch.enable_grad(): - x.requires_grad_() - y = network(x, timesteps).sample - grad = torch.autograd.grad([y.sum()], [x])[0] - if config["scale_grad_by_std"]: - posterior_variance = scheduler._get_variance(i) - model_std = torch.exp(0.5 * posterior_variance) - grad = model_std * grad - grad[timesteps < config["t_grad_cutoff"]] = 0 - x = x.detach() - x = x + config["scale"] * grad - x = reset_x0(x, conditions, action_dim) - # with torch.no_grad(): - prev_x = unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) - x = scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - - # 3. [optional] add posterior noise to the sample - if config["eta"] > 0: - noise = torch.randn(x.shape).to(x.device) - posterior_variance = scheduler._get_variance(i) # * noise - # no noise when t == 0 - # NOTE: original implementation missing sqrt on posterior_variance - x = x + int(i > 0) * (0.5 * posterior_variance) * config["eta"] * noise # MJ had as log var, exponentiated - - # 4. apply conditions to the trajectory - x = reset_x0(x, conditions, action_dim) - x = to_torch(x, device=config["device"]) - # y = network(x, timesteps).sample - return x, y - - -def to_np(x_in): - if torch.is_tensor(x_in): - x_in = x_in.detach().cpu().numpy() - return x_in - - -# from MJ's Diffuser code -# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79 -def mkdir(savepath): - """ - returns `True` iff `savepath` is created - """ - if not os.path.exists(savepath): - os.makedirs(savepath) - return True - else: - return False - - -def show_sample(renderer, observations, filename="sample.mp4", savebase="videos"): - """ - observations : [ batch_size x horizon x observation_dim ] - """ - - mkdir(savebase) - savepath = os.path.join(savebase, filename) - - images = [] - for rollout in observations: - # [ horizon x height x width x channels ] - img = renderer._renders(rollout, partial=True) - images.append(img) - - # [ horizon x height x (batch_size * width) x channels ] - images = np.concatenate(images, axis=2) - media.write_video(savepath, images, fps=60) - media.show_video(images, codec="h264", fps=60) - return images - - -# Code adapted from Michael Janner -# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py - - -def env_map(env_name): - """ - map D4RL dataset names to custom fully-observed - variants for rendering - """ - if "halfcheetah" in env_name: - return "HalfCheetahFullObs-v2" - elif "hopper" in env_name: - return "HopperFullObs-v2" - elif "walker2d" in env_name: - return "Walker2dFullObs-v2" - else: - return env_name - - -def get_image_mask(img): - background = (img == 255).all(axis=-1, keepdims=True) - mask = ~background.repeat(3, axis=-1) - return mask - - -def atmost_2d(x): - while x.ndim > 2: - x = x.squeeze(0) - return x - - -def set_state(env, state): - qpos_dim = env.sim.data.qpos.size - qvel_dim = env.sim.data.qvel.size - if not state.size == qpos_dim + qvel_dim: - warnings.warn( - f"[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, but got state of size {state.size}" - ) - state = state[: qpos_dim + qvel_dim] - - env.set_state(state[:qpos_dim], state[qpos_dim:]) - - -class MuJoCoRenderer: - """ - default mujoco renderer - """ - - def __init__(self, env): - if type(env) is str: - env = env_map(env) - self.env = gym.make(env) - else: - self.env = env - # - 1 because the envs in renderer are fully-observed - # @TODO : clean up - self.observation_dim = np.prod(self.env.observation_space.shape) - 1 - self.action_dim = np.prod(self.env.action_space.shape) - try: - self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) - except: - print("[ utils/rendering ] Warning: could not initialize offscreen renderer") - self.viewer = None - - def pad_observation(self, observation): - state = np.concatenate( - [ - np.zeros(1), - observation, - ] - ) - return state - - def pad_observations(self, observations): - qpos_dim = self.env.sim.data.qpos.size - # xpos is hidden - xvel_dim = qpos_dim - 1 - xvel = observations[:, xvel_dim] - xpos = np.cumsum(xvel) * self.env.dt - states = np.concatenate( - [ - xpos[:, None], - observations, - ], - axis=-1, - ) - return states - - def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None): - if type(dim) == int: - dim = (dim, dim) - - if self.viewer is None: - return np.zeros((*dim, 3), np.uint8) - - if render_kwargs is None: - xpos = observation[0] if not partial else 0 - render_kwargs = {"trackbodyid": 2, "distance": 3, "lookat": [xpos, -0.5, 1], "elevation": -20} - - for key, val in render_kwargs.items(): - if key == "lookat": - self.viewer.cam.lookat[:] = val[:] - else: - setattr(self.viewer.cam, key, val) - - if partial: - state = self.pad_observation(observation) - else: - state = observation - - qpos_dim = self.env.sim.data.qpos.size - if not qvel or state.shape[-1] == qpos_dim: - qvel_dim = self.env.sim.data.qvel.size - state = np.concatenate([state, np.zeros(qvel_dim)]) - - set_state(self.env, state) - - self.viewer.render(*dim) - data = self.viewer.read_pixels(*dim, depth=False) - data = data[::-1, :, :] - return data - - def _renders(self, observations, **kwargs): - images = [] - for observation in observations: - img = self.render(observation, **kwargs) - images.append(img) - return np.stack(images, axis=0) - - def renders(self, samples, partial=False, **kwargs): - if partial: - samples = self.pad_observations(samples) - partial = False - - sample_images = self._renders(samples, partial=partial, **kwargs) - - composite = np.ones_like(sample_images[0]) * 255 - - for img in sample_images: - mask = get_image_mask(img) - composite[mask] = img[mask] - - return composite - - def __call__(self, *args, **kwargs): - return self.renders(*args, **kwargs) - - -env_name = "hopper-medium-expert-v2" -env = gym.make(env_name) -data = env.get_dataset() # dataset is only used for normalization in this colab - -# Cuda settings for colab -# torch.cuda.get_device_name(0) -DEVICE = "cpu" -DTYPE = torch.float - -# diffusion model settings -n_samples = 4 # number of trajectories planned via diffusion -horizon = 128 # length of sampled trajectories -state_dim = env.observation_space.shape[0] -action_dim = env.action_space.shape[0] -num_inference_steps = 100 # number of difusion steps - -obs = env.reset() -obs_raw = obs - -# normalize observations for forward passes -obs = normalize(obs, data, "observations") - - -# Two generators for different parts of the diffusion loop to work in colab -generator = torch.Generator(device="cuda") -generator_cpu = torch.Generator(device="cpu") -network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) - -scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") -optimizer = torch.optim.AdamW( - network.parameters(), - lr=0.001, - betas=(0.95, 0.99), - weight_decay=1e-6, - eps=1e-8, -) - -# TODO: Flesh this out using accelerate library (a la other examples) diff --git a/examples/diffuser/README.md b/examples/rl/README.md similarity index 56% rename from examples/diffuser/README.md rename to examples/rl/README.md index 464ccd57af85..dd8add8aa4ea 100644 --- a/examples/diffuser/README.md +++ b/examples/rl/README.md @@ -1,6 +1,9 @@ # Overview -These examples show how to run (Diffuser)[https://arxiv.org/pdf/2205.09991.pdf] in Diffusers. There are two scripts, `run_diffuser_value_guided.py` and `run_diffuser.py`. +These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers. +There are four scripts, +1. `run_diffuser_locomotion.py` to sample actions and run them in the environment, +2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model. You will need some RL specific requirements to run the examples: diff --git a/examples/diffuser/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py similarity index 85% rename from examples/diffuser/run_diffuser_gen_trajectories.py rename to examples/rl/run_diffuser_gen_trajectories.py index 3de8521343e3..4f04d3acd704 100644 --- a/examples/diffuser/run_diffuser_gen_trajectories.py +++ b/examples/rl/run_diffuser_gen_trajectories.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import DiffusionPipeline +from diffusers import ValueGuidedRLPipeline config = dict( @@ -17,14 +17,13 @@ ) -def _run(): +if __name__ == "__main__": env_name = "hopper-medium-v2" env = gym.make(env_name) - pipeline = DiffusionPipeline.from_pretrained( + pipeline = ValueGuidedRLPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", env=env, - custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) env.seed(0) @@ -56,11 +55,3 @@ def _run(): pass print(f"Total reward: {total_reward}") - - -def run(): - _run() - - -if __name__ == "__main__": - run() diff --git a/examples/diffuser/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py similarity index 85% rename from examples/diffuser/run_diffuser_locomotion.py rename to examples/rl/run_diffuser_locomotion.py index 9ac9df28db81..ad2fc8785f15 100644 --- a/examples/diffuser/run_diffuser_locomotion.py +++ b/examples/rl/run_diffuser_locomotion.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import DiffusionPipeline +from diffusers import ValueGuidedRLPipeline config = dict( @@ -17,14 +17,13 @@ ) -def _run(): +if __name__ == "__main__": env_name = "hopper-medium-v2" env = gym.make(env_name) - pipeline = DiffusionPipeline.from_pretrained( + pipeline = ValueGuidedRLPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", env=env, - custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) env.seed(0) @@ -56,11 +55,3 @@ def _run(): pass print(f"Total reward: {total_reward}") - - -def run(): - _run() - - -if __name__ == "__main__": - run() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2c531cf8cee0..2a16132d3d8c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -17,6 +17,7 @@ if is_torch_available(): + from .experimental import ValueGuidedRLPipeline from .modeling_utils import ModelMixin from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md new file mode 100644 index 000000000000..81a9de81c737 --- /dev/null +++ b/src/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# ๐Ÿงจ Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/src/diffusers/experimental/__init__.py b/src/diffusers/experimental/__init__.py new file mode 100644 index 000000000000..ebc815540301 --- /dev/null +++ b/src/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/__init__.py b/src/diffusers/experimental/rl/__init__.py new file mode 100644 index 000000000000..7b338d3173e1 --- /dev/null +++ b/src/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/examples/community/pipeline.py b/src/diffusers/experimental/rl/value_guided_sampling.py similarity index 83% rename from examples/community/pipeline.py rename to src/diffusers/experimental/rl/value_guided_sampling.py index 85e359c5c4c9..8d5062e3d4c5 100644 --- a/examples/community/pipeline.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -1,13 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import torch import tqdm -from diffusers import DiffusionPipeline -from diffusers.models.unet_1d import UNet1DModel -from diffusers.utils.dummy_pt_objects import DDPMScheduler + +from ...models.unet_1d import UNet1DModel +from ...pipeline_utils import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler -class ValueGuidedDiffuserPipeline(DiffusionPipeline): +class ValueGuidedRLPipeline(DiffusionPipeline): def __init__( self, value_function: UNet1DModel, diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index ab86b5b6f202..dd320e8bd655 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -104,6 +104,25 @@ def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + @slow + def test_unet_1d_maestro(self): + model_id = "harmonai/maestro-150k" + model = UNet1DModel.from_pretrained(model_id, subfolder="unet") + model.to(torch_device) + + sample_size = 65536 + noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) + timestep = torch.tensor([1]).to(torch_device) + + with torch.no_grad(): + output = model(noise, timestep).sample + + output_sum = output.abs().sum() + output_max = output.abs().max() + + assert (output_sum - 224.0896).abs() < 4e-2 + assert (output_max - 0.0607).abs() < 4e-4 + class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet1DModel @@ -204,24 +223,3 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass - - -class UnetModel1DTests(unittest.TestCase): - @slow - def test_unet_1d_maestro(self): - model_id = "harmonai/maestro-150k" - model = UNet1DModel.from_pretrained(model_id, subfolder="unet") - model.to(torch_device) - - sample_size = 65536 - noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) - timestep = torch.tensor([1]).to(torch_device) - - with torch.no_grad(): - output = model(noise, timestep).sample - - output_sum = output.abs().sum() - output_max = output.abs().max() - - assert (output_sum - 224.0896).abs() < 4e-2 - assert (output_max - 0.0607).abs() < 4e-4