@@ -54,12 +54,17 @@ s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic)
5454y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
5555```
5656"""
57+
58+ struct Model
59+ g:: GraphInfo
60+ end
61+
5762function Model (;kwargs... )
5863 vals = getvals (NamedTuple (kwargs))
5964 args = [argnames (f) for f in vals[2 ]]
6065 A, sorted_vertices = DAG (NamedTuple {keys(kwargs)} (args))
6166 modelinputs = NamedTuple {Tuple(sorted_vertices)} .([vals[1 ], Tuple .(args), vals[2 ], vals[3 ]])
62- GraphInfo (modelinputs... , A, sorted_vertices)
67+ Model ( GraphInfo (modelinputs... , A, sorted_vertices) )
6368end
6469
6570
@@ -134,8 +139,6 @@ function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
134139 return A
135140end
136141
137- adjacency_matrix (m:: GraphInfo ) = adjacency_matrix (m. input)
138-
139142function outneighbors (A:: SparseMatrixCSC , u:: T ) where T <: Int
140143 # adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302
141144 inds, _ = findnz (A[:, u])
@@ -185,54 +188,67 @@ Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`,
185188# Examples
186189
187190```jl-doctest
188- # add a model
191+ julia> using AbstractPPL
192+
193+ julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
194+ μ = (1.0, () -> 1.0, :Logical),
195+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
196+ (s2 = Symbol[], μ = Symbol[], y = [:μ, :s2])
197+ Nodes:
198+ μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic)
199+ s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical)
200+ y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
189201
190- julia> m[@varname y]
191- (value = 0.0, input = (:μ, :s2), eval = var"#35#38"(), kind = :Stochastic)
192202
203+ julia> m[@varname y]
204+ (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
193205```
194206"""
195- @generated function Base. getindex (m :: GraphInfo , vn:: VarName{p} ) where {p}
207+ @generated function Base. getindex (g :: GraphInfo , vn:: VarName{p} ) where {p}
196208 fns = fieldnames (GraphInfo)[1 : 4 ]
197209 name_lens = Setfield. PropertyLens {p} ()
198210 field_lenses = [Setfield. PropertyLens {f} () for f in fns]
199- values = [:(get (m , Setfield. compose ($ l, $ name_lens, getlens (vn)))) for l in field_lenses]
211+ values = [:(get (g , Setfield. compose ($ l, $ name_lens, getlens (vn)))) for l in field_lenses]
200212 return :(NamedTuple {$(fns)} (($ (values... ),)))
201213end
202214
203- function Base. show (io:: IO , m:: GraphInfo )
215+ function Base. getindex (m:: Model , vn:: VarName )
216+ return m. g[vn]
217+ end
218+
219+ function Base. show (io:: IO , m:: Model )
204220 print (io, " Nodes: \n " )
205221 for node in nodes (m)
206222 print (io, " $node = " , m[VarName {node} ()], " \n " )
207223 end
208224end
209225
210226
211- function Base. iterate (m:: GraphInfo , state= 1 )
212- state > length (nodes (m)) ? nothing : (m[VarName {m.sorted_vertices[state]} ()], state+ 1 )
227+ function Base. iterate (m:: Model , state= 1 )
228+ state > length (nodes (m)) ? nothing : (m[VarName {m.g. sorted_vertices[state]} ()], state+ 1 )
213229end
214230
215- Base. eltype (m:: GraphInfo ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
216- Base. IteratorEltype (m:: GraphInfo ) = HasEltype ()
231+ Base. eltype (m:: Model ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
232+ Base. IteratorEltype (m:: Model ) = HasEltype ()
217233
218- Base. keys (m:: GraphInfo ) = (VarName {n} () for n in m. sorted_vertices)
219- Base. values (m:: GraphInfo ) = Base. Generator (identity, m)
220- Base. length (m:: GraphInfo ) = length (nodes (m))
221- Base. keytype (m:: GraphInfo ) = eltype (keys (m))
222- Base. valtype (m:: GraphInfo ) = eltype (m)
234+ Base. keys (m:: Model ) = (VarName {n} () for n in m. g . sorted_vertices)
235+ Base. values (m:: Model ) = Base. Generator (identity, m)
236+ Base. length (m:: Model ) = length (nodes (m))
237+ Base. keytype (m:: Model ) = eltype (keys (m))
238+ Base. valtype (m:: Model ) = eltype (m)
223239
224240
225241"""
226242 dag(m::Model)
227243
228244Returns the adjacency matrix of the model as a SparseArray.
229245"""
230- dag (m:: GraphInfo ) = m. A
246+ dag (m:: Model ) = m. g . A
231247
232248"""
233249 nodes(m::Model)
234250
235251Returns a `Vector{Symbol}` containing the sorted vertices
236252of the DAG.
237253"""
238- nodes (m:: GraphInfo ) = m. sorted_vertices
254+ nodes (m:: Model ) = m. g . sorted_vertices
0 commit comments