Skip to content

Commit 0ce30f1

Browse files
committed
adding Model type back in
1 parent 21192be commit 0ce30f1

File tree

4 files changed

+47
-30
lines changed

4 files changed

+47
-30
lines changed

src/AbstractPPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export VarName, getsym, getlens, inspace, subsumes, varname, vsym, @varname, @vs
77
# Abstract model functions
88
export AbstractProbabilisticProgram, condition, decondition, logdensityof, densityof
99

10-
# SimplePPL
10+
# GraphInfo
1111
export GraphInfo, Model, dag, nodes
1212

1313
# Abstract traces
@@ -18,5 +18,5 @@ include("varname.jl")
1818
include("abstractmodeltrace.jl")
1919
include("abstractprobprog.jl")
2020
include("deprecations.jl")
21-
include("simpleppl.jl")
21+
include("graphinfo.jl")
2222
end # module

src/graphinfo.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,17 @@ s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic)
5454
y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
5555
```
5656
"""
57+
58+
struct Model
59+
g::GraphInfo
60+
end
61+
5762
function Model(;kwargs...)
5863
vals = getvals(NamedTuple(kwargs))
5964
args = [argnames(f) for f in vals[2]]
6065
A, sorted_vertices = DAG(NamedTuple{keys(kwargs)}(args))
6166
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([vals[1], Tuple.(args), vals[2], vals[3]])
62-
GraphInfo(modelinputs..., A, sorted_vertices)
67+
Model(GraphInfo(modelinputs..., A, sorted_vertices))
6368
end
6469

6570

@@ -134,8 +139,6 @@ function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
134139
return A
135140
end
136141

137-
adjacency_matrix(m::GraphInfo) = adjacency_matrix(m.input)
138-
139142
function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int
140143
#adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302
141144
inds, _ = findnz(A[:, u])
@@ -185,54 +188,67 @@ Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`,
185188
# Examples
186189
187190
```jl-doctest
188-
# add a model
191+
julia> using AbstractPPL
192+
193+
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
194+
μ = (1.0, () -> 1.0, :Logical),
195+
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
196+
(s2 = Symbol[], μ = Symbol[], y = [:μ, :s2])
197+
Nodes:
198+
μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic)
199+
s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical)
200+
y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
189201
190-
julia> m[@varname y]
191-
(value = 0.0, input = (:μ, :s2), eval = var"#35#38"(), kind = :Stochastic)
192202
203+
julia> m[@varname y]
204+
(value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
193205
```
194206
"""
195-
@generated function Base.getindex(m::GraphInfo, vn::VarName{p}) where {p}
207+
@generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p}
196208
fns = fieldnames(GraphInfo)[1:4]
197209
name_lens = Setfield.PropertyLens{p}()
198210
field_lenses = [Setfield.PropertyLens{f}() for f in fns]
199-
values = [:(get(m, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses]
211+
values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses]
200212
return :(NamedTuple{$(fns)}(($(values...),)))
201213
end
202214

203-
function Base.show(io::IO, m::GraphInfo)
215+
function Base.getindex(m::Model, vn::VarName)
216+
return m.g[vn]
217+
end
218+
219+
function Base.show(io::IO, m::Model)
204220
print(io, "Nodes: \n")
205221
for node in nodes(m)
206222
print(io, "$node = ", m[VarName{node}()], "\n")
207223
end
208224
end
209225

210226

211-
function Base.iterate(m::GraphInfo, state=1)
212-
state > length(nodes(m)) ? nothing : (m[VarName{m.sorted_vertices[state]}()], state+1)
227+
function Base.iterate(m::Model, state=1)
228+
state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
213229
end
214230

215-
Base.eltype(m::GraphInfo) = NamedTuple{fieldnames(GraphInfo)[1:4]}
216-
Base.IteratorEltype(m::GraphInfo) = HasEltype()
231+
Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
232+
Base.IteratorEltype(m::Model) = HasEltype()
217233

218-
Base.keys(m::GraphInfo) = (VarName{n}() for n in m.sorted_vertices)
219-
Base.values(m::GraphInfo) = Base.Generator(identity, m)
220-
Base.length(m::GraphInfo) = length(nodes(m))
221-
Base.keytype(m::GraphInfo) = eltype(keys(m))
222-
Base.valtype(m::GraphInfo) = eltype(m)
234+
Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
235+
Base.values(m::Model) = Base.Generator(identity, m)
236+
Base.length(m::Model) = length(nodes(m))
237+
Base.keytype(m::Model) = eltype(keys(m))
238+
Base.valtype(m::Model) = eltype(m)
223239

224240

225241
"""
226242
dag(m::Model)
227243
228244
Returns the adjacency matrix of the model as a SparseArray.
229245
"""
230-
dag(m::GraphInfo) = m.A
246+
dag(m::Model) = m.g.A
231247

232248
"""
233249
nodes(m::Model)
234250
235251
Returns a `Vector{Symbol}` containing the sorted vertices
236252
of the DAG.
237253
"""
238-
nodes(m::GraphInfo) = m.sorted_vertices
254+
nodes(m::Model) = m.g.sorted_vertices

test/graphinfo.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ model = (
2121
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor
2222

2323
# test the type of the model is correct
24-
@test typeof(m) <: GraphInfo <: AbstractModelTrace
25-
@test typeof(m) == GraphInfo{(:s2, :xmat, , , :y)}
24+
@test typeof(m) == Model
25+
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
26+
@test typeof(m.g) == GraphInfo{(:s2, :xmat, , , :y)}
2627

2728
# test the dag is correct
2829
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
@@ -34,11 +35,11 @@ A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
3435
# check the values from the NamedTuple match the values in the fields of GraphInfo
3536
vals = AbstractPPL.getvals(model)
3637
for (i, field) in enumerate([:value, :eval, :kind])
37-
@test eval( :(values(m.$field) == vals[$i]) )
38+
@test eval( :(values(m.g.$field) == vals[$i]) )
3839
end
3940

4041
# test the right inputs have been inferred
41-
@test m.input == (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))
42+
@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))
4243

4344
# test keys are VarNames
4445
for key in keys(m)
@@ -47,9 +48,9 @@ end
4748

4849

4950
# test Model constructor for model with single parent node
50-
@test typeof(
51-
Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
52-
) == GraphInfo{(, :y)}
51+
single_parent_m = Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
52+
@test typeof(single_parent_m) == Model
53+
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}
5354

5455
# test ErrorException for parent node not found
5556
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Test
1212

1313
@testset "AbstractPPL.jl" begin
1414
include("deprecations.jl")
15-
include("simpleppl.jl")
15+
include("graphinfo.jl")
1616
@testset "doctests" begin
1717
DocMeta.setdocmeta!(
1818
AbstractPPL,

0 commit comments

Comments
 (0)