-
Notifications
You must be signed in to change notification settings - Fork 230
BNP priors for random partitions #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
That's great @trappmartin . If you want to have a look at https://github.com/mlomeli1/SMC-MPhilproject/tree/master/MFM_for_Turing , there is some code for SMC for DPM as well. I believe you have access to this private repo. I also have my Phd Matlab code for the Q-class, let me know if you would like access to that repo if it is useful :) |
|
Looking forward to that :) |
|
Thanks, @emilemathieu and @mlomeli1! I changed the implementation of BNP priors by separating the representation from the stochastic process. For example, a Pitman-Yor process can now be constructed as follows: a = 0.5
θ = 0.1
t = 2
# stick-breaking representation
d = StickBreakingProcess(PitmanYorProcess(a, θ, t))
# size-biased sampling representation
surplus = 2.0
d = SizeBiasedSamplingProcess(PitmanYorProcess(a, θ, t), surplus)
# CRP representation
cluster_counts = [2, 1]
d = ChineseRestaurantProcess(PitmanYorProcess(a, θ, t), cluster_counts)Let me know what you think of the new interface. I hope it's easier to use and allows us to have a more flexible interface for BNP priors. Cheers, |
|
I believe such an interface is way better ! |
|
Thanks, I think it should not have much of an influence on the sampling process. I should see soon. :D |
|
Chinese Restaurant Process Example using current implementation: @model infiniteMM(y; H = Normal(mean(y), std(y) * 2), rpm = DirichletProcess(0.1) ) = begin
# Latent assignments.
N = length(y)
z = tzeros(Int, N)
# Cluster counts.
cluster_counts = tzeros(Int, N)
# Cluster locations.
x = tzeros(Float64, N)
for i in 1:N
# Draw assignments using a CRP.
z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts)
if cluster_counts[z[i]] == 0
# Cluster is new, therefore, draw new location.
x[z[i]] ~ H
end
cluster_counts[z[i]] += 1
# Draw observation.
y[i] ~ Normal(x[z[i]], 0.5)
end
return z
end |
|
Nice ! Do you think that interface allows to easily extends to NIGP and such ? |
|
Yes, I'm pretty certain this should be possible. For this PR the focus is only on DP and PYP but it totally makes sense to extend the code after merging this PR. |
|
Ups, I think I broke something. o.O |
|
changing |
|
I added the test for the stick-breaking representation. However, I'm a bit unsure if my implementation is correct / necessary or if we can use the one by @emilemathieu. Here is my version of a truncated stick-breaking in Turing. @model sbimm(y, rpm, trunc) = begin
# Base distribution.
H = Normal(mu_0, sigma_0)
# Latent assignments.
N = length(y)
z = tzeros(Int, N)
# Infinite collection of stick pieces and weights.
v = tzeros(Float64, trunc)
w = tzeros(Float64, trunc)
K = 0
# Cluster locations.
x = tzeros(Float64, trunc)
for i in 1:N
# Draw a slice ∈ [0,1].
u[i] ~ Beta(1, 1)
# Instantiate new cluster.
while (sum(w) < u[i]) && (K < trunc)
K += 1
v[K] ~ StickBreakingProcess(rpm)
x[K] ~ H
w[K] = v[K] * prod(1 .- v[1:(K-1)])
end
# Find truncation point
K_ = findfirst(u[i] .< cumsum(w))
# Sample assignments.
w_ = w[1:K_] / sum(w[1:K_])
z[i] ~ Categorical(w_)
# Draw observation.
y[i] ~ Normal(x[z[i]], sigma_1)
end
end@emilemathieu and @yebai what are your thoughts? |
|
Hi Martin! This seems to be a valid implementation of a stick breaking process :) |
|
@cpfiffer the tests of this PR seem to be broken due to some bug in displaying MCMCChains. Can you have a look? |
It's solved on the master branch; you need to rebase master into this PR. |
|
@emilemathieu I did some minor adjustment on your code for the SBS. Could you let me know if this is correct for simulation based sampling. I'm still not too familiar with the SBS and probably should read the paper again once I find the time. :) Based on: https://github.com/TuringLang/Turing.jl/blob/project-bnp/test/rpm.jl/imm.jl Thanks! @model sbsimm(y,rpm) = begin
# Base distribution.
H = Normal(mu_0, sigma_0)
# Latent assignments.
N = length(y)
z = tzeros(Int, N)
x = tzeros(Float64, N)
J = tzeros(Float64, N)
z = tzeros(Int, N)
k = 0
surplus = 1
for i in 1:N
ps = vcat(J[1:k], surplus)
z[i] ~ Categorical(ps)
if z[i] > k
k = k + 1
J[k] ~ SizeBiasedSamplingProcess(rpm, surplus)
x[k] ~ H
surplus -= J[k]
end
y[i] ~ Normal(x[z[i]], sigma_1)
end
end |
|
@yebai once the SBS sampling example is correct, this PR is ready for review and merging. |
|
This PR is ready to be merged from my side. |
test/rpm.jl/sb.jl
Outdated
| end | ||
|
|
||
| # Find truncation point | ||
| K_ = findfirst(u[i] .< cumsum(w)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a non-typical slice sampler for DPs. Do you have a reference for this slice sampling representation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. After looking at it again, it seems rather odd and is probably not quite correct. I can change it to the retrospective sampler by Papaspiliopoulos and Roberts which seems straight forward in Turing to me. Or do you have a preference for another one, e.g. Walker et al.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a preference for this; the method in Papaspiliopoulos and Roberts sounds good to me. Or, we can simply implement a basic recursive stick breaking if it's only for testing purpose. We can leave the task of advanced implementations till later, perhaps in a BNP tutorial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me. I already started a BNP tutorial and will put the Papaspiliopoulos and Roberts code in there. With basic recursive stick breaking you mean a truncated implementation with fixed truncation point, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean Alg 1 from the following paper:
http://www.robots.ox.ac.uk/~twgr/assets/pdf/bloemreddy2017rpm.pdf
This version doesn't involve any truncation through the use of a random coin-flip based termination criterion. I still need to read the original paper to have a better understanding of why this is equivalent to the standard stick-breaking, but my guess is that the expectation of the process converges to the standard stick-breaking process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yes, it looks like it does. I'll read the paper more carefully again as I forgot about the recursive coin-flipping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Algo 1 (coin-flipping based) is complicated to implement in Turing because of the variable name issue we have, i.e. the coin will not be resampled recursively because of:
Turing.jl/src/inference/pgibbs.jl
Lines 154 to 167 in c5ca896
| if ~haskey(vi, vn) | |
| r = rand(dist) | |
| push!(vi, vn, r, dist, spl.alg.gid) | |
| spl.info[:cache_updated] = CACHERESET # sanity flag mask for getidcs and getranges | |
| elseif is_flagged(vi, vn, "del") | |
| unset_flag!(vi, vn, "del") | |
| r = rand(dist) | |
| vi[vn] = vectorize(dist, r) | |
| setgid!(vi, spl.alg.gid, vn) | |
| setorder!(vi, vn, vi.num_produce) | |
| else | |
| updategid!(vi, vn, spl) | |
| r = vi[vn] | |
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, perhaps move this into a separate issue?
This should be fixed soon - see related discussion #720 (review).
|
It's probably good to have a dedicated folder for customised distributions in Turing, e.g. |
|
Excellent work - Ready to merge except one minor filename issue (see above). |
This PR is a work in progress PR, integrating the existing codes of #370 and #374 for random partitions.
TODO:
Turing.Chainto work with missing values.Changes to code base:
cc: @mlomeli1 , @emilemathieu