From c0cb6089a6c93ce0d3bad34c4704c6e74c5b7727 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 3 Jul 2023 05:40:57 +0100 Subject: [PATCH 01/15] Bump `AbstractPPL` version. --- Project.toml | 2 +- test/Project.toml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index eb03e7c7c..425b300a9 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0, 4" -AbstractPPL = "0.5.3" +AbstractPPL = "0.5.3, 0.6" BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" diff --git a/test/Project.toml b/test/Project.toml index 1f3e6e4fe..a91449e64 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,6 +22,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0, 4" +AbstractPPL = "0.5, 0.6" +Bijectors = "0.11, 0.12" AbstractPPL = "0.5" Bijectors = "0.13" Distributions = "0.25" From 9332e4022623b4ed6301149c71f52558d61d79f7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 3 Jul 2023 05:41:24 +0100 Subject: [PATCH 02/15] Remove version 0.5. --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 425b300a9..99c954d19 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0, 4" -AbstractPPL = "0.5.3, 0.6" +AbstractPPL = "0.6.3" BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" diff --git a/test/Project.toml b/test/Project.toml index a91449e64..49910fa80 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0, 4" -AbstractPPL = "0.5, 0.6" +AbstractPPL = "0.6" Bijectors = "0.11, 0.12" AbstractPPL = "0.5" Bijectors = "0.13" From 5adfb4cc8febec447f20984f03f268329438f49d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:07:13 +0100 Subject: [PATCH 03/15] Concretize `Colon`s --- src/compiler.jl | 21 ++++++++++++++++----- src/test_utils.jl | 3 ++- test/varinfo.jl | 2 +- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bdd413630..c415bd0c0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,5 +1,15 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +function need_concretize(expr) + return Setfield.need_dynamic_lens(expr) || begin + flag = false + MacroTools.postwalk(expr) do ex # flag in closure, not ideal + ex == :(:) && (flag = true) && return ex + end + flag + end +end + """ isassumption(expr[, vn]) @@ -16,10 +26,10 @@ When `expr` is not an expression or symbol (i.e., a literal), this expands to `f If `vn` is specified, it will be assumed to refer to a expression which evaluates to a `VarName`, and this will be used in the subsequent checks. -If `vn` is not specified, `AbstractPPL.drop_escape(varname(expr))` will be +If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr))) +function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr)))) return quote if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: @@ -160,7 +170,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return vn ∘ Setfield.IndexLens((Colon(), i)) + return AbstractPPL.concretize(vn ∘ Setfield.IndexLens((Colon(), i)), left) end return unwrap_right_left_vns(right, left, vns) end @@ -338,7 +348,7 @@ function generate_tilde(left, right) return quote $dist = $right $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left))), $dist + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption @@ -397,7 +407,8 @@ function generate_dot_tilde(left, right) @gensym vn isassumption value return quote $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left))), $right + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right + # $(AbstractPPL.drop_escape(varname(left, true))), $right ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption diff --git a/src/test_utils.jl b/src/test_utils.jl index 74ddcc00f..7712f982d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -543,7 +543,8 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] + s = zeros(2) # used for varname concretization only + return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] end @model function demo_assume_matrix_dot_observe_matrix( diff --git a/test/varinfo.jl b/test/varinfo.jl index de86a9e05..9fd0018ee 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -55,7 +55,7 @@ @test inspace(@varname(z[1][:]), space) @test inspace(@varname(z[1][2:3:10]), space) @test inspace(@varname(M[[2, 3], 1]), space) - @test inspace(@varname(M[:, 1:4]), space) + @test inspace(@varname(M[:, 1:4]), space) broken=true @test inspace(@varname(M[1, [2, 4, 6]]), space) @test !inspace(@varname(z[2]), space) @test !inspace(@varname(z), space) From 96f6ce774ae7f63dcb2b7458f357eb87c1a71fa3 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:11:06 +0100 Subject: [PATCH 04/15] Relax `AbstractPPL` version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 99c954d19..8cc4e7d6e 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0, 4" -AbstractPPL = "0.6.3" +AbstractPPL = "0.6" BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" From efb45880225e055d9df18b499e156720d16873dc Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:15:25 +0100 Subject: [PATCH 05/15] Correct test `Project.toml` --- test/Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 49910fa80..b36a7e23a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,8 +23,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0, 4" AbstractPPL = "0.6" -Bijectors = "0.11, 0.12" -AbstractPPL = "0.5" Bijectors = "0.13" Distributions = "0.25" DistributionsAD = "0.6.3" From c692f4045fa9541b6e287aba1c9a614ad9c223b0 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:21:18 +0100 Subject: [PATCH 06/15] Formatting --- src/compiler.jl | 11 +++++++++-- test/varinfo.jl | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index c415bd0c0..687f08a53 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -29,7 +29,13 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr)))) +function isassumption( + expr::Union{Expr,Symbol}, + vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), +) + expr::Union{Expr,Symbol}, + vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), +) return quote if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: @@ -401,7 +407,8 @@ Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) isliteral(left) && return generate_tilde_literal(left, right) - + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), + $right, # Otherwise it is determined by the model or its value, # if the LHS represents an observation @gensym vn isassumption value diff --git a/test/varinfo.jl b/test/varinfo.jl index 9fd0018ee..746b77804 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -55,7 +55,7 @@ @test inspace(@varname(z[1][:]), space) @test inspace(@varname(z[1][2:3:10]), space) @test inspace(@varname(M[[2, 3], 1]), space) - @test inspace(@varname(M[:, 1:4]), space) broken=true + @test inspace(@varname(M[:, 1:4]), space) broken = true @test inspace(@varname(M[1, [2, 4, 6]]), space) @test !inspace(@varname(z[2]), space) @test !inspace(@varname(z), space) From c03dfb697633db09b87f773f89d4635ee4d61e50 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:27:15 +0100 Subject: [PATCH 07/15] Fix formatter mistake. --- src/compiler.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 687f08a53..71aa5f46c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -32,9 +32,6 @@ used in its place. function isassumption( expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), -) - expr::Union{Expr,Symbol}, - vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), ) return quote if $(DynamicPPL.contextual_isassumption)(__context__, $vn) @@ -407,15 +404,15 @@ Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) isliteral(left) && return generate_tilde_literal(left, right) - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), - $right, + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), + $right, # Otherwise it is determined by the model or its value, # if the LHS represents an observation @gensym vn isassumption value return quote $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right - # $(AbstractPPL.drop_escape(varname(left, true))), $right + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), + $right, ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption From cc9ce6a45837823dafc7938e9135034ce99ab0d4 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:31:10 +0100 Subject: [PATCH 08/15] Apply formatting suggestions --- src/compiler.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 71aa5f46c..30f7b9ef6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -411,8 +411,7 @@ function generate_dot_tilde(left, right) @gensym vn isassumption value return quote $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), - $right, + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption From ab333720cbfc8c1caae881bfa87151048867c5d7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 09:59:11 +0100 Subject: [PATCH 09/15] Correct wrong dimension of `s` in test --- src/compiler.jl | 3 +-- src/test_utils.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 30f7b9ef6..e7b21054d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -404,8 +404,7 @@ Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) isliteral(left) && return generate_tilde_literal(left, right) - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), - $right, + # Otherwise it is determined by the model or its value, # if the LHS represents an observation @gensym vn isassumption value diff --git a/src/test_utils.jl b/src/test_utils.jl index 7712f982d..8e00880e9 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -543,7 +543,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - s = zeros(2) # used for varname concretization only + s = zeros(1, 2) # used for varname concretization only return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] end From 5a08f6a5a4cfdbcfe181b2570ccf153251d85d79 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 7 Jul 2023 10:52:33 +0100 Subject: [PATCH 10/15] use `@test_throws` instead of `@test` --- test/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 746b77804..35ab30dcd 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -55,7 +55,7 @@ @test inspace(@varname(z[1][:]), space) @test inspace(@varname(z[1][2:3:10]), space) @test inspace(@varname(M[[2, 3], 1]), space) - @test inspace(@varname(M[:, 1:4]), space) broken = true + @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) @test inspace(@varname(M[1, [2, 4, 6]]), space) @test !inspace(@varname(z[2]), space) @test !inspace(@varname(z), space) From d7cf4ca3ab2441a15c2c78ec2a70848da676fa4f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 8 Jul 2023 14:18:01 +0100 Subject: [PATCH 11/15] use `Setfield.set` in `set!!` function, experiment --- src/utils.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b1076daf4..bdb4e90c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -332,12 +332,14 @@ collectmaybe(x::Base.AbstractSet) = collect(x) # BangBang.jl related # ####################### function set!!(obj, lens::Setfield.Lens, value) - lensmut = BangBang.prefermutation(lens) - return Setfield.set(obj, lensmut, value) + # lensmut = BangBang.prefermutation(lens) + # return Setfield.set(obj, lensmut, value) + return Setfield.set(obj, lens, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} - lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) - return Setfield.set(obj, lens, value) + # lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) + # return Setfield.set(obj, lens, value) + return AbstractPPL.set(obj, vn, value) end ############################# From c2a87aab6b34adf9270dd2157c93368e10332851 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 18 Jul 2023 10:52:04 +0100 Subject: [PATCH 12/15] Revert "use `Setfield.set` in `set!!` function, experiment" This reverts commit d7cf4ca3ab2441a15c2c78ec2a70848da676fa4f. --- src/utils.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index bdb4e90c8..b1076daf4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -332,14 +332,12 @@ collectmaybe(x::Base.AbstractSet) = collect(x) # BangBang.jl related # ####################### function set!!(obj, lens::Setfield.Lens, value) - # lensmut = BangBang.prefermutation(lens) - # return Setfield.set(obj, lensmut, value) - return Setfield.set(obj, lens, value) + lensmut = BangBang.prefermutation(lens) + return Setfield.set(obj, lensmut, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} - # lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) - # return Setfield.set(obj, lens, value) - return AbstractPPL.set(obj, vn, value) + lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) + return Setfield.set(obj, lens, value) end ############################# From c4de922f0e46fef20016ed237f756faca20b96e4 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 19 Jul 2023 10:12:16 +0100 Subject: [PATCH 13/15] hacky `possible` with `ConcretizedSlice` --- src/utils.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index b1076daf4..3f1a4261d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -494,6 +494,12 @@ function BangBang.possible( return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) end +function BangBang.possible( + ::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer +) where {C<:AbstractMatrix,T<:AbstractVector} + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) +end # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. From 6d482fa20d61464570182469946137a0a788e694 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 19 Jul 2023 13:58:35 +0100 Subject: [PATCH 14/15] add some doc for `need --- src/compiler.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 9a855069f..96e98938b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,9 +1,28 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +""" + need_concretize(expr) + +Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or +requires a dynamic lens. + +# Examples + +```jldoctest; setup=:(using Setfield) +julia> DynamicPPL.need_concretize(:(x[1, :])) +true + +julia> DynamicPPL.need_concretize(:(x[1, end])) +true + +julia> DynamicPPL.need_concretize(:(x[1, 1])) +false +""" function need_concretize(expr) return Setfield.need_dynamic_lens(expr) || begin flag = false - MacroTools.postwalk(expr) do ex # flag in closure, not ideal + MacroTools.postwalk(expr) do ex + # Concretise colon by default ex == :(:) && (flag = true) && return ex end flag From dddda4d72166d9b45fe20ef61a979a5d9c5ef61a Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 21 Jul 2023 12:21:48 +0100 Subject: [PATCH 15/15] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 872026756..6b8abb913 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.6" +version = "0.23.7" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"