Skip to content

Commit 95414bd

Browse files
authored
Experimental: allow fp16 in mps (#961)
* Docs: refer to pre-RC version of PyTorch 1.13.0. * Remove temporary workaround for unavailable op. * Update comment to make it less ambiguous. * Remove use of contiguous in mps. It appears to not longer be necessary. * Special case: use einsum for much better performance in mps * Update mps docs. * MPS: make pipeline work in half precision.
1 parent a59f999 commit 95414bd

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/diffusers/models/attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,12 @@ def __init__(self, dim_in: int, dim_out: int):
376376
super().__init__()
377377
self.proj = nn.Linear(dim_in, dim_out * 2)
378378

379+
def gelu(self, gate):
380+
if gate.device.type != "mps":
381+
return F.gelu(gate)
382+
# mps: gelu is not implemented for float16
383+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
384+
379385
def forward(self, hidden_states):
380386
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
381-
return hidden_states * F.gelu(gate)
387+
return hidden_states * self.gelu(gate)

0 commit comments

Comments
 (0)