Skip to content

Commit 5eb628a

Browse files
committed
also do linalg
1 parent 79e851e commit 5eb628a

File tree

1 file changed

+31
-56
lines changed

1 file changed

+31
-56
lines changed

src/linalg.jl

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -224,87 +224,65 @@ end
224224

225225
@inline dot(a::StaticVector, b::StaticVector) = _vecdot(same_size(a, b), a, b)
226226
@generated function _vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
227-
if prod(S) == 0
228-
return :(zero(promote_op(*, eltype(a), eltype(b))))
229-
end
230-
231-
expr = :(conj(a[1]) * b[1])
232-
for j = 2:prod(S)
233-
expr = :($expr + conj(a[$j]) * b[$j])
234-
end
235-
236227
return quote
237228
@_inline_meta
238-
@inbounds return $expr
229+
s = zero(promote_op(*, eltype(a), eltype(b)))
230+
@inbounds @simd for j = 1:$(prod(S))
231+
s += conj(a[j]) * b[j]
232+
end
233+
return s
239234
end
240235
end
241236

242237
@inline bilinear_vecdot(a::StaticArray, b::StaticArray) = _bilinear_vecdot(same_size(a, b), a, b)
243238
@generated function _bilinear_vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
244-
if prod(S) == 0
245-
return :(zero(promote_op(*, eltype(a), eltype(b))))
246-
end
247-
248-
expr = :(a[1] * b[1])
249-
for j = 2:prod(S)
250-
expr = :($expr + a[$j] * b[$j])
251-
end
252-
253239
return quote
254240
@_inline_meta
255-
@inbounds return $expr
241+
s = zero(promote_op(*, eltype(a), eltype(b)))
242+
@inbounds @simd for j = 1:$(prod(S))
243+
s += a[j] * b[j]
244+
end
245+
return s
256246
end
257247
end
258248

259249
@inline LinearAlgebra.norm_sqr(v::StaticVector) = mapreduce(abs2, +, v; init=zero(real(eltype(v))))
260250

261251
@inline norm(a::StaticArray) = _norm(Size(a), a)
262252
@generated function _norm(::Size{S}, a::StaticArray) where {S}
263-
if prod(S) == 0
264-
return zero(real(eltype(a)))
265-
end
266-
267-
expr = :(abs2(a[1]))
268-
for j = 2:prod(S)
269-
expr = :($expr + abs2(a[$j]))
270-
end
271-
272253
return quote
273-
$(Expr(:meta, :inline))
274-
@inbounds return sqrt($expr)
254+
@_inline_meta
255+
s = zero(real(eltype(a)))
256+
@inbounds @simd for j = 1:$(prod(S))
257+
s += abs2(a[j])
258+
end
259+
return sqrt(s)
275260
end
276261
end
277262

278263
_norm_p0(x) = x == 0 ? zero(x) : one(x)
279264

280265
@inline norm(a::StaticArray, p::Real) = _norm(Size(a), a, p)
281266
@generated function _norm(::Size{S}, a::StaticArray, p::Real) where {S}
282-
if prod(S) == 0
283-
return zero(real(eltype(a)))
284-
end
285-
286-
expr = :(abs(a[1])^p)
287-
for j = 2:prod(S)
288-
expr = :($expr + abs(a[$j])^p)
289-
end
290-
291-
expr_p1 = :(abs(a[1]))
292-
for j = 2:prod(S)
293-
expr_p1 = :($expr_p1 + abs(a[$j]))
294-
end
295-
296267
return quote
297-
$(Expr(:meta, :inline))
268+
@_inline_meta
269+
s = zero(real(eltype(a)))
298270
if p == Inf
299271
return mapreduce(abs, max, a; init=$(zero(real(eltype(a)))))
300272
elseif p == 1
301-
@inbounds return $expr_p1
273+
@inbounds @simd for j = 1:$(prod(S))
274+
s += abs(a[j])
275+
end
276+
return s
302277
elseif p == 2
303278
return norm(a)
304279
elseif p == 0
305280
return mapreduce(_norm_p0, +, a; init=$(zero(real(eltype(a)))))
306281
else
307-
@inbounds return ($expr)^(inv(p))
282+
@inbounds @simd for j = 1:$(prod(S))
283+
s += abs(a[j])^p
284+
end
285+
return s^(inv(p))
308286
end
309287
end
310288
end
@@ -321,16 +299,13 @@ end
321299
throw(DimensionMismatch("matrix is not square"))
322300
end
323301

324-
if S[1] == 0
325-
return zero(eltype(a))
326-
end
327-
328-
exprs = [:(a[$(LinearIndices(S)[i, i])]) for i = 1:S[1]]
329-
total = reduce((ex1, ex2) -> :($ex1 + $ex2), exprs)
330-
331302
return quote
332303
@_inline_meta
333-
@inbounds return $total
304+
s = zero(eltype(a))
305+
@inbounds @simd for i in 1:$(S[1])
306+
s += a[i,i]
307+
end
308+
return s
334309
end
335310
end
336311

0 commit comments

Comments
 (0)