Skip to content

Commit d51a91d

Browse files
committed
reverting adjacency matrix function and changing test
1 parent 5205ad5 commit d51a91d

File tree

2 files changed

+17
-49
lines changed

2 files changed

+17
-49
lines changed

src/simpleppl.jl

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ end
8282

8383
function Model(;kwargs...)
8484
Model(values(kwargs))
85-
end
85+
end
86+
# add thing here to extract inputs from anon functions then change into different NamedTuple
87+
# add docstring because it will behave differently
8688

8789
function Base.show(io::IO, m::Model)
8890
print(io, "Nodes: \n")
@@ -118,29 +120,15 @@ function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
118120
col_inds = NamedTuple{nodes}(ntuple(identity, N))
119121
A = spzeros(Bool, N, N)
120122
for (row, node) in enumerate(nodes)
121-
v_inputs = inputs[node]
122-
setinput!(A, row, col_inds, nodes, v_inputs)
123-
end
124-
return A
125-
end
126-
127-
function setinput!(A::SparseMatrixCSC{Bool, Int64}, row, col_inds, nodes, v_input::Symbol)
128-
if v_input nodes
129-
error("Parent node of $(v_input) not found in node set: $(nodes)")
130-
end
131-
col = col_inds[v_input]
132-
A[row, col] = true
133-
end
134-
135-
function setinput!(A::SparseMatrixCSC{Bool, Int64}, row, col_inds, nodes, v_inputs)
136-
for input in v_inputs
137-
if input nodes
138-
error("Parent node of $(input) not found in node set: $(nodes)")
123+
for input in inputs[node]
124+
if input nodes
125+
error("Parent node of $(input) not found in node set: $(nodes)")
126+
end
127+
col = col_inds[input]
128+
A[row, col] = true
139129
end
140-
col = col_inds[input]
141-
A[row, col] = true
142130
end
143-
A
131+
return A
144132
end
145133

146134
adjacency_matrix(m::Model) = adjacency_matrix(m.ModelState.input)
@@ -226,6 +214,7 @@ Base.length(m::Model) = length(nodes(m))
226214
Base.keytype(m::Model) = eltype(keys(m))
227215
Base.valtype(m::Model) = eltype(m)
228216

217+
229218
"""
230219
dag(m::Model)
231220
@@ -239,26 +228,4 @@ dag(m::Model) = m.DAG.A
239228
Returns a `Vector{Symbol}` containing the sorted vertices
240229
of the DAG.
241230
"""
242-
nodes(m::Model) = m.DAG.sorted_vertices
243-
244-
# # General eval function
245-
# function evalf(f::Function, m::Model)
246-
# nodes = m.DAG.sorted_vertex_list
247-
# symlist = keys(m.ModelState.input)
248-
# vals = (;)
249-
# for (i, n) in enumerate(nodes)
250-
# node = symlist[n]
251-
# input_nodes = m.ModelState.input[node]
252-
# if m.ModelState.kind[node] == :Stochastic
253-
# if length(input_nodes) == 0
254-
# vals = merge(vals, [node=>f(m.ModelState.eval[node]())])
255-
# elseif length(input_nodes) > 0
256-
# inputs = [vals[n] for n in input_nodes]
257-
# vals = merge(vals, [node=>f(m.ModelState.eval[node](inputs...))])
258-
# end
259-
# else
260-
# vals = merge(vals, [node=>m.ModelState.eval[node]()])
261-
# end
262-
# end
263-
# vals
264-
# end
231+
nodes(m::Model) = m.DAG.sorted_vertices

test/simpleppl.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@ model = (
1717
y = (zeros(5), (, :s2), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
1818
)
1919

20+
2021
# test Model constructor for model with single parent node
2122
@test typeof(
2223
Model(
23-
μ = (zeros(5), (), () -> 3, :Logical),
24-
y = (zeros(5), (), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
24+
μ = (1.0, (), () -> 3, :Logical),
25+
y = (1.0, (,), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
2526
)
2627
) == Model
2728

2829
# test ErrorException for parent node not being found
2930
@test_throws ErrorException Model(
30-
μ = (zeros(5), (), () -> 3, :Logical),
31-
y = (zeros(5), (), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
31+
μ = (zeros(5), (,), () -> 3, :Logical),
32+
y = (zeros(5), (,), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
3233
)
3334

3435
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor

0 commit comments

Comments
 (0)