Skip to content

Commit 8a73064

Browse files
Add AltDiffusion (#1299)
* add conversion script for vae * up * up * some fixes * add text model * use the correct config * add docs * move model in it's own file * move model in its own file * pass attenion mask to text encoder * pass attn mask to uncond inputs * quality * fix image2image * add imag2image in init * fix import * fix one more import * fix import, dummy objetcs * fix copied from * up * finish Co-authored-by: patil-suraj <[email protected]>
1 parent 4625f04 commit 8a73064

17 files changed

+1486
-12
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
- sections:
8181
- local: api/pipelines/overview
8282
title: "Overview"
83+
- local: api/pipelines/alt_diffusion
84+
title: "AltDiffusion"
8385
- local: api/pipelines/cycle_diffusion
8486
title: "Cycle Diffusion"
8587
- local: api/pipelines/ddim
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# AltDiffusion
14+
15+
AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu
16+
17+
The abstract of the paper is the following:
18+
19+
*In this work, we present a conceptually simple and effective method to train a strong bilingual multimodal representation model. Starting from the pretrained multimodal representation model CLIP released by OpenAI, we switched its text encoder with a pretrained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k- CN, and COCO-CN. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding.*
20+
21+
22+
*Overview*:
23+
24+
| Pipeline | Tasks | Colab | Demo
25+
|---|---|:---:|:---:|
26+
| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | -
27+
| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |-
28+
29+
## Tips
30+
31+
- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion).
32+
33+
- *Run AltDiffusion*
34+
35+
AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img).
36+
37+
- *How to load and use different schedulers.*
38+
39+
The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
40+
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
41+
42+
```python
43+
>>> from diffusers import AltDiffusionPipeline, EulerDiscreteScheduler
44+
45+
>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion")
46+
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
47+
48+
>>> # or
49+
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")
50+
>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=euler_scheduler)
51+
```
52+
53+
54+
- *How to conver all use cases with multiple or single pipeline*
55+
56+
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
57+
58+
```python
59+
>>> from diffusers import (
60+
... AltDiffusionPipeline,
61+
... AltDiffusionImg2ImgPipeline,
62+
... )
63+
64+
>>> img2text = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion")
65+
>>> img2img = AltDiffusionImg2ImgPipeline(**img2text.components)
66+
67+
>>> # now you can use img2text(...) and img2img(...) just like the call methods of each respective pipeline
68+
```
69+
70+
## AltDiffusionPipelineOutput
71+
[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput
72+
73+
## AltDiffusionPipeline
74+
[[autodoc]] AltDiffusionPipeline
75+
- __call__
76+
- enable_attention_slicing
77+
- disable_attention_slicing
78+
79+
## AltDiffusionImg2ImgPipeline
80+
[[autodoc]] AltDiffusionImg2ImgPipeline
81+
- __call__
82+
- enable_attention_slicing
83+
- disable_attention_slicing

docs/source/api/pipelines/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ available a colab notebook to directly try them out.
4444

4545
| Pipeline | Paper | Tasks | Colab
4646
|---|---|:---:|:---:|
47+
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
4748
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
4849
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
4950
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |

docs/source/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ available a colab notebook to directly try them out.
3434

3535
| Pipeline | Paper | Tasks | Colab
3636
|---|---|:---:|:---:|
37+
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
3738
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
3839
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
3940
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |

docs/source/using-diffusers/conditional_image_generation.mdx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,3 @@ You can save the image by simply calling:
4444
```python
4545
>>> image.save("image_of_squirrel_painting.png")
4646
```
47-
48-

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565

6666
if is_torch_available() and is_transformers_available():
6767
from .pipelines import (
68+
AltDiffusionImg2ImgPipeline,
69+
AltDiffusionPipeline,
6870
CycleDiffusionPipeline,
6971
LDMTextToImagePipeline,
7072
StableDiffusionImg2ImgPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..utils.dummy_pt_objects import * # noqa F403
1616

1717
if is_torch_available() and is_transformers_available():
18+
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
1819
from .latent_diffusion import LDMTextToImagePipeline
1920
from .stable_diffusion import (
2021
CycleDiffusionPipeline,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Union
3+
4+
import numpy as np
5+
6+
import PIL
7+
from PIL import Image
8+
9+
from ...utils import BaseOutput, is_torch_available, is_transformers_available
10+
11+
12+
@dataclass
13+
# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
14+
class AltDiffusionPipelineOutput(BaseOutput):
15+
"""
16+
Output class for Alt Diffusion pipelines.
17+
18+
Args:
19+
images (`List[PIL.Image.Image]` or `np.ndarray`)
20+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
21+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
22+
nsfw_content_detected (`List[bool]`)
23+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
24+
(nsfw) content, or `None` if safety checking could not be performed.
25+
"""
26+
27+
images: Union[List[PIL.Image.Image], np.ndarray]
28+
nsfw_content_detected: Optional[List[bool]]
29+
30+
31+
if is_transformers_available() and is_torch_available():
32+
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
33+
from .pipeline_alt_diffusion import AltDiffusionPipeline
34+
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
from torch import nn
6+
7+
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
8+
from transformers.utils import ModelOutput
9+
10+
11+
@dataclass
12+
class TransformationModelOutput(ModelOutput):
13+
"""
14+
Base class for text model's outputs that also contains a pooling of the last hidden states.
15+
16+
Args:
17+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
18+
The text embeddings obtained by applying the projection layer to the pooler_output.
19+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
20+
Sequence of hidden-states at the output of the last layer of the model.
21+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
22+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
23+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
24+
25+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
26+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
27+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
28+
sequence_length)`.
29+
30+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
31+
heads.
32+
"""
33+
34+
projection_state: Optional[torch.FloatTensor] = None
35+
last_hidden_state: torch.FloatTensor = None
36+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
37+
attentions: Optional[Tuple[torch.FloatTensor]] = None
38+
39+
40+
class RobertaSeriesConfig(XLMRobertaConfig):
41+
def __init__(
42+
self,
43+
pad_token_id=1,
44+
bos_token_id=0,
45+
eos_token_id=2,
46+
project_dim=512,
47+
pooler_fn="cls",
48+
learn_encoder=False,
49+
use_attention_mask=True,
50+
**kwargs,
51+
):
52+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
53+
self.project_dim = project_dim
54+
self.pooler_fn = pooler_fn
55+
self.learn_encoder = learn_encoder
56+
self.use_attention_mask = use_attention_mask
57+
58+
59+
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
60+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
61+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
62+
base_model_prefix = "roberta"
63+
config_class = RobertaSeriesConfig
64+
65+
def __init__(self, config):
66+
super().__init__(config)
67+
self.roberta = XLMRobertaModel(config)
68+
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
69+
self.post_init()
70+
71+
def forward(
72+
self,
73+
input_ids: Optional[torch.Tensor] = None,
74+
attention_mask: Optional[torch.Tensor] = None,
75+
token_type_ids: Optional[torch.Tensor] = None,
76+
position_ids: Optional[torch.Tensor] = None,
77+
head_mask: Optional[torch.Tensor] = None,
78+
inputs_embeds: Optional[torch.Tensor] = None,
79+
encoder_hidden_states: Optional[torch.Tensor] = None,
80+
encoder_attention_mask: Optional[torch.Tensor] = None,
81+
output_attentions: Optional[bool] = None,
82+
return_dict: Optional[bool] = None,
83+
output_hidden_states: Optional[bool] = None,
84+
):
85+
r""" """
86+
87+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88+
89+
outputs = self.base_model(
90+
input_ids=input_ids,
91+
attention_mask=attention_mask,
92+
token_type_ids=token_type_ids,
93+
position_ids=position_ids,
94+
head_mask=head_mask,
95+
inputs_embeds=inputs_embeds,
96+
encoder_hidden_states=encoder_hidden_states,
97+
encoder_attention_mask=encoder_attention_mask,
98+
output_attentions=output_attentions,
99+
output_hidden_states=output_hidden_states,
100+
return_dict=return_dict,
101+
)
102+
103+
projection_state = self.transformation(outputs.last_hidden_state)
104+
105+
return TransformationModelOutput(
106+
projection_state=projection_state,
107+
last_hidden_state=outputs.last_hidden_state,
108+
hidden_states=outputs.hidden_states,
109+
attentions=outputs.attentions,
110+
)

0 commit comments

Comments
 (0)