From 7b46941e7768eabb0ad73d3e540379229bfe7fda Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 12 Feb 2025 21:13:41 +0100 Subject: [PATCH 1/3] Add fast path in generic matmul This manually adds the critical optimisation investigated in Julia issue 56954. While we could rely on LLVM to continue doing this optimisation, it's more robust to add it manually. --- src/matmul.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/matmul.jl b/src/matmul.jl index 76719c53..1a9ac099 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -1021,6 +1021,7 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) @inbounds for n in axes(B, 2), k in axes(B, 1) # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) + iszero(Balpha) && continue @simd for m in axes(A, 1) C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end From 3a5bdbe2952628939315c722f21d9a4e224017bf Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Thu, 13 Feb 2025 08:04:02 +0100 Subject: [PATCH 2/3] Fixup tests --- test/matmul.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/matmul.jl b/test/matmul.jl index 805edeac..afa730f1 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -767,6 +767,7 @@ import LinearAlgebra: Adjoint, Transpose (*)(x::RootInt, y::Integer) = x.i * y adjoint(x::RootInt) = x transpose(x::RootInt) = x +Base.zero(::RootInt) = RootInt(0) @test Base.promote_op(*, RootInt, RootInt) === Int From cbe2f396c08fa75dcbe85511f1a03e795f945cf3 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Thu, 13 Feb 2025 10:17:15 +0100 Subject: [PATCH 3/3] Handle missing --- src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matmul.jl b/src/matmul.jl index 1a9ac099..21c6e299 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -1021,7 +1021,7 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta) @inbounds for n in axes(B, 2), k in axes(B, 1) # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) - iszero(Balpha) && continue + !ismissing(Balpha) && iszero(Balpha) && continue @simd for m in axes(A, 1) C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end