Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ os:
- linux
- osx
julia:
- 0.5
- 0.6
- 1.0
- nightly
notifications:
email:
Expand Down
3 changes: 1 addition & 2 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
julia 0.5
julia 1.0
Distances
FunctionalData
NearestNeighbors
ProgressMeter
67 changes: 38 additions & 29 deletions src/QuickShiftClustering.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
__precompile__()

module QuickShiftClustering

isinstalled(a) = isa(Pkg.installed(a), VersionNumber)
isinstalled("PyPlot") && begin
using PyPlot
using Pkg

if "PyPlot" in keys(Pkg.installed())
import PyPlot
end

using NearestNeighbors, ProgressMeter, Distances, FunctionalData
using Statistics
using NearestNeighbors, ProgressMeter, Distances

export quickshift, quickshiftlabels, quickshiftplot

Expand All @@ -22,7 +22,8 @@ function dist(a::Array{Float32,2}, i::Int, b::Array{Float32,2}, j::Int)
for d = 1:size(a,1)
sum += (a[d,i]-b[d,j])^2
end
sum

return sum
end

function link(i, inds, G, data)
Expand All @@ -37,42 +38,45 @@ function link(i, inds, G, data)
end
end
end
minind, mindist

return minind, mindist
end

function gauss(data, n, ind, factor1, factor2)
s = 0f0
for m = ind
s += exp(-dist(data,n,data,m)*factor1)
end
factor2 * s

return factor2 * s
end

quickshift(data, a...) = quickshift(convert(Array{Float32,2},data), a...)
quickshift(data::Array{Float32,2}, sigma) = quickshift(data, convert(Float32,sigma))
function quickshift(data::Array{Float32,2}, sigma::Float32=convert(Float32,median(pairwise(Euclidean(), randsample(data,1000)))/100))

function quickshift(data::Array{Float32,2}, sigma::Float32=convert(Float32, median(pairwise(Euclidean(), data[:,rand(1:size(data, 2), 1000)] / 100, dims=2))))
tree = KDTree(data)
# @show sigma
N = len(data)
N = size(data, 2)
factor1 = 1f0 / (2*sigma^2) ::Float32
factor2 = 1/(2*pi*sigma^2*N)
G = zeros(Float32, 1, N)
nninds = [Array(Int,0) for n in 1:N]
nninds = [Array{Int, 1}() for n in 1:N]
@showprogress 1 "Computing kernel distances ... " for n = 1:N
knnind = @p knn tree vec(at(data,n)) 100 true | fst
knnind = knn(tree, vec(data[:,n]), 100, true)[1]
ind = length(knnind) > 10 ? knnind : 1:N
nninds[n] = ind
G[n] = gauss(data, n, ind, factor1, factor2)
end
# println("median lenghts:",(@p map nninds length | flatten | median))

links = [Any[] for i in 1:N]
rootind = -1
inflength = typemax(eltype(G))
Nrange = 1:N

minind = 0
mindist = inflength
@showprogress 1 "Linking ... " for i in 1:N
for inds = [nninds[i]; Nrange]
for inds = [nninds[i]; 1:N]
minind, mindist = link(i, inds, G, data)
if minind != 0
break
Expand All @@ -88,44 +92,49 @@ function quickshift(data::Array{Float32,2}, sigma::Float32=convert(Float32,media
push!(links[minind], (sqrt(mindist), i))
end
end
QuickShift(rootind, links, sigma)

return QuickShift(rootind, links, sigma)
end

function quickshiftlabels(a::QuickShift, maxlength = 10*a.sigma)
labels = zeros(Int32,length(a.links))
cut_internal(a.rootind, a.links, labels, maxlength, 1, 2)
labels
labels = zeros(Int,length(a.links))
cut_internal!(labels, a.rootind, a.links, maxlength, 1, 2)

return labels
end

function cut_internal(ind, links, labels, maxlength, label, maxlabel)
function cut_internal!(labels, ind, links, maxlength, label, maxlabel)
labels[ind] = label
for x in links[ind]
if x[1] > maxlength
maxlabel += 1
end
maxlabel = max(label, cut_internal(x[2], links, labels, maxlength, x[1] > maxlength ? maxlabel : label, maxlabel))
maxlabel = max(label, cut_internal!(labels, x[2], links, maxlength, x[1] > maxlength ? maxlabel : label, maxlabel))
end
maxlabel

return maxlabel
end

function quickshiftplot(a::QuickShift, data, labels)
if !isdefined(:PyPlot)
function quickshiftplot(a::QuickShift, data::Array{T, 2} where T, labels::Array{Int, 1})
if !isdefined(Main, :PyPlot)
error("quickshiftplot needs PyPlot installed and loaded using 'using PyPlot'")
end
if size(data,1) != 2

if size(data, 1) != 2
error("quickshiftplot only works on 2D data, i.e. size(data,1)==2")
end

for i = 1:len(a.links)
for i = 1:length(a.links)
from = data[:,i]
for x = a.links[i]
to = data[:,x[2]]
p = hcat(from,to)'
plot(fst(p),snd(p),"b-")
PyPlot.plot(p[:,1],p[:,2],"b-")
end

end
scatter(data[1,:],data[2,:], c = Array{Int, 1}(labels), edgecolor = "none")

PyPlot.scatter(data[1,:],data[2,:], c = labels, edgecolor = "none")
end

end
1 change: 0 additions & 1 deletion test/REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
FactCheck
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
println("\n\n\nRunning tests ...")

using QuickShiftClustering, FactCheck
using QuickShiftClustering

quickshift(rand(2,1000))

println(" done running tests!")