-
Notifications
You must be signed in to change notification settings - Fork 11
Overloading-AD-Friendly Unflatten #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,20 +21,20 @@ flatten(x) = flatten(Float64, x) | |
|
|
||
| function flatten(::Type{T}, x::Integer) where {T<:Real} | ||
| v = T[] | ||
| unflatten_to_Integer(v::Vector{T}) = x | ||
| unflatten_to_Integer(v::AbstractVector{<:Real}) = x | ||
| return v, unflatten_to_Integer | ||
| end | ||
|
|
||
| function flatten(::Type{T}, x::R) where {T<:Real,R<:Real} | ||
| v = T[x] | ||
| unflatten_to_Real(v::Vector{T}) = convert(R, only(v)) | ||
| unflatten_to_Real(v::AbstractVector{<:Real}) = only(v) | ||
| return v, unflatten_to_Real | ||
| end | ||
|
|
||
| flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} = (Vector{T}(x), Vector{R}) | ||
| flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} = (Vector{T}(x), identity) | ||
|
|
||
| function _flatten_vector_integer(::Type{T}, x::AbstractVector{<:Integer}) where {T<:Real} | ||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| unflatten_to_Vector_Integer(x_vec) = x | ||
| unflatten_to_Vector_Integer(::AbstractVector{<:Real}) = x | ||
| return T[], unflatten_to_Vector_Integer | ||
| end | ||
|
|
||
|
|
@@ -44,28 +44,32 @@ function flatten(::Type{T}, x::AbstractVector{<:Integer}) where {T<:Real} | |
| return _flatten_vector_integer(T, x) | ||
| end | ||
|
|
||
| function flatten(::Type{T}, x::AbstractVector) where {T<:Real} | ||
| function flatten(::Type{T}, x::Vector) where {T<:Real} | ||
| x_vecs_and_backs = map(val -> flatten(T, val), x) | ||
| x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) | ||
| function Vector_from_vec(x_vec) | ||
| function Vector_from_vec(x_vec::AbstractVector{<:Real}) | ||
| sz = _cumsum(map(length, x_vecs)) | ||
| x_Vec = [ | ||
| backs[n](x_vec[(sz[n] - length(x_vecs[n]) + 1):sz[n]]) for n in eachindex(x) | ||
| ] | ||
| return oftype(x, x_Vec) | ||
| return collect(x_Vec) | ||
| end | ||
| return reduce(vcat, x_vecs), Vector_from_vec | ||
| end | ||
|
|
||
| function flatten(::Type{T}, x::AbstractArray) where {T<:Real} | ||
| function flatten(::Type{T}, x::Array) where {T<:Real} | ||
| x_vec, from_vec = flatten(T, vec(x)) | ||
| Array_from_vec(x_vec) = oftype(x, reshape(from_vec(x_vec), size(x))) | ||
| function Array_from_vec(x_vec::AbstractVector{<:Real}) | ||
| return collect(reshape(from_vec(x_vec), size(x))) | ||
| end | ||
| return x_vec, Array_from_vec | ||
| end | ||
|
|
||
| function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real} | ||
| x_vec, from_vec = flatten(T, x.nzval) | ||
| Array_from_vec(x_vec) = SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec)) | ||
| function Array_from_vec(x_vec::AbstractVector{<:Real}) | ||
| return SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec)) | ||
| end | ||
| return x_vec, Array_from_vec | ||
| end | ||
|
|
||
|
|
@@ -74,26 +78,26 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real} | |
| x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) | ||
| lengths = map(length, x_vecs) | ||
| sz = _cumsum(lengths) | ||
| function unflatten_to_Tuple(v::Vector{T}) | ||
| function unflatten_to_Tuple(v::AbstractVector{<:Real}) | ||
| map(x_backs, lengths, sz) do x_back, l, s | ||
| return x_back(v[(s - l + 1):s]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be amazing if we can somehow find a way to make a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I agree that that would be nice. Would you mind opening a separate issue to discuss further? I'd rather keep it out of scope for this PR. |
||
| end | ||
| end | ||
| return reduce(vcat, x_vecs), unflatten_to_Tuple | ||
| end | ||
|
|
||
| function flatten(::Type{T}, x::NamedTuple) where {T<:Real} | ||
| function flatten(::Type{T}, x::NamedTuple{names}) where {T<:Real,names} | ||
| x_vec, unflatten = flatten(T, values(x)) | ||
| function unflatten_to_NamedTuple(v::Vector{T}) | ||
| function unflatten_to_NamedTuple(v::AbstractVector{<:Real}) | ||
| v_vec_vec = unflatten(v) | ||
| return typeof(x)(v_vec_vec) | ||
| return NamedTuple{names}(v_vec_vec) | ||
| end | ||
| return x_vec, unflatten_to_NamedTuple | ||
| end | ||
|
|
||
| function flatten(::Type{T}, d::Dict) where {T<:Real} | ||
| d_vec, unflatten = flatten(T, collect(values(d))) | ||
| function unflatten_to_Dict(v::Vector{T}) | ||
| function unflatten_to_Dict(v::AbstractVector{<:Real}) | ||
| v_vec_vec = unflatten(v) | ||
| return Dict(key => v_vec_vec[n] for (n, key) in enumerate(keys(d))) | ||
| end | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,7 +48,7 @@ value(x::Positive) = x.transform(x.unconstrained_value) + x.ε | |
| function flatten(::Type{T}, x::Positive) where {T<:Real} | ||
| v, unflatten_to_Real = flatten(T, x.unconstrained_value) | ||
|
|
||
| function unflatten_Positive(v_new::Vector{T}) | ||
| function unflatten_Positive(v_new::AbstractVector{<:Real}) | ||
| return Positive(unflatten_to_Real(v_new), x.transform, x.ε) | ||
| end | ||
|
|
||
|
|
@@ -84,10 +84,10 @@ function bounded(val::Real, lower_bound::Real, upper_bound::Real) | |
| return Bounded(inv_transform(val), lb, ub, transform, ε) | ||
| end | ||
|
|
||
| struct Bounded{T<:Real,V<:Bijector,Tε<:Real} <: AbstractParameter | ||
| struct Bounded{T<:Real,Tbound<:Real,V<:Bijector,Tε<:Real} <: AbstractParameter | ||
| unconstrained_value::T | ||
| lower_bound::T | ||
| upper_bound::T | ||
| lower_bound::Tbound | ||
| upper_bound::Tbound | ||
| transform::V | ||
| ε::Tε | ||
| end | ||
|
|
@@ -97,7 +97,7 @@ value(x::Bounded) = x.transform(x.unconstrained_value) | |
| function flatten(::Type{T}, x::Bounded) where {T<:Real} | ||
| v, unflatten_to_Real = flatten(T, x.unconstrained_value) | ||
|
|
||
| function unflatten_Bounded(v_new::Vector{T}) | ||
| function unflatten_Bounded(v_new::AbstractVector{<:Real}) | ||
| return Bounded( | ||
| unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε | ||
| ) | ||
|
|
@@ -122,7 +122,7 @@ end | |
| value(x::Fixed) = x.value | ||
|
|
||
| function flatten(::Type{T}, x::Fixed) where {T<:Real} | ||
| unflatten_Fixed(v_new::Vector{T}) = x | ||
| unflatten_Fixed(v_new::AbstractVector{<:Real}) = x | ||
| return T[], unflatten_Fixed | ||
| end | ||
|
|
||
|
|
@@ -148,7 +148,7 @@ value(x::Deferred) = x.f(value(x.args)...) | |
|
|
||
| function flatten(::Type{T}, x::Deferred) where {T<:Real} | ||
| v, unflatten = flatten(T, x.args) | ||
| unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new)) | ||
| unflatten_Deferred(v_new::AbstractVector{<:Real}) = Deferred(x.f, unflatten(v_new)) | ||
| return v, unflatten_Deferred | ||
| end | ||
|
|
||
|
|
@@ -188,7 +188,9 @@ value(X::Orthogonal) = nearest_orthogonal_matrix(X.X) | |
|
|
||
| function flatten(::Type{T}, X::Orthogonal) where {T<:Real} | ||
| v, unflatten_to_Array = flatten(T, X.X) | ||
| unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new)) | ||
| function unflatten_Orthogonal(v_new::AbstractVector{<:Real}) | ||
| return Orthogonal(unflatten_to_Array(v_new)) | ||
| end | ||
| return v, unflatten_Orthogonal | ||
| end | ||
|
|
||
|
|
@@ -217,7 +219,9 @@ value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) | |
|
|
||
| function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not related to the commit, but I think most Statistics packages work with the upper triangular, if I am not mistaken? using Distributions, LinearAlgebra
Σ1 = UpperTriangular([1. .5; .5 1.])
Σ2 = LowerTriangular([1. .5; .5 1.])
Symmetric(Σ1) # 1 0.5 0.5 1
Symmetric(Σ2) # 1 0.0 0.0 1
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. I agreee with your assertion that most packages use the upper triangle in the Julia ecosystem, but I'm not sure that it's a problem here, because the user should interact with the Or maybe I've misunderstood where you're coming from with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are right. It also should probably be the job of the user to check if his transformations make sense.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Where "not totally optimal" in this case is "an extra copy() of the entire matrix" (JuliaLang/julia#42920, and you can use workarounds such as |
||
| v, unflatten_v = flatten(T, X.L) | ||
| unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new)) | ||
| function unflatten_PositiveDefinite(v_new::AbstractVector{<:Real}) | ||
| return PositiveDefinite(unflatten_v(v_new)) | ||
| end | ||
| return v, unflatten_PositiveDefinite | ||
| end | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this line will change the current behavior for unflatten quite drastically:
I don't think there is any other way though to facilitate AD while keeping initial parameter types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed -- very breaking. From my perspective in terms of how I tend to use ParameterHandling in practice, the new behaviour is more helpful anyway though 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no way around it if you want to work with AD here. Type changes could be a problem if you define a concrete container to collect samples (e.g. MCMC) of your model parameter.