diff --git a/src/matmul.jl b/src/matmul.jl index cf229364..efb9d0fe 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -412,52 +412,60 @@ lmul!(A, B) _vec_or_mat_str(s::Tuple{Any}) = "vector" _vec_or_mat_str(s::Tuple{Any,Any}) = "matrix" -@noinline function matmul_size_check(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) +function matmul_size_check(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) szA2 = get(sizeA, 2, 1) if szA2 != sizeB[1] - strA = _vec_or_mat_str(sizeA) - strB = _vec_or_mat_str(sizeB) - B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB - size_or_len_str_B = B_size_len isa Integer ? "length" : "size" - dim_or_len_str_B = B_size_len isa Integer ? "length" : "first dimension" - pos_str_A = LazyString(length(sizeA) == length(sizeB) ? "first " : "", strA) - pos_str_B = LazyString(length(sizeA) == length(sizeB) ? "second " : "", strB) - throw(DimensionMismatch( - LazyString( - lazy"incompatible dimensions for matrix multiplication: ", - lazy"tried to multiply a $strA of size $sizeA with a $strB of $size_or_len_str_B $B_size_len. ", - lazy"The second dimension of the $pos_str_A: $szA2, does not match the $dim_or_len_str_B of the $pos_str_B: $(sizeB[1])." - ) - ) - ) + matmul_size_check_error(sizeA, sizeB) end return nothing end -@noinline function matmul_size_check(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) +@noinline function matmul_size_check_error(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) + strA = _vec_or_mat_str(sizeA) + strB = _vec_or_mat_str(sizeB) + szA2 = get(sizeA, 2, 1) + B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB + size_or_len_str_B = B_size_len isa Integer ? "length" : "size" + dim_or_len_str_B = B_size_len isa Integer ? "length" : "first dimension" + pos_str_A = LazyString(length(sizeA) == length(sizeB) ? "first " : "", strA) + pos_str_B = LazyString(length(sizeA) == length(sizeB) ? "second " : "", strB) + throw(DimensionMismatch( + LazyString( + "incompatible dimensions for matrix multiplication: ", + lazy"tried to multiply a $strA of size $sizeA with a $strB of $size_or_len_str_B $B_size_len. ", + lazy"The second dimension of the $pos_str_A: $szA2, does not match the $dim_or_len_str_B of the $pos_str_B: $(sizeB[1])." + ) + ) + ) +end +function matmul_size_check(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) matmul_size_check(sizeA, sizeB) szB2 = get(sizeB, 2, 1) szC2 = get(sizeC, 2, 1) if sizeC[1] != sizeA[1] || szC2 != szB2 - strA = _vec_or_mat_str(sizeA) - strB = _vec_or_mat_str(sizeB) - strC = _vec_or_mat_str(sizeC) - C_size_len = length(sizeC) == 1 ? sizeC[1] : sizeC - size_or_len_str_C = C_size_len isa Integer ? "length" : "size" - B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB - size_or_len_str_B = B_size_len isa Integer ? "length" : "size" - destsize = length(sizeB) == length(sizeC) == 1 ? sizeA[1] : (sizeA[1], szB2) - size_or_len_str_dest = destsize isa Integer ? "length" : "size" - throw(DimensionMismatch( - LazyString( - "incompatible destination size: ", - lazy"the destination $strC of $size_or_len_str_C $C_size_len is incomatible with the product of a $strA of size $sizeA and a $strB of $size_or_len_str_B $B_size_len. ", - lazy"The destination must be of $size_or_len_str_dest $destsize." - ) - ) - ) + matmul_size_check_error(sizeC, sizeA, sizeB) end return nothing end +@noinline function matmul_size_check_error(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) + strA = _vec_or_mat_str(sizeA) + strB = _vec_or_mat_str(sizeB) + strC = _vec_or_mat_str(sizeC) + szB2 = get(sizeB, 2, 1) + C_size_len = length(sizeC) == 1 ? sizeC[1] : sizeC + size_or_len_str_C = C_size_len isa Integer ? "length" : "size" + B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB + size_or_len_str_B = B_size_len isa Integer ? "length" : "size" + destsize = length(sizeB) == length(sizeC) == 1 ? sizeA[1] : (sizeA[1], szB2) + size_or_len_str_dest = destsize isa Integer ? "length" : "size" + throw(DimensionMismatch( + LazyString( + "incompatible destination size: ", + lazy"the destination $strC of $size_or_len_str_C $C_size_len is incomatible with the product of a $strA of size $sizeA and a $strB of $size_or_len_str_B $B_size_len. ", + lazy"The destination must be of $size_or_len_str_dest $destsize." + ) + ) + ) +end # We may inline the matmul2x2! and matmul3x3! calls for `α == true` # to simplify the @stable_muladdmul branches