Skip to content

DDPM pipeline with black and white 1-channel image 'L" #1280

@zepmck

Description

@zepmck

I am trying to generate synthetic images starting from a black and white 1-channel training data set using the DDPMPipeline.

This is my UNet2D model

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels = 1, #3,  # the number of input channels, 3 for RGB images
    out_channels = 1, #3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
    # block_out_channels=(132, 132, 256, 256, 512, 512),  # the number of output channes for each UNet block
    down_block_types=( 
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D", 
        "DownBlock2D", 
        "DownBlock2D", 
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ), 
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D"  
      ),
)

However, during the evaluation of the model I get the following error:

Input In [134], in train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
     66 pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
     68 if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
---> 69     evaluate(config, epoch, pipeline)
     71 if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
     72     if config.push_to_hub:

Input In [133], in evaluate(config, epoch, pipeline)
     12 def evaluate(config, epoch, pipeline):
     13     # Sample some images from random noise (this is the backward diffusion process).
     14     # The default pipeline output type is `List[PIL.Image]`
---> 15     images = pipeline(
     16         batch_size = config.eval_batch_size, 
     17         generator=torch.manual_seed(config.seed),
     18     )["sample"]
     20     # Make a grid out of the images
     21     image_grid = make_grid(images, rows=4, cols=4)

File /opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /opt/conda/lib/python3.8/site-packages/diffusers/pipelines/ddpm/pipeline_ddpm.py:57, in DDPMPipeline.__call__(self, batch_size, generator, torch_device, output_type)
     55 image = image.cpu().permute(0, 2, 3, 1).numpy()
     56 if output_type == "pil":
---> 57     image = self.numpy_to_pil(image)
     59 return {"sample": image}

File /opt/conda/lib/python3.8/site-packages/diffusers/pipeline_utils.py:202, in DiffusionPipeline.numpy_to_pil(images)
    200     images = images[None, ...]
    201 images = (images * 255).round().astype("uint8")
--> 202 pil_images = [Image.fromarray(image) for image in images]
    204 return pil_images

File /opt/conda/lib/python3.8/site-packages/diffusers/pipeline_utils.py:202, in <listcomp>(.0)
    200     images = images[None, ...]
    201 images = (images * 255).round().astype("uint8")
--> 202 pil_images = [Image.fromarray(image) for image in images]
    204 return pil_images

File /opt/conda/lib/python3.8/site-packages/PIL/Image.py:2815, in fromarray(obj, mode)
   2813         mode, rawmode = _fromarray_typemap[typekey]
   2814     except KeyError as e:
-> 2815         raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e
   2816 else:
   2817     rawmode = mode

TypeError: Cannot handle this data type: (1, 1, 1), |u1

Any idea why this happens?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions