Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 42 additions & 34 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,52 +412,60 @@ lmul!(A, B)

_vec_or_mat_str(s::Tuple{Any}) = "vector"
_vec_or_mat_str(s::Tuple{Any,Any}) = "matrix"
@noinline function matmul_size_check(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
function matmul_size_check(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
szA2 = get(sizeA, 2, 1)
if szA2 != sizeB[1]
strA = _vec_or_mat_str(sizeA)
strB = _vec_or_mat_str(sizeB)
B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB
size_or_len_str_B = B_size_len isa Integer ? "length" : "size"
dim_or_len_str_B = B_size_len isa Integer ? "length" : "first dimension"
pos_str_A = LazyString(length(sizeA) == length(sizeB) ? "first " : "", strA)
pos_str_B = LazyString(length(sizeA) == length(sizeB) ? "second " : "", strB)
throw(DimensionMismatch(
LazyString(
lazy"incompatible dimensions for matrix multiplication: ",
lazy"tried to multiply a $strA of size $sizeA with a $strB of $size_or_len_str_B $B_size_len. ",
lazy"The second dimension of the $pos_str_A: $szA2, does not match the $dim_or_len_str_B of the $pos_str_B: $(sizeB[1])."
)
)
)
matmul_size_check_error(sizeA, sizeB)
end
return nothing
end
@noinline function matmul_size_check(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
@noinline function matmul_size_check_error(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
strA = _vec_or_mat_str(sizeA)
strB = _vec_or_mat_str(sizeB)
szA2 = get(sizeA, 2, 1)
B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB
size_or_len_str_B = B_size_len isa Integer ? "length" : "size"
dim_or_len_str_B = B_size_len isa Integer ? "length" : "first dimension"
pos_str_A = LazyString(length(sizeA) == length(sizeB) ? "first " : "", strA)
pos_str_B = LazyString(length(sizeA) == length(sizeB) ? "second " : "", strB)
throw(DimensionMismatch(
LazyString(
"incompatible dimensions for matrix multiplication: ",
lazy"tried to multiply a $strA of size $sizeA with a $strB of $size_or_len_str_B $B_size_len. ",
lazy"The second dimension of the $pos_str_A: $szA2, does not match the $dim_or_len_str_B of the $pos_str_B: $(sizeB[1])."
)
)
)
end
function matmul_size_check(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
matmul_size_check(sizeA, sizeB)
szB2 = get(sizeB, 2, 1)
szC2 = get(sizeC, 2, 1)
if sizeC[1] != sizeA[1] || szC2 != szB2
strA = _vec_or_mat_str(sizeA)
strB = _vec_or_mat_str(sizeB)
strC = _vec_or_mat_str(sizeC)
C_size_len = length(sizeC) == 1 ? sizeC[1] : sizeC
size_or_len_str_C = C_size_len isa Integer ? "length" : "size"
B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB
size_or_len_str_B = B_size_len isa Integer ? "length" : "size"
destsize = length(sizeB) == length(sizeC) == 1 ? sizeA[1] : (sizeA[1], szB2)
size_or_len_str_dest = destsize isa Integer ? "length" : "size"
throw(DimensionMismatch(
LazyString(
"incompatible destination size: ",
lazy"the destination $strC of $size_or_len_str_C $C_size_len is incomatible with the product of a $strA of size $sizeA and a $strB of $size_or_len_str_B $B_size_len. ",
lazy"The destination must be of $size_or_len_str_dest $destsize."
)
)
)
matmul_size_check_error(sizeC, sizeA, sizeB)
end
return nothing
end
@noinline function matmul_size_check_error(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}})
strA = _vec_or_mat_str(sizeA)
strB = _vec_or_mat_str(sizeB)
strC = _vec_or_mat_str(sizeC)
szB2 = get(sizeB, 2, 1)
C_size_len = length(sizeC) == 1 ? sizeC[1] : sizeC
size_or_len_str_C = C_size_len isa Integer ? "length" : "size"
B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB
size_or_len_str_B = B_size_len isa Integer ? "length" : "size"
destsize = length(sizeB) == length(sizeC) == 1 ? sizeA[1] : (sizeA[1], szB2)
size_or_len_str_dest = destsize isa Integer ? "length" : "size"
throw(DimensionMismatch(
LazyString(
"incompatible destination size: ",
lazy"the destination $strC of $size_or_len_str_C $C_size_len is incomatible with the product of a $strA of size $sizeA and a $strB of $size_or_len_str_B $B_size_len. ",
lazy"The destination must be of $size_or_len_str_dest $destsize."
)
)
)
end

# We may inline the matmul2x2! and matmul3x3! calls for `α == true`
# to simplify the @stable_muladdmul branches
Expand Down