Skip to content

Conversation

@torfjelde
Copy link
Member

@torfjelde torfjelde commented Mar 27, 2021

EDIT: This PR needs some rework as the process can be simplifed significantly now that we have proper macroexpansion within models.

It's Saturday. Saturday is the day to be a bit wild and "let loose" as the kids say.

Therefore I tried coming up with a decent solution to this problem, and I think I've arrived at something that will at least leave some hair on @devmotion's head.

Result

Now it's possible to do stuff like

julia> using DynamicPPL, Distributions, Bijectors

julia> @model function demo(x)
           @reparam exp m ~ Normal()

           for i in eachindex(x)
               @reparam Bijectors.Shift(m) x[i] ~ Normal()
           end

           return (m = m, x = x, lp = DynamicPPL.getlogp(_varinfo))
       end
┌ Warning: you are using the internal variable `_varinfo`
└ @ DynamicPPL ~/Projects/public/DynamicPPL.jl/src/compiler.jl:171
demo (generic function with 1 method)

julia> demo([missing])()
(m = 0.707759569697004, x = Union{Missing, Float64}[2.2558762896658413], lp = -3.095947005181318)

julia> demo([missing])()
(m = 0.3357617167878797, x = Union{Missing, Float64}[-0.9206668753706337], lp = -3.222709752095152)

julia> demo([1.0])()
(m = 0.5493840362537875, x = [1.0], lp = -2.118779520599513)

julia> demo([1.0])()
(m = 0.6311868436969137, x = [1.0], lp = -2.0117591926730154)

where the internal variable which is kept track of is denoted by _:

julia> DynamicPPL.VarInfo(demo([1.0])).metadata.m_
DynamicPPL.Metadata{Dict{VarName{:m_, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m_, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}(Dict(m_ => 1), [m_], UnitRange{Int64}[1:1], [0.2300906444320414], Normal{Float64}[Normal{Float64}=0.0, σ=1.0)], Set{DynamicPPL.Selector}[Set()], [0], Dict{String, BitVector}("del" => [0], "trans" => [0]))

And for the sake of completeness:

julia> using Turing
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]

julia> @model function demo(x)
           @reparam exp m ~ Normal()

           for i in eachindex(x)
               @reparam Bijectors.Shift(m) x[i] ~ Normal()
           end

           return (m = m, x = x, lp = DynamicPPL.getlogp(_varinfo))
       end
┌ Warning: you are using the internal variable `_varinfo`
└ @ DynamicPPL ~/Projects/public/DynamicPPL.jl/src/compiler.jl:171
demo (generic function with 1 method)

julia> chain = sample(m, NUTS(), 1000)
┌ Info: Found initial step size
└   ϵ = 0.0234375
Sampling 100%|██████████████████████████████████████████████████████| Time: 0:00:06
Chains MCMC chain (1000×13×1 Array{Float64,3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
parameters        = m_
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

          m_    1.6174    0.0199     0.0006    0.0004   485.9247    0.9998

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          m_    1.5762    1.6046    1.6172    1.6304    1.6549
          
          
julia> mean(exp.(chain.value[:, :m_, :]))
5.040949482223531

To see the above example's expanded code, see the bottom of this PR.

What does the people want?

Desire for such a feature have been expressed several times, e.g. #94 TuringLang/Turing.jl#1444 and loads of times in our Slack channel.

The people want the ability to:

  1. Reparameterize distributions, instead of writing x ~ Normal(μ, σ) they want to write
    x = begin
        x_ ~ Normal()
        μ + σ * x_
    end
    so that Turing instead sees the variable x_, since this can make the geometry of the posterior nicer (i.e. more numerically stable and/or better suited for the metric chosen in something like HMC).
  2. They also want the ability to observe transformed distributions when it makes sense. This only makes sense when you can actually convert from the observation to the "inner" distribution (Normal() in the above case), i.e. only when the transformation f is invertible. What's that you say? Come on now, everyone say it together. 1...2...3..A BIJECTOR! Very good. So we can do this for transformations present in Bijectors.jl.
    • If you say, "Wait, why do we need reparameterizations for observations? Why can't we just use transformed(dist, bijector) and observe using this?". Sometimes people have variables present in the arguments of the model which are not always observed; sometimes it's instead missing. In those case you want to sample, and we're back to case (1).

Why isn't TransformedDistribution enough?

  • Importantly it doesn't support arbitrary transformations, e.g. abs ain't invertible.
  • But even then, DynamicPPL.jl only sees the transformed variable, not the "internal"/"untransformed" variable, and thus we don't get the benefit of the reparameterization.
    • We could mess around with overloading assume and whatnot to make this possible, but we'd still have the issue that we could not support non-bijective transformations.

Why not introduce a ReparamDistribution which allows non-invertible transformations?

  • We can't compute the logpdf, and so representing it as a Distribution will be disingenious.
  • It will also require some heavy work on the insides and we'd have to be very careful in how we overload assume and deal with link and invlink.

Solution?

Why not just introduce a @reparam macro that isn't even a macro?:)

Wait! I know what you're thinking "We literally went through this before, where we replaced all those fake macros like @varinfo with 'internal variables', e.g._varinfo_." Yeah, yeah, I know. BUT this is different! Kind of.

By introducing a "fake" macro we really don't have to do much to make things just work. The idea is to take the following representation of a model:

@model function demo(x)
    @reparam identity m ~ Normal()

    for i in eachindex(x)
        @reparam Bijectors.Shift(m) x[i] ~ Normal()
    end

    return (m = m, x = x, lp = DynamicPPL.getlogp(_varinfo))
end

and convert into

@model function demo(x)
    m = begin
        m_ ~ Normal()
        f(m)
    end

    for i in eachindex(x)
        f = Bijectors.Shift(m) # <= `f` will be created using `@gensym`
        # If `x[i]` is `missing`, we get:
        x[i] = begin
            x_[i] ~ Normal() # <= `x_[i]` will be created using `@gensym`
            f(x_[i])
        end
        # If `x[i]` is NOT `missing`, we get:
        if f isa Bijectors.AbstractBijector
            logpdf(Normal(), inv(f)(x[i]))
        else
            throw(ArgumentError("You fool! You can't observe using a non-bijective reparameterization!!!"))
        end
    end

    return (m = m, x = x, lp = DynamicPPL.getlogp(_varinfo))
end

This means that we won't have to worry about assume, link, invlink, etc. These are all handled as usual for the "base" distribution.

Caveats

  • @reparam isn't really a macro since it will have to be captured in DynamicPPL.generate_mainbody! before it's expanded.
  • The resulting chain will instead of the untransformed variables with the underscore behind them. This is of course non-ideal, but IMO is a small cost to pay vs. not having the ability to do reparameterizations at all. This approach is also taken by other PPLs, e.g. pymc3, so users will likely be familiar with the idea of transformed variables ending up in the chain with an underscore behind it.

Alternatives/To discuss

  1. We could use some other mechanism than @reparam, e.g. use x ~ f dist or something. So what the user sees is def something we can discuss more. We could even go a bit crazy and do

TODOs

The todo's are straight-forward but I just didn't bother until we've discussed the change.

  • Implement generate_dot_tilde_with_reparam (super-easy)
  • Tests

Example of expanded model:

julia> expr = @macroexpand @model function demo(x)
           @reparam exp m ~ Normal()

           for i in eachindex(x)
               @reparam Bijectors.Shift(m) x[i] ~ Normal()
           end
       end;

julia> expr |> Base.remove_linenums!
quote
    $(Expr(:meta, :doc))
    function demo(x; )
        var"##evaluator#376" = ((_rng::Random.AbstractRNG, _model::Model, _varinfo::AbstractVarInfo, _sampler::AbstractMCMC.AbstractSampler, _context::DynamicPPL.AbstractContext, x)->begin
                    begin
                        begin
                            var"##tmpright#362" = Normal()
                            var"##tmpright#362" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
                            var"##vn#364" = (DynamicPPL.varname2intermediate)(m)
                            var"##inds#365" = ()
                            var"##f#367" = exp
                            m = begin
                                    var"##left_intermediate#366" = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#362", var"##vn#364", var"##inds#365", _varinfo)
                                    var"##f#367"(var"##left_intermediate#366")
                                end
                        end
                        for i = eachindex(x)
                            begin
                                var"##tmpright#368" = Normal()
                                var"##tmpright#368" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
                                var"##vn#370" = (DynamicPPL.varname2intermediate)((VarName)(:x, ((i,),)))
                                var"##inds#371" = ((i,),)
                                var"##f#373" = Bijectors.Shift(m)
                                var"##isassumption#374" = begin
                                        let var"##vn#375" = (VarName)(:x, ((i,),))
                                            if !((DynamicPPL.inargnames)(var"##vn#375", _model)) || (DynamicPPL.inmissings)(var"##vn#375", _model)
                                                true
                                            else
                                                x[i] === missing
                                            end
                                        end
                                    end
                                if var"##isassumption#374"
                                    x[i] = begin
                                            var"##left_intermediate#372" = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#368", var"##vn#370", var"##inds#371", _varinfo)
                                            var"##f#373"(var"##left_intermediate#372")
                                        end
                                else
                                    if var"##f#373" isa Bijectors.AbstractBijector
                                        var"##left_intermediate#372" = (inv(var"##f#373"))(x[i])
                                        (DynamicPPL.tilde_observe)(_context, _sampler, var"##tmpright#368", var"##left_intermediate#372", var"##vn#370", var"##inds#371", _varinfo)
                                    else
                                        throw(ArgumentError("cannot observe non-invertible reparameterization!!!"))
                                    end
                                end
                            end
                        end
                    end
                end)
        return (Model)(:demo, var"##evaluator#376", (DynamicPPL.namedtuple)(NamedTuple{(:x,), Tuple{Core.Typeof(x)}}, (x,)), NamedTuple())
    end
end

@devmotion
Copy link
Member

I like the general idea 👍

I have to admit though that I did not get what @reparam does when I read the examples at the top. So I think the syntax might be too unintuitive - maybe it is clear to everyone else except me? I assume that something like

x ~ transform(f, Normal())

or

x ~ f  Normal()

would be more intuitive. And I noticed that you already mentioned such an alternative at the end of your post 🙂

Additionally, I felt it would be more consistent with the typical use of _ for private/unexported methods or variables if the internal variable would start with an underscore and not end with one, i.e., if the internal variable would be _m in your example. (BTW I think it would also be more consistent with the Julia ecosystem if the macro internals such as _varinfo would be called __varinfo__ etc - this seems the common way to denote such internal variables e.g. in Julia base (__source__ etc.) and Zygote (__context__).)

@torfjelde
Copy link
Member Author

I have to admit though that I did not get what @reparam does when I read the examples at the top.

Yeah the more I look at it, the more I think there ought to be a better approach. But IMO it needs to be something more special than just transform(f, dist) since this already "overlaps" with TransformedDistribution from Bijectors.jl + it gives the user an indication that this is "normal Julia code" (which it isn't going to be, at least with the method proposed in this PR).

So I'm more a fan of the f ∘ Normal() or something along those lines, but this would also be considered "regular" Julia-code, no? I was thinking maybe even

x ~ f Normal()

would just do it, since ~ is already "special" in Turing and not standard Julia-code, we might as well make use of that for this "transformed"-syntax?

Additionally, I felt it would be more consistent with the typical use of _ for private/unexported methods or variables if the internal variable would start with an underscore and not end with one, i.e., if the internal variable would be _m in your example.

Very okay with this! I think _m makes more sense given we already use _varinfo, though I do think people coming from other PPLs, e.g. pymc3, will be familiar with m_ version. To me it's whatever:)

And regarding the __varinfo__, I agree that this should probably be used for those internal variables as you said! Though I don't think it should be used for reparameterized variables as it's going to be annoying to access in the chain afterwards 😕 But I'm not certain you even suggested that, just figured I'd mention it 😅

@devmotion
Copy link
Member

Hmm yes, I guess f ∘ Normal() is not a good idea either since Distributions actually implements linear transformations of location scale distributions, so we would hijack this behaviour. I am a bit worried that f Normal() is not clear enough and easy to miss if f is a more complicated expression. Some other possible alternatives:

x ~ Normal() transform=f
x ~ f @ Normal()
x ~ @transform f Normal()
x ~ @reparam f Normal()

I am not completely sure anymore though to what extent the proposal here would actually address the issues and user requests mentioned in the OP. IIRC usually users wanted to track the transformed x, i.e., they wanted that x ends up in the sampled chain. The PR doesn't solve this problem since the chain would still only contain the untransformed _x (which is also hidden in the model implementation), it seems?

Regarding the names, I wanted to suggest _x instead of x_ but not __x__ 🙂 Personally I would prefer _x since it felt more "Julian" but I don't have a strong opinion there.

@torfjelde
Copy link
Member Author

x ~ Normal() transform=f

I honestly kind of like this! Explicit and simple.

I am not completely sure anymore though to what extent the proposal here would actually address the issues and user requests mentioned in the OP.

Yes and no. It indeed doesn't solve that particular part of the issue, but IMO it's a step towards it. Of course we can always do this by hand, but it quickly becomes annoying.
Currently working on a project where we're looking to do very high-dimensional inference, to the point where performing adapation for HMC is infeasible. As a result we might want to replace a truncated(Normal(), 0, Inf) with x ~ abs Normal() to ensure that the sampler sees the "untransformed" variable, but it's quite annoying to have to do x_ ~ Normal(); x = abs(x_) everywhere and it really clutters the model definition.

Regarding the names, I wanted to suggest x instead of x but not x slightly_smiling_face Personally I would prefer _x since it felt more "Julian" but I don't have a strong opinion there.

Ah, good:) Yeah, we can see what people think.

EDIT: Actually, I wonder if maybe it would be a good idea to introduce a tracking mechanism while I'm at it. I'll have to think a bit.

@torfjelde
Copy link
Member Author

x ~ Normal() transform=f

I honestly kind of like this! Explicit and simple.

This doesn't work btw. We need a , between the two, or something like that. Hmm.

@torfjelde
Copy link
Member Author

torfjelde commented Mar 29, 2021

Wait..why don't we just use your Dirac distribution for tracking? So instead of just returning the variable, I'll make it an assume statement with a Dirac on the RHS. Then we'll get tracking for the transformed variable too.

EDIT: Nvm, I guess it messes up stuff like HMC sampling since HMC will attempt to change the parameter.

bors bot pushed a commit that referenced this pull request Apr 4, 2021
At the moment we will actually call `generate_mainbody!` on inputs to macros inside the model, e.g. in a model `@mymacro x ~ Normal()` will actually result in code `@mymacro $(generate_mainbody!(:(x ~ Normal())))` (or something, you get the idea). 

IMO, this shouldn't be done for the following reasons:
1. Breaks with what you'd expect in Julia, IMO, which is that a macro eats the "raw" code.
2. Means that if we want to do stuff like `@reparam` from #220  (and a bunch of others, see #221 for a small list of possibilities), we need touch the compiler rather than just make a small macro that will perform transformations *after* the compiler has done it's job (referring to DynamicPPL compiler here). 
3. If the user wants to use a macro on some variables, but they want the actual variable rather than messing around with the sample-statement, they can just separate it into two lines, e.g. `x ~ Normal(); @mymacro ...`. 

Also, to be completely honest, for the longest time I've just assumed that I'm not even allowed to do `@mymacro x ~ Normal()` and have things work 😅 I bet a lot of people have the same impression by default (though this might of course just not be true:) )
bors bot pushed a commit that referenced this pull request Apr 4, 2021
At the moment we will actually call `generate_mainbody!` on inputs to macros inside the model, e.g. in a model `@mymacro x ~ Normal()` will actually result in code `@mymacro $(generate_mainbody!(:(x ~ Normal())))` (or something, you get the idea). 

IMO, this shouldn't be done for the following reasons:
1. Breaks with what you'd expect in Julia, IMO, which is that a macro eats the "raw" code.
2. Means that if we want to do stuff like `@reparam` from #220  (and a bunch of others, see #221 for a small list of possibilities), we need touch the compiler rather than just make a small macro that will perform transformations *after* the compiler has done it's job (referring to DynamicPPL compiler here). 
3. If the user wants to use a macro on some variables, but they want the actual variable rather than messing around with the sample-statement, they can just separate it into two lines, e.g. `x ~ Normal(); @mymacro ...`. 

Also, to be completely honest, for the longest time I've just assumed that I'm not even allowed to do `@mymacro x ~ Normal()` and have things work 😅 I bet a lot of people have the same impression by default (though this might of course just not be true:) )
bors bot pushed a commit that referenced this pull request Apr 7, 2021
## Overview
At the moment, we perform a check at model-expansion as to whether or not `vsym(left) in args`, where `args` is the arguments of the model. 
1. If `true`, we return a block of code which uses `DynamicPPL.isassumption` to check whether or not to call `assume` or `observe` for the the variable present in `args`. 
2. Otherwise, we generate a block which is identical to the `assume` block in the if-statement mentioned in (1).

The thing is, `DynamicPPL.isassumption` performs exactly the same check as above but using `DynamicPPL.inargnames`, i.e. at runtime. So if we're using  `TypedVarInfo`, the check at macro-expansion vs. at runtime is completely redundant since all the information necessary to determine `DynamicPPL.inargnames` is available at compile-time.

Therefore I suggest we remove this check at model-expansion, and simply handle it using `DynamicPPL.isassumption`.

## Pros & cons
Pros:
- No need to pass `args` around everywhere
- `generate_tilde` and `generate_dot_tilde` are much simpler: two possible blocks we can generate, either a) assume/observe, or b) observe literal.

Cons:
- We need to perform _one_ more check at runtime when using `UntypedVarInfo`.


**IMO, this is really worth it.**

## Motivation (sort of)
The main motivation behind this PR is simplification, but there's a different reason why I came across this.

I came to this because I was thinking about trying to "customize" the behavior of `~`, and I was thinking of using a macro to do it, e.g. `@mymacro x ~ Normal()`. Atm we're actually performing model-expansion on the code passed to the macro and thus trying to alter the way DynamicPPL treats `~` using a macro is veeeery difficult since you actually have to work with the *expanded* code, but let's ignore that issue for now (and take that discussion somewhere else, because IMO we shouldn't do this). 

Suppose we didn't perform model-expansions of the code fed to the macros, then you can just copy-paste `generate_tilde`, customize it do what you want, and BAM, you got yourself a working `@mymacro x ~ Normal()` which can do neat stuff! This is *not* possible atm because we don't have access to `args`, and so you have to take the approach in this PR to get there. That means that it's of course possible to do atm, but it's a bit icky since it ends up looking fundamentally different from `generate_tilde` rather than just slightly different.

Then we can implement things like a `@tilde` which will expand to `generate_tilde` which can be used *internally* in functions (if the "internal" variables are present in the functions of course, but we can also simplify this in different ways), actually allowing people to modularize their models a bit, and `@reparam` from #220 using very similar pieces of code, a `@track` macro can be introduced to deal with the explicit tracking of variables rather than putting this directly in the compiler, etc. Endless opportunities! (Of course, I'm not suggesting we add these, but this makes it a bit easier to explore.)

Co-authored-by: David Widmann <[email protected]>
@torfjelde
Copy link
Member Author

Closing this as there are now much easier ways to add this in:)

@torfjelde torfjelde closed this Jul 10, 2021
@yebai yebai deleted the tor/reparameterizations branch January 28, 2022 20:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants