@@ -21,8 +21,8 @@ GraphInfo is instantiated using the `Model` constctor.
2121"""
2222
2323struct GraphInfo{T} <: AbstractModelTrace
24- value:: NamedTuple{T}
2524 input:: NamedTuple{T}
25+ value:: NamedTuple{T}
2626 eval:: NamedTuple{T}
2727 kind:: NamedTuple{T}
2828 A:: SparseMatrixCSC
@@ -55,29 +55,30 @@ y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
5555```
5656"""
5757
58- struct Model{T}
58+ struct Model{T} <: AbstractProbabilisticProgram
5959 g:: GraphInfo{T}
6060end
6161
6262function Model (;kwargs... )
63+ for (i, node) in enumerate (values (kwargs))
64+ @assert typeof (node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} " Check input order for node $(i) matches Tuple(value, function, kind)"
65+ end
6366 vals = getvals (NamedTuple (kwargs))
6467 args = [argnames (f) for f in vals[2 ]]
65- A, sorted_vertices = DAG (NamedTuple {keys(kwargs)} (args))
66- modelinputs = NamedTuple {Tuple(sorted_vertices)} .([vals[ 1 ], Tuple .(args), vals[ 2 ], vals[ 3 ] ])
68+ A, sorted_vertices = dag (NamedTuple {keys(kwargs)} (args))
69+ modelinputs = NamedTuple {Tuple(sorted_vertices)} .([Tuple .(args), vals... ])
6770 Model (GraphInfo (modelinputs... , A, sorted_vertices))
6871end
6972
70-
7173"""
72- DAG (inputs)
74+ dag (inputs)
7375
7476Function taking in a NamedTuple containing the inputs to each node
7577and returns the implied adjacency matrix and topologically ordered
7678vertex list.
7779"""
78- function DAG (inputs)
80+ function dag (inputs)
7981 input_names = Symbol[keys (inputs)... ]
80- println (inputs)
8182 A = adjacency_matrix (inputs)
8283 sorted_vertices = topological_sort_by_dfs (A)
8384 sorted_A = permute (A, collect (1 : length (inputs)), sorted_vertices)
@@ -94,7 +95,7 @@ input, eval and kind, as required by the GraphInfo type.
9495@generated function getvals (nt:: NamedTuple{T} ) where T
9596 values = [:(nt[$ i][$ j]) for i in 1 : length (T), j in 1 : 3 ]
9697 m = [:($ (values[:,i]. .. ), ) for i in 1 : 3 ]
97- return :($ (m... ),)
98+ return Expr ( :tuple , m ... ) # :($(m...),)
9899end
99100
100101"""
@@ -243,7 +244,7 @@ Base.valtype(m::Model) = eltype(m)
243244
244245Returns the adjacency matrix of the model as a SparseArray.
245246"""
246- dag (m:: Model ) = m. g. A
247+ get_dag (m:: Model ) = m. g. A
247248
248249"""
249250 nodes(m::Model)
0 commit comments