Skip to content

Commit 4630240

Browse files
pcuencanatolambert
authored andcommitted
Use ONNX / Core ML compatible method to broadcast (#310)
* Use ONNX / Core ML compatible method to broadcast. Unfortunately `tile` could not be used either, it's still not compatible with ONNX. See #284. * Add comment about why broadcast_to is not used. Also, apply style to changed files. * Make sure broadcast remains in same device.
1 parent a756e02 commit 4630240

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def __init__(
120120
def forward(
121121
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
122122
) -> Dict[str, torch.FloatTensor]:
123-
124123
# 0. center input if necessary
125124
if self.config.center_input_sample:
126125
sample = 2 * sample - 1.0
@@ -132,8 +131,8 @@ def forward(
132131
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
133132
timesteps = timesteps[None].to(sample.device)
134133

135-
# broadcast to batch dimension
136-
timesteps = timesteps.broadcast_to(sample.shape[0])
134+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
135+
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
137136

138137
t_emb = self.time_proj(timesteps)
139138
emb = self.time_embedding(t_emb)

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def forward(
121121
timestep: Union[torch.Tensor, float, int],
122122
encoder_hidden_states: torch.Tensor,
123123
) -> Dict[str, torch.FloatTensor]:
124-
125124
# 0. center input if necessary
126125
if self.config.center_input_sample:
127126
sample = 2 * sample - 1.0
@@ -133,8 +132,8 @@ def forward(
133132
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
134133
timesteps = timesteps[None].to(sample.device)
135134

136-
# broadcast to batch dimension
137-
timesteps = timesteps.broadcast_to(sample.shape[0])
135+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
136+
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
138137

139138
t_emb = self.time_proj(timesteps)
140139
emb = self.time_embedding(t_emb)
@@ -145,7 +144,6 @@ def forward(
145144
# 3. down
146145
down_block_res_samples = (sample,)
147146
for downsample_block in self.down_blocks:
148-
149147
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
150148
sample, res_samples = downsample_block(
151149
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
@@ -160,7 +158,6 @@ def forward(
160158

161159
# 5. up
162160
for upsample_block in self.up_blocks:
163-
164161
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
165162
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
166163

0 commit comments

Comments
 (0)