Skip to content

Commit 48f6727

Browse files
committed
Add colwise for euclidean
1 parent d9b8846 commit 48f6727

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

src/distances/euclidean.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,26 @@ function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::RowVecs, y::R
2828
return D, pairwise_pullback
2929
end
3030

31+
function colwise(::Euclidean, x::ColVecs, y::ColVecs)
32+
return @tullio out[i] := sqrt <| (x.X[k, i] - y.X[k, i])^2
33+
end
34+
35+
function colwise(::Euclidean, x::RowVecs, y::RowVecs)
36+
return @tullio out[i] := sqrt <| (x.X[i, k] - y.X[i, k])^2
37+
end
38+
3139
function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs)
3240
return @tullio out[i, j] := (x.X[k, i] - y.X[k, j])^2
3341
end
3442

3543
function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs)
3644
return @tullio out[i, j] := (x.X[i, k] - y.X[j, k])^2
3745
end
46+
47+
function colwise(::SqEuclidean, x::ColVecs, y::ColVecs)
48+
return @tullio out[i] := (x.X[k, i] - y.X[k, i])^2
49+
end
50+
51+
function colwise(::SqEuclidean, x::RowVecs, y::RowVecs)
52+
return @tullio out[i] := (x.X[i, k] - y.X[i, k])^2
53+
end

0 commit comments

Comments
 (0)