diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py index 7e62922a02..d8bbad46a1 100644 --- a/torchao/swizzle/swizzle_ops.py +++ b/torchao/swizzle/swizzle_ops.py @@ -30,7 +30,12 @@ def swizzle_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] - if torch.is_floating_point(a) and torch.is_floating_point(b) and a.ndim == 2 and b.ndim == 2: + if ( + torch.is_floating_point(a) + and torch.is_floating_point(b) + and a.ndim == 2 + and b.ndim == 2 + ): a_is_swizzled = False b_is_swizzled = False if isinstance(a, SwizzleTensor):