Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.3.8"
version = "0.4.0"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
34 changes: 19 additions & 15 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ flatten(x) = flatten(Float64, x)

function flatten(::Type{T}, x::Integer) where {T<:Real}
v = T[]
unflatten_to_Integer(v::Vector{T}) = x
unflatten_to_Integer(v::AbstractVector{<:Real}) = x
return v, unflatten_to_Integer
end

function flatten(::Type{T}, x::R) where {T<:Real,R<:Real}
v = T[x]
unflatten_to_Real(v::Vector{T}) = convert(R, only(v))
unflatten_to_Real(v::AbstractVector{<:Real}) = only(v)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this line will change the current behavior for unflatten quite drastically:

x = (a = 1., b = [2., 3.], c = [4 5 ; 6 7])
typeof(x.b) #Vector{Float64}
xvec, unflat = flatten(Float16, x)
x2 = unflat(xvec)
typeof(x2.b) #Vector{Float16}

I don't think there is any other way though to facilitate AD while keeping initial parameter types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed -- very breaking. From my perspective in terms of how I tend to use ParameterHandling in practice, the new behaviour is more helpful anyway though 🤷

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no way around it if you want to work with AD here. Type changes could be a problem if you define a concrete container to collect samples (e.g. MCMC) of your model parameter.

return v, unflatten_to_Real
end

flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} = (Vector{T}(x), Vector{R})
flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} = (Vector{T}(x), identity)

function _flatten_vector_integer(::Type{T}, x::AbstractVector{<:Integer}) where {T<:Real}
unflatten_to_Vector_Integer(x_vec) = x
unflatten_to_Vector_Integer(::AbstractVector{<:Real}) = x
return T[], unflatten_to_Vector_Integer
end

Expand All @@ -44,28 +44,32 @@ function flatten(::Type{T}, x::AbstractVector{<:Integer}) where {T<:Real}
return _flatten_vector_integer(T, x)
end

function flatten(::Type{T}, x::AbstractVector) where {T<:Real}
function flatten(::Type{T}, x::Vector) where {T<:Real}
x_vecs_and_backs = map(val -> flatten(T, val), x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
function Vector_from_vec(x_vec::AbstractVector{<:Real})
sz = _cumsum(map(length, x_vecs))
x_Vec = [
backs[n](x_vec[(sz[n] - length(x_vecs[n]) + 1):sz[n]]) for n in eachindex(x)
]
return oftype(x, x_Vec)
return collect(x_Vec)
end
return reduce(vcat, x_vecs), Vector_from_vec
end

function flatten(::Type{T}, x::AbstractArray) where {T<:Real}
function flatten(::Type{T}, x::Array) where {T<:Real}
x_vec, from_vec = flatten(T, vec(x))
Array_from_vec(x_vec) = oftype(x, reshape(from_vec(x_vec), size(x)))
function Array_from_vec(x_vec::AbstractVector{<:Real})
return collect(reshape(from_vec(x_vec), size(x)))
end
return x_vec, Array_from_vec
end

function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real}
x_vec, from_vec = flatten(T, x.nzval)
Array_from_vec(x_vec) = SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec))
function Array_from_vec(x_vec::AbstractVector{<:Real})
return SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec))
end
return x_vec, Array_from_vec
end

Expand All @@ -74,26 +78,26 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real}
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(length, x_vecs)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v::Vector{T})
function unflatten_to_Tuple(v::AbstractVector{<:Real})
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[(s - l + 1):s])
Copy link

@paschermayr paschermayr Sep 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be amazing if we can somehow find a way to make a @view here, so we do not generate a new vector for each argument in the tuple. The problem is that if we have to call NamedTuple{names}(v_vec_vec) instead of typeof(x)(v_vec_vec) in the NamedTuple dispatch below, we will get back different types for everything bar scalar parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I agree that that would be nice. Would you mind opening a separate issue to discuss further? I'd rather keep it out of scope for this PR.

end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
end

function flatten(::Type{T}, x::NamedTuple) where {T<:Real}
function flatten(::Type{T}, x::NamedTuple{names}) where {T<:Real,names}
x_vec, unflatten = flatten(T, values(x))
function unflatten_to_NamedTuple(v::Vector{T})
function unflatten_to_NamedTuple(v::AbstractVector{<:Real})
v_vec_vec = unflatten(v)
return typeof(x)(v_vec_vec)
return NamedTuple{names}(v_vec_vec)
end
return x_vec, unflatten_to_NamedTuple
end

function flatten(::Type{T}, d::Dict) where {T<:Real}
d_vec, unflatten = flatten(T, collect(values(d)))
function unflatten_to_Dict(v::Vector{T})
function unflatten_to_Dict(v::AbstractVector{<:Real})
v_vec_vec = unflatten(v)
return Dict(key => v_vec_vec[n] for (n, key) in enumerate(keys(d)))
end
Expand Down
22 changes: 13 additions & 9 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ value(x::Positive) = x.transform(x.unconstrained_value) + x.ε
function flatten(::Type{T}, x::Positive) where {T<:Real}
v, unflatten_to_Real = flatten(T, x.unconstrained_value)

function unflatten_Positive(v_new::Vector{T})
function unflatten_Positive(v_new::AbstractVector{<:Real})
return Positive(unflatten_to_Real(v_new), x.transform, x.ε)
end

Expand Down Expand Up @@ -84,10 +84,10 @@ function bounded(val::Real, lower_bound::Real, upper_bound::Real)
return Bounded(inv_transform(val), lb, ub, transform, ε)
end

struct Bounded{T<:Real,V<:Bijector,Tε<:Real} <: AbstractParameter
struct Bounded{T<:Real,Tbound<:Real,V<:Bijector,Tε<:Real} <: AbstractParameter
unconstrained_value::T
lower_bound::T
upper_bound::T
lower_bound::Tbound
upper_bound::Tbound
transform::V
ε::Tε
end
Expand All @@ -97,7 +97,7 @@ value(x::Bounded) = x.transform(x.unconstrained_value)
function flatten(::Type{T}, x::Bounded) where {T<:Real}
v, unflatten_to_Real = flatten(T, x.unconstrained_value)

function unflatten_Bounded(v_new::Vector{T})
function unflatten_Bounded(v_new::AbstractVector{<:Real})
return Bounded(
unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε
)
Expand All @@ -122,7 +122,7 @@ end
value(x::Fixed) = x.value

function flatten(::Type{T}, x::Fixed) where {T<:Real}
unflatten_Fixed(v_new::Vector{T}) = x
unflatten_Fixed(v_new::AbstractVector{<:Real}) = x
return T[], unflatten_Fixed
end

Expand All @@ -148,7 +148,7 @@ value(x::Deferred) = x.f(value(x.args)...)

function flatten(::Type{T}, x::Deferred) where {T<:Real}
v, unflatten = flatten(T, x.args)
unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new))
unflatten_Deferred(v_new::AbstractVector{<:Real}) = Deferred(x.f, unflatten(v_new))
return v, unflatten_Deferred
end

Expand Down Expand Up @@ -188,7 +188,9 @@ value(X::Orthogonal) = nearest_orthogonal_matrix(X.X)

function flatten(::Type{T}, X::Orthogonal) where {T<:Real}
v, unflatten_to_Array = flatten(T, X.X)
unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new))
function unflatten_Orthogonal(v_new::AbstractVector{<:Real})
return Orthogonal(unflatten_to_Array(v_new))
end
return v, unflatten_Orthogonal
end

Expand Down Expand Up @@ -217,7 +219,9 @@ value(X::PositiveDefinite) = A_At(vec_to_tril(X.L))

function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to the commit, but I think most Statistics packages work with the upper triangular, if I am not mistaken?

using Distributions, LinearAlgebra
Σ1 = UpperTriangular([1. .5; .5 1.])
Σ2 = LowerTriangular([1. .5; .5 1.])
Symmetric(Σ1) # 1 0.5 0.5 1
Symmetric(Σ2) # 1 0.0 0.0 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I agreee with your assertion that most packages use the upper triangle in the Julia ecosystem, but I'm not sure that it's a problem here, because the user should interact with the PositiveDefinite type via the positive_definite function, which just requires that the user provide a StridedMatrixwhich is positive definite (we should probably widen that to include Symmetric matrices...). Once inside that functionality, asking for the L field of a Cholesky is fine, albeit it may not be totally optimal.

Or maybe I've misunderstood where you're coming from with this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. It also should probably be the job of the user to check if his transformations make sense.

Copy link
Member

@st-- st-- Nov 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking for the L field of a Cholesky is fine, albeit it may not be totally optimal.

Where "not totally optimal" in this case is "an extra copy() of the entire matrix" (JuliaLang/julia#42920, and you can use workarounds such as PDMats.chol_lower, though note that the bugfix for JuliaStats/PDMats.jl#143 resulted in yet another AD issue unfortunately)...

v, unflatten_v = flatten(T, X.L)
unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new))
function unflatten_PositiveDefinite(v_new::AbstractVector{<:Real})
return PositiveDefinite(unflatten_v(v_new))
end
return v, unflatten_PositiveDefinite
end

Expand Down
37 changes: 21 additions & 16 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,30 @@ using Test

using ParameterHandling: AbstractParameter, value

# Handles equality of scalars, functions, tuples or arbitrary types
function default_equality(a::T, b::T; kwargs...) where {T}
vals = fieldvalues(a)

# If we don't have any fields then we're probably dealing with scalars
if isempty(vals)
# Only call isapprox for numbers, otherwise we fallback to ==
return T <: Number ? isapprox(a, b; kwargs...) : a == b
else
return all(t -> default_equality(t...; kwargs...), zip(vals, fieldvalues(b)))
end
# Handles equality of structs / mutable structs.
function default_equality(a::Ta, b::Tb; kwargs...) where {Ta,Tb}
(isstructtype(Ta) && isstructtype(Tb)) || throw(error("Arguments aren't structs"))
return all(t -> default_equality(t...; kwargs...), zip(fieldvalues(a), fieldvalues(b)))
end

default_equality(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...)

# Handles extracting elements from arrays.
# Needed because fieldvalues(a) are empty, but we may need to recurse depending on
# the element type
function default_equality(a::T, b::T; kwargs...) where {T<:AbstractArray}
function default_equality(a::AbstractArray, b::AbstractArray; kwargs...)
return all(t -> default_equality(t...; kwargs...), zip(a, b))
end

# Handles extracting values for any dictionary types
function default_equality(a::T, b::T; kwargs...) where {T<:AbstractDict}
function default_equality(a::AbstractDict, b::AbstractDict; kwargs...)
return all(t -> default_equality(t...; kwargs...), zip(values(a), values(b)))
end

struct MyReal{T} <: Real
v::T
end

# NOTE: May want to make the equality function a kwarg in the future.
function test_flatten_interface(x::T; check_inferred::Bool=true) where {T}
@testset "flatten($T)" begin
Expand All @@ -50,7 +49,9 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where {T}
@test typeof(_v) === Vector{Float64}
@test _v == v
@test default_equality(x, unflatten(_v))
@test _unflatten(_v) isa T

# Check that unflattening works with different reals.
_unflatten(map(MyReal, randn(length(_v))))

# Check that everything infers properly.
check_inferred && @inferred flatten(Float64, x)
Expand All @@ -59,7 +60,9 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where {T}
_v, _unflatten = flatten(Float32, x)
@test typeof(_v) === Vector{Float32}
@test default_equality(x, _unflatten(_v); atol=1e-5)
@test _unflatten(_v) isa T

# Check that unflattening works with different precisions.
_unflatten(map(MyReal, randn(length(_v))))

# Check that everything infers properly.
check_inferred && @inferred flatten(Float32, x)
Expand All @@ -68,7 +71,9 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where {T}
_v, _unflatten = flatten(Float16, x)
@test typeof(_v) === Vector{Float16}
@test default_equality(x, _unflatten(_v); atol=1e-2)
@test _unflatten(_v) isa T

# Check that unflattening works with different precisions.
_unflatten(map(MyReal, randn(length(_v))))

# Check that everything infers properly.
check_inferred && @inferred flatten(Float16, x)
Expand Down
1 change: 0 additions & 1 deletion test/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

@testset "Tuple" begin
test_flatten_interface((1.0, 2.0); check_inferred=tuple_infers)

test_flatten_interface((1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers)
end

Expand Down