diff --git a/Project.toml b/Project.toml index e17b62f..f5c066a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterHandling" uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5" authors = ["Invenia Technical Computing Corporation"] -version = "0.3.8" +version = "0.4.0" [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" diff --git a/src/flatten.jl b/src/flatten.jl index 717b27b..5779a66 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -21,20 +21,20 @@ flatten(x) = flatten(Float64, x) function flatten(::Type{T}, x::Integer) where {T<:Real} v = T[] - unflatten_to_Integer(v::Vector{T}) = x + unflatten_to_Integer(v::AbstractVector{<:Real}) = x return v, unflatten_to_Integer end function flatten(::Type{T}, x::R) where {T<:Real,R<:Real} v = T[x] - unflatten_to_Real(v::Vector{T}) = convert(R, only(v)) + unflatten_to_Real(v::AbstractVector{<:Real}) = only(v) return v, unflatten_to_Real end -flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} = (Vector{T}(x), Vector{R}) +flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} = (Vector{T}(x), identity) function _flatten_vector_integer(::Type{T}, x::AbstractVector{<:Integer}) where {T<:Real} - unflatten_to_Vector_Integer(x_vec) = x + unflatten_to_Vector_Integer(::AbstractVector{<:Real}) = x return T[], unflatten_to_Vector_Integer end @@ -44,28 +44,32 @@ function flatten(::Type{T}, x::AbstractVector{<:Integer}) where {T<:Real} return _flatten_vector_integer(T, x) end -function flatten(::Type{T}, x::AbstractVector) where {T<:Real} +function flatten(::Type{T}, x::Vector) where {T<:Real} x_vecs_and_backs = map(val -> flatten(T, val), x) x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) - function Vector_from_vec(x_vec) + function Vector_from_vec(x_vec::AbstractVector{<:Real}) sz = _cumsum(map(length, x_vecs)) x_Vec = [ backs[n](x_vec[(sz[n] - length(x_vecs[n]) + 1):sz[n]]) for n in eachindex(x) ] - return oftype(x, x_Vec) + return collect(x_Vec) end return reduce(vcat, x_vecs), Vector_from_vec end -function flatten(::Type{T}, x::AbstractArray) where {T<:Real} +function flatten(::Type{T}, x::Array) where {T<:Real} x_vec, from_vec = flatten(T, vec(x)) - Array_from_vec(x_vec) = oftype(x, reshape(from_vec(x_vec), size(x))) + function Array_from_vec(x_vec::AbstractVector{<:Real}) + return collect(reshape(from_vec(x_vec), size(x))) + end return x_vec, Array_from_vec end function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real} x_vec, from_vec = flatten(T, x.nzval) - Array_from_vec(x_vec) = SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec)) + function Array_from_vec(x_vec::AbstractVector{<:Real}) + return SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec)) + end return x_vec, Array_from_vec end @@ -74,7 +78,7 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real} x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) lengths = map(length, x_vecs) sz = _cumsum(lengths) - function unflatten_to_Tuple(v::Vector{T}) + function unflatten_to_Tuple(v::AbstractVector{<:Real}) map(x_backs, lengths, sz) do x_back, l, s return x_back(v[(s - l + 1):s]) end @@ -82,18 +86,18 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real} return reduce(vcat, x_vecs), unflatten_to_Tuple end -function flatten(::Type{T}, x::NamedTuple) where {T<:Real} +function flatten(::Type{T}, x::NamedTuple{names}) where {T<:Real,names} x_vec, unflatten = flatten(T, values(x)) - function unflatten_to_NamedTuple(v::Vector{T}) + function unflatten_to_NamedTuple(v::AbstractVector{<:Real}) v_vec_vec = unflatten(v) - return typeof(x)(v_vec_vec) + return NamedTuple{names}(v_vec_vec) end return x_vec, unflatten_to_NamedTuple end function flatten(::Type{T}, d::Dict) where {T<:Real} d_vec, unflatten = flatten(T, collect(values(d))) - function unflatten_to_Dict(v::Vector{T}) + function unflatten_to_Dict(v::AbstractVector{<:Real}) v_vec_vec = unflatten(v) return Dict(key => v_vec_vec[n] for (n, key) in enumerate(keys(d))) end diff --git a/src/parameters.jl b/src/parameters.jl index a3c1dd4..ed7f752 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -48,7 +48,7 @@ value(x::Positive) = x.transform(x.unconstrained_value) + x.ε function flatten(::Type{T}, x::Positive) where {T<:Real} v, unflatten_to_Real = flatten(T, x.unconstrained_value) - function unflatten_Positive(v_new::Vector{T}) + function unflatten_Positive(v_new::AbstractVector{<:Real}) return Positive(unflatten_to_Real(v_new), x.transform, x.ε) end @@ -84,10 +84,10 @@ function bounded(val::Real, lower_bound::Real, upper_bound::Real) return Bounded(inv_transform(val), lb, ub, transform, ε) end -struct Bounded{T<:Real,V<:Bijector,Tε<:Real} <: AbstractParameter +struct Bounded{T<:Real,Tbound<:Real,V<:Bijector,Tε<:Real} <: AbstractParameter unconstrained_value::T - lower_bound::T - upper_bound::T + lower_bound::Tbound + upper_bound::Tbound transform::V ε::Tε end @@ -97,7 +97,7 @@ value(x::Bounded) = x.transform(x.unconstrained_value) function flatten(::Type{T}, x::Bounded) where {T<:Real} v, unflatten_to_Real = flatten(T, x.unconstrained_value) - function unflatten_Bounded(v_new::Vector{T}) + function unflatten_Bounded(v_new::AbstractVector{<:Real}) return Bounded( unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε ) @@ -122,7 +122,7 @@ end value(x::Fixed) = x.value function flatten(::Type{T}, x::Fixed) where {T<:Real} - unflatten_Fixed(v_new::Vector{T}) = x + unflatten_Fixed(v_new::AbstractVector{<:Real}) = x return T[], unflatten_Fixed end @@ -148,7 +148,7 @@ value(x::Deferred) = x.f(value(x.args)...) function flatten(::Type{T}, x::Deferred) where {T<:Real} v, unflatten = flatten(T, x.args) - unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new)) + unflatten_Deferred(v_new::AbstractVector{<:Real}) = Deferred(x.f, unflatten(v_new)) return v, unflatten_Deferred end @@ -188,7 +188,9 @@ value(X::Orthogonal) = nearest_orthogonal_matrix(X.X) function flatten(::Type{T}, X::Orthogonal) where {T<:Real} v, unflatten_to_Array = flatten(T, X.X) - unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new)) + function unflatten_Orthogonal(v_new::AbstractVector{<:Real}) + return Orthogonal(unflatten_to_Array(v_new)) + end return v, unflatten_Orthogonal end @@ -217,7 +219,9 @@ value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real} v, unflatten_v = flatten(T, X.L) - unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new)) + function unflatten_PositiveDefinite(v_new::AbstractVector{<:Real}) + return PositiveDefinite(unflatten_v(v_new)) + end return v, unflatten_PositiveDefinite end diff --git a/src/test_utils.jl b/src/test_utils.jl index 22ce3d3..3a1c5e7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,31 +6,30 @@ using Test using ParameterHandling: AbstractParameter, value -# Handles equality of scalars, functions, tuples or arbitrary types -function default_equality(a::T, b::T; kwargs...) where {T} - vals = fieldvalues(a) - - # If we don't have any fields then we're probably dealing with scalars - if isempty(vals) - # Only call isapprox for numbers, otherwise we fallback to == - return T <: Number ? isapprox(a, b; kwargs...) : a == b - else - return all(t -> default_equality(t...; kwargs...), zip(vals, fieldvalues(b))) - end +# Handles equality of structs / mutable structs. +function default_equality(a::Ta, b::Tb; kwargs...) where {Ta,Tb} + (isstructtype(Ta) && isstructtype(Tb)) || throw(error("Arguments aren't structs")) + return all(t -> default_equality(t...; kwargs...), zip(fieldvalues(a), fieldvalues(b))) end +default_equality(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...) + # Handles extracting elements from arrays. # Needed because fieldvalues(a) are empty, but we may need to recurse depending on # the element type -function default_equality(a::T, b::T; kwargs...) where {T<:AbstractArray} +function default_equality(a::AbstractArray, b::AbstractArray; kwargs...) return all(t -> default_equality(t...; kwargs...), zip(a, b)) end # Handles extracting values for any dictionary types -function default_equality(a::T, b::T; kwargs...) where {T<:AbstractDict} +function default_equality(a::AbstractDict, b::AbstractDict; kwargs...) return all(t -> default_equality(t...; kwargs...), zip(values(a), values(b))) end +struct MyReal{T} <: Real + v::T +end + # NOTE: May want to make the equality function a kwarg in the future. function test_flatten_interface(x::T; check_inferred::Bool=true) where {T} @testset "flatten($T)" begin @@ -50,7 +49,9 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where {T} @test typeof(_v) === Vector{Float64} @test _v == v @test default_equality(x, unflatten(_v)) - @test _unflatten(_v) isa T + + # Check that unflattening works with different reals. + _unflatten(map(MyReal, randn(length(_v)))) # Check that everything infers properly. check_inferred && @inferred flatten(Float64, x) @@ -59,7 +60,9 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where {T} _v, _unflatten = flatten(Float32, x) @test typeof(_v) === Vector{Float32} @test default_equality(x, _unflatten(_v); atol=1e-5) - @test _unflatten(_v) isa T + + # Check that unflattening works with different precisions. + _unflatten(map(MyReal, randn(length(_v)))) # Check that everything infers properly. check_inferred && @inferred flatten(Float32, x) @@ -68,7 +71,9 @@ function test_flatten_interface(x::T; check_inferred::Bool=true) where {T} _v, _unflatten = flatten(Float16, x) @test typeof(_v) === Vector{Float16} @test default_equality(x, _unflatten(_v); atol=1e-2) - @test _unflatten(_v) isa T + + # Check that unflattening works with different precisions. + _unflatten(map(MyReal, randn(length(_v)))) # Check that everything infers properly. check_inferred && @inferred flatten(Float16, x) diff --git a/test/flatten.jl b/test/flatten.jl index 4a34c99..0d6a5b2 100644 --- a/test/flatten.jl +++ b/test/flatten.jl @@ -26,7 +26,6 @@ @testset "Tuple" begin test_flatten_interface((1.0, 2.0); check_inferred=tuple_infers) - test_flatten_interface((1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers) end