-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Describe the bug
This is my first time creating an issue and I'm just starting to use the diffusers library. Really enjoying it! Also, apologies if I'm using the DiffusionPipeline incorrectly.
I'm trying to follow along the unconditional image generation example notebook, but substitute the butterfly dataset (3 channel images) for MNIST digits (1 channel images).
When using the DDPMPipeline, the Image.fromarray(image) in DiffusionPipeline's method numpy_to_pil fails to generate the PIL image because the image has an additional dimension i.e. (M, N, 1). If the image had three channels (M,N,3), this would work and create an RGB image. This would also work if if the last length axis was removed (M,N).
Here's an example with randomly generated ndarrays
Works for 3 channel
Image.fromarray((np.random.rand(32,32,3)*255).astype('uint8'))Current behavior, fails for grayscale/single channel images
Image.fromarray((np.random.rand(32,32,1)*255).astype('uint8'))Both would work after applying np.squeeze, removing the length one axis from ndarray
Image.fromarray(np.squeeze((np.random.rand(32,32,1)*255).astype('uint8')))I assume using np.squeeze or some other check for last dimension being length one would fix the issue.
Reproduction
Here the model isn't trained, but I think the pipeline should still work
import diffusers
model = diffusers.UNet2DModel(
sample_size=32,
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(128,128,256,512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
noise_scheduler = diffusers.DDPMScheduler(num_train_timesteps=200, tensor_format='pt')
pipeline = diffusers.DDPMPipeline(unet=model,scheduler=noise_scheduler)
pipeline()["sample"]Logs
Here's what I get when I run the above code in a jupyter notebook.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~/miniconda3/envs/deep_learning/lib/python3.8/site-packages/PIL/Image.py:2953, in fromarray(obj, mode)
2952 try:
-> 2953 mode, rawmode = _fromarray_typemap[typekey]
2954 except KeyError as e:
KeyError: ((1, 1, 1), '|u1')
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Input In [7], in <cell line: 1>()
----> 1 pipeline()
File ~/miniconda3/envs/deep_learning/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/miniconda3/envs/deep_learning/lib/python3.8/site-packages/diffusers/pipelines/ddpm/pipeline_ddpm.py:66, in DDPMPipeline.__call__(self, batch_size, generator, output_type, **kwargs)
64 image = image.cpu().permute(0, 2, 3, 1).numpy()
65 if output_type == "pil":
---> 66 image = self.numpy_to_pil(image)
68 return {"sample": image}
File ~/miniconda3/envs/deep_learning/lib/python3.8/site-packages/diffusers/pipeline_utils.py:261, in DiffusionPipeline.numpy_to_pil(images)
259 images = images[None, ...]
260 images = (images * 255).round().astype("uint8")
--> 261 pil_images = [Image.fromarray(image) for image in images]
263 return pil_images
File ~/miniconda3/envs/deep_learning/lib/python3.8/site-packages/diffusers/pipeline_utils.py:261, in <listcomp>(.0)
259 images = images[None, ...]
260 images = (images * 255).round().astype("uint8")
--> 261 pil_images = [Image.fromarray(image) for image in images]
263 return pil_images
File ~/miniconda3/envs/deep_learning/lib/python3.8/site-packages/PIL/Image.py:2955, in fromarray(obj, mode)
2953 mode, rawmode = _fromarray_typemap[typekey]
2954 except KeyError as e:
-> 2955 raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e
2956 else:
2957 rawmode = mode
TypeError: Cannot handle this data type: (1, 1, 1), |u1System Info
| Name | Version |
|---|---|
| numpy | 1.21.5 |
| diffusers | 0.2.3 |
| pillow | 9.2.0 |