diff --git a/.travis.yml b/.travis.yml index 111cc8e..cb7b7fa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,7 @@ os: - linux - osx julia: - - 0.5 - - 0.6 + - 1.0 - nightly notifications: email: diff --git a/REQUIRE b/REQUIRE index 8eff6b2..922872e 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,5 +1,4 @@ -julia 0.5 +julia 1.0 Distances -FunctionalData NearestNeighbors ProgressMeter diff --git a/src/QuickShiftClustering.jl b/src/QuickShiftClustering.jl index 9bc65f8..d67d52c 100644 --- a/src/QuickShiftClustering.jl +++ b/src/QuickShiftClustering.jl @@ -1,13 +1,13 @@ -__precompile__() - module QuickShiftClustering -isinstalled(a) = isa(Pkg.installed(a), VersionNumber) -isinstalled("PyPlot") && begin - using PyPlot +using Pkg + +if "PyPlot" in keys(Pkg.installed()) + import PyPlot end -using NearestNeighbors, ProgressMeter, Distances, FunctionalData +using Statistics +using NearestNeighbors, ProgressMeter, Distances export quickshift, quickshiftlabels, quickshiftplot @@ -22,7 +22,8 @@ function dist(a::Array{Float32,2}, i::Int, b::Array{Float32,2}, j::Int) for d = 1:size(a,1) sum += (a[d,i]-b[d,j])^2 end - sum + + return sum end function link(i, inds, G, data) @@ -37,7 +38,8 @@ function link(i, inds, G, data) end end end - minind, mindist + + return minind, mindist end function gauss(data, n, ind, factor1, factor2) @@ -45,34 +47,36 @@ function gauss(data, n, ind, factor1, factor2) for m = ind s += exp(-dist(data,n,data,m)*factor1) end - factor2 * s + + return factor2 * s end quickshift(data, a...) = quickshift(convert(Array{Float32,2},data), a...) quickshift(data::Array{Float32,2}, sigma) = quickshift(data, convert(Float32,sigma)) -function quickshift(data::Array{Float32,2}, sigma::Float32=convert(Float32,median(pairwise(Euclidean(), randsample(data,1000)))/100)) + +function quickshift(data::Array{Float32,2}, sigma::Float32=convert(Float32, median(pairwise(Euclidean(), data[:,rand(1:size(data, 2), 1000)] / 100, dims=2)))) tree = KDTree(data) # @show sigma - N = len(data) + N = size(data, 2) factor1 = 1f0 / (2*sigma^2) ::Float32 factor2 = 1/(2*pi*sigma^2*N) G = zeros(Float32, 1, N) - nninds = [Array(Int,0) for n in 1:N] + nninds = [Array{Int, 1}() for n in 1:N] @showprogress 1 "Computing kernel distances ... " for n = 1:N - knnind = @p knn tree vec(at(data,n)) 100 true | fst + knnind = knn(tree, vec(data[:,n]), 100, true)[1] ind = length(knnind) > 10 ? knnind : 1:N nninds[n] = ind G[n] = gauss(data, n, ind, factor1, factor2) end - # println("median lenghts:",(@p map nninds length | flatten | median)) + links = [Any[] for i in 1:N] rootind = -1 inflength = typemax(eltype(G)) - Nrange = 1:N + minind = 0 mindist = inflength @showprogress 1 "Linking ... " for i in 1:N - for inds = [nninds[i]; Nrange] + for inds = [nninds[i]; 1:N] minind, mindist = link(i, inds, G, data) if minind != 0 break @@ -88,44 +92,49 @@ function quickshift(data::Array{Float32,2}, sigma::Float32=convert(Float32,media push!(links[minind], (sqrt(mindist), i)) end end - QuickShift(rootind, links, sigma) + + return QuickShift(rootind, links, sigma) end function quickshiftlabels(a::QuickShift, maxlength = 10*a.sigma) - labels = zeros(Int32,length(a.links)) - cut_internal(a.rootind, a.links, labels, maxlength, 1, 2) - labels + labels = zeros(Int,length(a.links)) + cut_internal!(labels, a.rootind, a.links, maxlength, 1, 2) + + return labels end -function cut_internal(ind, links, labels, maxlength, label, maxlabel) +function cut_internal!(labels, ind, links, maxlength, label, maxlabel) labels[ind] = label for x in links[ind] if x[1] > maxlength maxlabel += 1 end - maxlabel = max(label, cut_internal(x[2], links, labels, maxlength, x[1] > maxlength ? maxlabel : label, maxlabel)) + maxlabel = max(label, cut_internal!(labels, x[2], links, maxlength, x[1] > maxlength ? maxlabel : label, maxlabel)) end - maxlabel + + return maxlabel end -function quickshiftplot(a::QuickShift, data, labels) - if !isdefined(:PyPlot) +function quickshiftplot(a::QuickShift, data::Array{T, 2} where T, labels::Array{Int, 1}) + if !isdefined(Main, :PyPlot) error("quickshiftplot needs PyPlot installed and loaded using 'using PyPlot'") end - if size(data,1) != 2 + + if size(data, 1) != 2 error("quickshiftplot only works on 2D data, i.e. size(data,1)==2") end - for i = 1:len(a.links) + for i = 1:length(a.links) from = data[:,i] for x = a.links[i] to = data[:,x[2]] p = hcat(from,to)' - plot(fst(p),snd(p),"b-") + PyPlot.plot(p[:,1],p[:,2],"b-") end end - scatter(data[1,:],data[2,:], c = Array{Int, 1}(labels), edgecolor = "none") + + PyPlot.scatter(data[1,:],data[2,:], c = labels, edgecolor = "none") end end diff --git a/test/REQUIRE b/test/REQUIRE index bc3e234..e69de29 100644 --- a/test/REQUIRE +++ b/test/REQUIRE @@ -1 +0,0 @@ -FactCheck diff --git a/test/runtests.jl b/test/runtests.jl index fb6bcf0..947b5cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,7 @@ println("\n\n\nRunning tests ...") -using QuickShiftClustering, FactCheck +using QuickShiftClustering quickshift(rand(2,1000)) println(" done running tests!") -