Skip to content

Commit f2e86da

Browse files
committed
remove exports; adding checks for input to Models are correct
1 parent 1267785 commit f2e86da

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

src/AbstractPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ include("deprecations.jl")
1919
# GraphInfo
2020
module GraphPPL
2121
include("graphinfo.jl")
22-
export GraphInfo, Model, dag, nodes
2322
end
2423

2524
end # module

src/graphinfo.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ GraphInfo is instantiated using the `Model` constctor.
2121
"""
2222

2323
struct GraphInfo{T} <: AbstractModelTrace
24-
value::NamedTuple{T}
2524
input::NamedTuple{T}
25+
value::NamedTuple{T}
2626
eval::NamedTuple{T}
2727
kind::NamedTuple{T}
2828
A::SparseMatrixCSC
@@ -55,29 +55,30 @@ y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
5555
```
5656
"""
5757

58-
struct Model{T}
58+
struct Model{T} <: AbstractProbabilisticProgram
5959
g::GraphInfo{T}
6060
end
6161

6262
function Model(;kwargs...)
63+
for (i, node) in enumerate(values(kwargs))
64+
@assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)"
65+
end
6366
vals = getvals(NamedTuple(kwargs))
6467
args = [argnames(f) for f in vals[2]]
65-
A, sorted_vertices = DAG(NamedTuple{keys(kwargs)}(args))
66-
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([vals[1], Tuple.(args), vals[2], vals[3]])
68+
A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args))
69+
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...])
6770
Model(GraphInfo(modelinputs..., A, sorted_vertices))
6871
end
6972

70-
7173
"""
72-
DAG(inputs)
74+
dag(inputs)
7375
7476
Function taking in a NamedTuple containing the inputs to each node
7577
and returns the implied adjacency matrix and topologically ordered
7678
vertex list.
7779
"""
78-
function DAG(inputs)
80+
function dag(inputs)
7981
input_names = Symbol[keys(inputs)...]
80-
println(inputs)
8182
A = adjacency_matrix(inputs)
8283
sorted_vertices = topological_sort_by_dfs(A)
8384
sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices)
@@ -94,7 +95,7 @@ input, eval and kind, as required by the GraphInfo type.
9495
@generated function getvals(nt::NamedTuple{T}) where T
9596
values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3]
9697
m = [:($(values[:,i]...), ) for i in 1:3]
97-
return :($(m...),)
98+
return Expr(:tuple, m...) # :($(m...),)
9899
end
99100

100101
"""
@@ -243,7 +244,7 @@ Base.valtype(m::Model) = eltype(m)
243244
244245
Returns the adjacency matrix of the model as a SparseArray.
245246
"""
246-
dag(m::Model) = m.g.A
247+
get_dag(m::Model) = m.g.A
247248

248249
"""
249250
nodes(m::Model)

test/graphinfo.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using AbstractPPL.GraphPPL
1+
using AbstractPPL
2+
import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag
23
using SparseArrays
3-
4+
using Test
45
## Example taken from Mamba
56
line = Dict{Symbol, Any}(
67
:x => [1, 2, 3, 4, 5],
@@ -19,7 +20,7 @@ model = (
1920

2021
# construct the model!
2122
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor
22-
typeof(m)
23+
2324
# test the type of the model is correct
2425
@test typeof(m) <: Model
2526
@test typeof(m) == Model{(:s2, :xmat, , , :y)}
@@ -28,7 +29,7 @@ typeof(m)
2829

2930
# test the dag is correct
3031
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
31-
@test dag(m) == A
32+
@test get_dag(m) == A
3233

3334
@test length(m) == 5
3435
@test eltype(m) == valtype(m)
@@ -57,4 +58,7 @@ single_parent_m = Model(μ = (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNorma
5758
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}
5859

5960
# test ErrorException for parent node not found
60-
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
61+
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
62+
63+
# test AssertionError thrown for kwargs with the wrong order of inputs
64+
@test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))

0 commit comments

Comments
 (0)