Skip to content

Commit b83c262

Browse files
committed
renamed getval to getindex_internal and made dist an optional
argument for all the transform-related methods
1 parent 77b835e commit b83c262

File tree

8 files changed

+124
-62
lines changed

8 files changed

+124
-62
lines changed

src/abstract_varinfo.jl

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -779,14 +779,14 @@ end
779779
Invlink `x` and compute the logpdf under `dist` including correction from
780780
the invlink-transformation.
781781
782-
If `x` is not provided, `getval(vi, vn)` will be used.
782+
If `x` is not provided, `getindex_internal(vi, vn)` will be used.
783783
784784
!!! warning
785785
The input value `x` should be according to the internal representation of
786-
`varinfo`, e.g. the value returned by `getval(vi, vn)`.
786+
`varinfo`, e.g. the value returned by `getindex_internal(vi, vn)`.
787787
"""
788788
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
789-
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
789+
return invlink_with_logpdf(vi, vn, dist, getindex_internal(vi, vn))
790790
end
791791
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
792792
f = from_maybe_linked_internal_transform(vi, vn, dist)
@@ -800,26 +800,32 @@ increment_num_produce!(::AbstractVarInfo) = nothing
800800
setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing
801801

802802
"""
803-
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
803+
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
804804
805805
Return a transformation that transforms from the internal representation of `vn` with `dist`
806806
in `varinfo` to a representation compatible with `dist`.
807+
808+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
807809
"""
808810
function from_internal_transform end
809811

810812
"""
811-
from_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
813+
from_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
812814
813815
Return a transformation that transforms from the linked internal representation of `vn` with `dist`
814816
in `varinfo` to a representation compatible with `dist`.
817+
818+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
815819
"""
816820
function from_linked_internal_transform end
817821

818822
"""
819-
from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
823+
from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
820824
821825
Return a transformation that transforms from the possibly linked internal representation of `vn` with `dist`n
822826
in `varinfo` to a representation compatible with `dist`.
827+
828+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
823829
"""
824830
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
825831
return if istrans(varinfo, vn)
@@ -828,57 +834,94 @@ function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarN
828834
from_internal_transform(varinfo, vn, dist)
829835
end
830836
end
837+
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
838+
return if istrans(varinfo, vn)
839+
from_linked_internal_transform(varinfo, vn)
840+
else
841+
from_internal_transform(varinfo, vn)
842+
end
843+
end
831844

832845
"""
833-
to_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
846+
to_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
834847
835848
Return a transformation that transforms from a representation compatible with `dist` to the
836849
internal representation of `vn` with `dist` in `varinfo`.
850+
851+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
837852
"""
838853
function to_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
839854
return inverse(from_internal_transform(varinfo, vn, dist))
840855
end
856+
function to_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
857+
return inverse(from_internal_transform(varinfo, vn))
858+
end
841859

842860
"""
843-
to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
861+
to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
844862
845863
Return a transformation that transforms from a representation compatible with `dist` to the
846864
linked internal representation of `vn` with `dist` in `varinfo`.
865+
866+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
847867
"""
848868
function to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
849869
return inverse(from_linked_internal_transform(varinfo, vn, dist))
850870
end
871+
function to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
872+
return inverse(from_linked_internal_transform(varinfo, vn))
873+
end
851874

852875
"""
853-
to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
876+
to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
854877
855878
Return a transformation that transforms from a representation compatible with `dist` to a
856879
possibly linked internal representation of `vn` with `dist` in `varinfo`.
880+
881+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
857882
"""
858883
function to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
859884
return inverse(from_maybe_linked_internal_transform(varinfo, vn, dist))
860885
end
886+
function to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
887+
return inverse(from_maybe_linked_internal_transform(varinfo, vn))
888+
end
861889

862890
"""
863891
internal_to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
864892
865893
Return a transformation that transforms from the internal representation of `vn` with `dist`
866894
in `varinfo` to a _linked_ internal representation of `vn` with `dist` in `varinfo`.
895+
896+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
867897
"""
868898
function internal_to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
869899
f_from_internal = from_internal_transform(varinfo, vn, dist)
870900
f_to_linked_internal = to_linked_internal_transform(varinfo, vn, dist)
871901
return f_to_linked_internal f_from_internal
872902
end
903+
function internal_to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
904+
f_from_internal = from_internal_transform(varinfo, vn)
905+
f_to_linked_internal = to_linked_internal_transform(varinfo, vn)
906+
return f_to_linked_internal f_from_internal
907+
end
873908

874909
"""
875-
linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
910+
linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
876911
877912
Return a transformation that transforms from a _linked_ internal representation of `vn` with `dist`
878913
in `varinfo` to the internal representation of `vn` with `dist` in `varinfo`.
914+
915+
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
879916
"""
880917
function linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
881918
f_from_linked_internal = from_linked_internal_transform(varinfo, vn, dist)
882919
f_to_internal = to_internal_transform(varinfo, vn, dist)
883920
return f_to_internal f_from_linked_internal
884921
end
922+
923+
function linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
924+
f_from_linked_internal = from_linked_internal_transform(varinfo, vn)
925+
f_to_internal = to_internal_transform(varinfo, vn)
926+
return f_to_internal f_from_linked_internal
927+
end

src/context_implementations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ function _link_broadcast_new(vi, vn, dist, r)
477477
end
478478

479479
function _maybe_invlink_broadcast(vi, vn, dist)
480-
xvec = getval(vi, vn)
480+
xvec = getindex_internal(vi, vn)
481481
b = from_maybe_linked_internal_transform(vi, vn, dist)
482482
return b(xvec)
483483
end

src/simple_varinfo.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
332332
end
333333

334334
# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
335-
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)
335+
getindex_internal(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)
336336

337337
Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)
338338

@@ -690,16 +690,11 @@ function invlink!!(
690690
end
691691

692692
# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything.
693-
from_internal_transform(::SimpleVarInfo, dist) = identity
694-
function from_internal_transform(vi::SimpleVarInfo, ::VarName, dist)
695-
return from_internal_transform(vi, dist)
696-
end
697-
698-
function from_linked_internal_transform(vi::SimpleVarInfo, dist)
699-
return invlink_transform(dist)
700-
end
693+
from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity
694+
from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity
695+
from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity
701696
function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist)
702-
return from_linked_internal_transform(vi, dist)
697+
return invlink_transform(dist)
703698
end
704699

705700
# Threadsafe stuff.

src/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ end
222222
istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
223223
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
224224

225-
getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
225+
getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn)
226226

227227
function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
228228
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x)

src/utils.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,14 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist))
256256
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) FromVec(size(dist))
257257

258258
"""
259-
from_linked_vec_transform(dist)
259+
from_linked_vec_transform(dist::Distribution)
260260
261261
Return the transformation from the unconstrained vector to the constrained
262262
realization of distribution `dist`.
263+
264+
By default, this is just `invlink_transform(dist) ∘ from_vec_transform(dist)`.
265+
266+
See also: [`DynamicPPL.invlink_transform`](@ref), [`DynamicPPL.from_vec_transform`](@ref).
263267
"""
264268
function from_linked_vec_transform(dist::Distribution)
265269
f_vec = from_vec_transform(dist)

0 commit comments

Comments
 (0)