Skip to content

Commit dda448f

Browse files
authored
use fused_glu (huggingface#51)
1 parent 99a8c6f commit dda448f

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

src/diffusers/models/attention_oneflow.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -770,15 +770,8 @@ def gelu(self, gate):
770770
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
771771

772772
def forward(self, hidden_states):
773-
if hasattr(torch._C, "fused_geglu"):
774-
x_shape = hidden_states.shape
775-
if len(x_shape) != 2:
776-
hidden_states = hidden_states.reshape(-1, x_shape[-1])
777-
out = torch._C.fused_geglu(hidden_states, self.proj.weight, self.proj.bias)
778-
if len(x_shape) != 2:
779-
out_shape = x_shape[0:len(x_shape) -1 ] + (-1, )
780-
out = out.reshape(out_shape)
781-
return out
773+
if hasattr(torch._C, "fused_glu"):
774+
return torch._C.fused_glu(hidden_states, self.proj.weight, self.proj.bias, activation="gelu")
782775
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
783776
return hidden_states * self.gelu(gate)
784777

0 commit comments

Comments
 (0)