-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
From: #39
This only relates to array valued parameter. At the moment, when a Vector is unflattened, a new Vector is created for each argument in the tuple:
function flatten(::Type{T}, x::Tuple) where {T<:Real}
x_vecs_and_backs = map(val -> flatten(T, val), x)
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(length, x_vecs)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v::AbstractVector{<:Real})
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[(s - l + 1):s]) #HERE
end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
endThis is necessary, as otherwise the unflattened NamedTuple would contain a bunch of Subarrays, as we cannot deduce the original array from
function flatten(::Type{T}, x::NamedTuple{names}) where {T<:Real,names}
x_vec, unflatten = flatten(T, values(x))
function unflatten_to_NamedTuple(v::AbstractVector{<:Real})
v_vec_vec = unflatten(v)
return NamedTuple{names}(v_vec_vec) #HERE
end
return x_vec, unflatten_to_NamedTuple
endChanging return NamedTuple{names}(v_vec_vec) to typeof(x)(v_vec_vec) would allow us to use views in the first code block, but this change would make many AD backends unusable.
Metadata
Metadata
Assignees
Labels
No labels