-
Notifications
You must be signed in to change notification settings - Fork 9
[Merged by Bors] - DAG Model interface #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
be1b846
initiali commit
PavanChaggar e7debfb
adding simpleppl file and test model
PavanChaggar 7b439f4
prototyping model constructor
PavanChaggar fc46d10
removing graphs dependency
PavanChaggar 56c6f32
typo in Model
PavanChaggar 867de59
adding dag struct and topological sorting
PavanChaggar 1066882
working draft of model constructor
PavanChaggar 5ebca19
changing model constructor to a @generator function to fix type insta…
PavanChaggar a510e78
adding docs and tests
PavanChaggar 69e1939
weird typo thing in vs code...
PavanChaggar 8de143f
adding show method
PavanChaggar 9fa1244
fixing typo in docs
PavanChaggar 0f30616
Delete playground.jl
PavanChaggar e58db4f
adding SparseArrays to test env
PavanChaggar 2115817
making suggested changes to adjacency matrix function
PavanChaggar dbf0929
making DAG outer constructor; changing types constraints for docs
PavanChaggar f7a3cf4
permuting A matrix and changing type of sorted vertices
PavanChaggar 8da1156
deleting some old comments; adding reference for dfs sort; changes to…
PavanChaggar b0b4c0c
starting on iterators
PavanChaggar 80b18c2
adding more iterator functions (maybe incorrectly)
PavanChaggar 268eaed
adding more iterator function + dag function
PavanChaggar 5857cd1
adding nicer Model constructor
PavanChaggar 8758fb5
changing adjacency matrix functions to support single input nodes
PavanChaggar 8ca3af8
adding more tests
PavanChaggar 7699ae2
editing adjacency matrix function and adding another test for model w…
PavanChaggar c053b2f
adding links to functions adapted from graphs.jl
PavanChaggar 14e3e98
adding error exception for parent node not found + tests
PavanChaggar 5205ad5
fixing typo
PavanChaggar d51a91d
reverting adjacency matrix function and changing test
PavanChaggar 8a8afa6
changing Model(;kwargs...) constructor to infer node inputs
PavanChaggar 1273fda
merging DAG and ModelState into GraphInfo
PavanChaggar 21192be
renaming simpleppl.jl to graphinfo.jl
PavanChaggar 0ce30f1
adding `Model` type back in
PavanChaggar ebcacae
adding type annotations for Model; adding tests for Model iterator type
PavanChaggar 1267785
putting GraphInfo into submodule GraphPPL
PavanChaggar f2e86da
remove exports; adding checks for input to `Models` are correct
PavanChaggar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| using AbstractPPL | ||
| import Base.getindex | ||
| using SparseArrays | ||
| using Setfield | ||
| using Setfield: PropertyLens, get | ||
|
|
||
| """ | ||
| GraphInfo | ||
|
|
||
| Record the state of the model as a struct of NamedTuples, all | ||
| sharing the same key values, namely, those of the model parameters. | ||
| `value` should store the initial/current value of the parameters. | ||
| `input` stores a tuple of inputs for a given node. `eval` are the | ||
| anonymous functions associated with each node. These might typically | ||
| be either deterministic values or some distribution, but could an | ||
| arbitrary julia program. `kind` is a tuple of symbols indicating | ||
| whether the node is a logical or stochastic node. Additionally, the | ||
| adjacency matrix and topologically ordered vertex list and stored. | ||
|
|
||
| GraphInfo is instantiated using the `Model` constctor. | ||
| """ | ||
|
|
||
| struct GraphInfo{T} <: AbstractModelTrace | ||
| input::NamedTuple{T} | ||
| value::NamedTuple{T} | ||
| eval::NamedTuple{T} | ||
| kind::NamedTuple{T} | ||
| A::SparseMatrixCSC | ||
| sorted_vertices::Vector{Symbol} | ||
| end | ||
|
|
||
| """ | ||
| Model(;kwargs...) | ||
|
|
||
| `Model` type constructor that takes in named arguments for | ||
| nodes and returns a `Model`. Nodes are pairs of variable names | ||
| and tuples containing default value, an eval function | ||
| and node type. The inputs of each node are inferred from | ||
| their anonymous functions. The returned object has a type | ||
| GraphInfo{(sorted_vertices...)}. | ||
|
|
||
| # Examples | ||
| ```jl-doctest | ||
| julia> using AbstractPPL | ||
|
|
||
| julia> Model( | ||
| s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), | ||
| μ = (1.0, () -> 1.0, :Logical), | ||
| y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) | ||
| ) | ||
| Nodes: | ||
| μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical) | ||
| s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic) | ||
| y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic) | ||
| ``` | ||
| """ | ||
|
|
||
| struct Model{T} <: AbstractProbabilisticProgram | ||
| g::GraphInfo{T} | ||
| end | ||
|
|
||
| function Model(;kwargs...) | ||
| for (i, node) in enumerate(values(kwargs)) | ||
| @assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)" | ||
| end | ||
| vals = getvals(NamedTuple(kwargs)) | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| args = [argnames(f) for f in vals[2]] | ||
| A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args)) | ||
| modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...]) | ||
| Model(GraphInfo(modelinputs..., A, sorted_vertices)) | ||
| end | ||
|
|
||
| """ | ||
| dag(inputs) | ||
|
|
||
| Function taking in a NamedTuple containing the inputs to each node | ||
| and returns the implied adjacency matrix and topologically ordered | ||
| vertex list. | ||
| """ | ||
| function dag(inputs) | ||
| input_names = Symbol[keys(inputs)...] | ||
| A = adjacency_matrix(inputs) | ||
| sorted_vertices = topological_sort_by_dfs(A) | ||
| sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices) | ||
| sorted_A, input_names[sorted_vertices] | ||
| end | ||
|
|
||
| """ | ||
| getvals(nt::NamedTuple{T}) | ||
|
|
||
| Takes in the arguments to Model(;kwargs...) as a NamedTuple and | ||
| reorders into a tuple of tuples each containing either of value, | ||
| input, eval and kind, as required by the GraphInfo type. | ||
| """ | ||
| @generated function getvals(nt::NamedTuple{T}) where T | ||
| values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3] | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| m = [:($(values[:,i]...), ) for i in 1:3] | ||
| return Expr(:tuple, m...) # :($(m...),) | ||
| end | ||
|
|
||
| """ | ||
| argnames(f::Function) | ||
|
|
||
| Returns a Vector{Symbol} of the inputs to an anonymous function `f`. | ||
| """ | ||
| argnames(f::Function) = Base.method_argnames(first(methods(f)))[2:end] | ||
|
|
||
| """ | ||
| adjacency_matrix(inputs) | ||
|
|
||
| For a NamedTuple{T} with vertices `T` paired with tuples of input nodes, | ||
| `adjacency_matrix` constructs the adjacency matrix using the order | ||
| of variables given by `T`. | ||
|
|
||
| # Examples | ||
| ```jl-doctest | ||
| julia> inputs = (a = (), b = (), c = (:a, :b)) | ||
| (a = (), b = (), c = (:a, :b)) | ||
|
|
||
| julia> AbstractPPL.adjacency_matrix(inputs) | ||
| 3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries: | ||
| ⋅ ⋅ ⋅ | ||
| ⋅ ⋅ ⋅ | ||
| 1.0 1.0 ⋅ | ||
| ``` | ||
| """ | ||
| function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes} | ||
| N = length(inputs) | ||
| col_inds = NamedTuple{nodes}(ntuple(identity, N)) | ||
| A = spzeros(Bool, N, N) | ||
| for (row, node) in enumerate(nodes) | ||
| for input in inputs[node] | ||
| if input ∉ nodes | ||
| error("Parent node of $(input) not found in node set: $(nodes)") | ||
| end | ||
| col = col_inds[input] | ||
| A[row, col] = true | ||
| end | ||
| end | ||
| return A | ||
| end | ||
|
|
||
| function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int | ||
| #adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302 | ||
| inds, _ = findnz(A[:, u]) | ||
| inds | ||
| end | ||
|
|
||
| function topological_sort_by_dfs(A) | ||
| # lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44 | ||
| # Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf | ||
| n_verts = size(A)[1] | ||
| vcolor = zeros(UInt8, n_verts) | ||
| verts = Vector{Int64}() | ||
| for v in 1:n_verts | ||
| vcolor[v] != 0 && continue | ||
| S = Vector{Int64}([v]) | ||
| vcolor[v] = 1 | ||
| while !isempty(S) | ||
| u = S[end] | ||
| w = 0 | ||
| for n in outneighbors(A, u) | ||
| if vcolor[n] == 1 | ||
| error("The input graph contains at least one loop.") # TODO 0.7 should we use a different error? | ||
| elseif vcolor[n] == 0 | ||
| w = n | ||
| break | ||
| end | ||
| end | ||
| if w != 0 | ||
| vcolor[w] = 1 | ||
| push!(S, w) | ||
| else | ||
| vcolor[u] = 2 | ||
| push!(verts, u) | ||
| pop!(S) | ||
| end | ||
| end | ||
| end | ||
| return reverse(verts) | ||
| end | ||
|
|
||
| """ | ||
| Base.getindex(m::Model, vn::VarName{p}) | ||
|
|
||
| Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`, | ||
| `eval` and `kind` for node `p`. | ||
|
|
||
| # Examples | ||
|
|
||
| ```jl-doctest | ||
| julia> using AbstractPPL | ||
|
|
||
| julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), | ||
| μ = (1.0, () -> 1.0, :Logical), | ||
| y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)) | ||
| (s2 = Symbol[], μ = Symbol[], y = [:μ, :s2]) | ||
| Nodes: | ||
| μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic) | ||
| s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical) | ||
| y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic) | ||
|
|
||
|
|
||
| julia> m[@varname y] | ||
| (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic) | ||
| ``` | ||
| """ | ||
| @generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p} | ||
| fns = fieldnames(GraphInfo)[1:4] | ||
| name_lens = Setfield.PropertyLens{p}() | ||
| field_lenses = [Setfield.PropertyLens{f}() for f in fns] | ||
| values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses] | ||
| return :(NamedTuple{$(fns)}(($(values...),))) | ||
| end | ||
|
|
||
| function Base.getindex(m::Model, vn::VarName) | ||
| return m.g[vn] | ||
| end | ||
|
|
||
| function Base.show(io::IO, m::Model) | ||
| print(io, "Nodes: \n") | ||
| for node in nodes(m) | ||
| print(io, "$node = ", m[VarName{node}()], "\n") | ||
| end | ||
| end | ||
|
|
||
|
|
||
| function Base.iterate(m::Model, state=1) | ||
| state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1) | ||
| end | ||
|
|
||
| Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]} | ||
| Base.IteratorEltype(m::Model) = HasEltype() | ||
|
|
||
| Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices) | ||
| Base.values(m::Model) = Base.Generator(identity, m) | ||
| Base.length(m::Model) = length(nodes(m)) | ||
| Base.keytype(m::Model) = eltype(keys(m)) | ||
| Base.valtype(m::Model) = eltype(m) | ||
|
|
||
|
|
||
| """ | ||
| dag(m::Model) | ||
|
|
||
| Returns the adjacency matrix of the model as a SparseArray. | ||
| """ | ||
| get_dag(m::Model) = m.g.A | ||
|
|
||
| """ | ||
| nodes(m::Model) | ||
|
|
||
| Returns a `Vector{Symbol}` containing the sorted vertices | ||
| of the DAG. | ||
| """ | ||
| nodes(m::Model) = m.g.sorted_vertices | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| using AbstractPPL | ||
| import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag | ||
| using SparseArrays | ||
| using Test | ||
| ## Example taken from Mamba | ||
| line = Dict{Symbol, Any}( | ||
| :x => [1, 2, 3, 4, 5], | ||
| :y => [1, 3, 3, 3, 5] | ||
| ) | ||
| line[:xmat] = [ones(5) line[:x]] | ||
|
|
||
| # just making it a NamedTuple so that the values can be tested later. Constructor should be used as Model(;kwargs...). | ||
| model = ( | ||
| β = (zeros(2), () -> MvNormal(2, sqrt(1000)), :Stochastic), | ||
| xmat = (line[:xmat], () -> line[:xmat], :Logical), | ||
| s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), | ||
| μ = (zeros(5), (xmat, β) -> xmat * β, :Logical), | ||
| y = (zeros(5), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) | ||
| ) | ||
|
|
||
| # construct the model! | ||
| m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor | ||
|
|
||
| # test the type of the model is correct | ||
| @test typeof(m) <: Model | ||
| @test typeof(m) == Model{(:s2, :xmat, :β, :μ, :y)} | ||
| @test typeof(m.g) <: GraphInfo <: AbstractModelTrace | ||
| @test typeof(m.g) == GraphInfo{(:s2, :xmat, :β, :μ, :y)} | ||
|
|
||
| # test the dag is correct | ||
| A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0]) | ||
| @test get_dag(m) == A | ||
|
|
||
| @test length(m) == 5 | ||
| @test eltype(m) == valtype(m) | ||
|
|
||
| # check the values from the NamedTuple match the values in the fields of GraphInfo | ||
| vals = AbstractPPL.GraphPPL.getvals(model) | ||
| for (i, field) in enumerate([:value, :eval, :kind]) | ||
| @test eval( :( values(m.g.$field) == vals[$i] ) ) | ||
| end | ||
|
|
||
| for node in m | ||
| @test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]} | ||
| end | ||
|
|
||
| # test the right inputs have been inferred | ||
| @test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, :β), y = (:μ, :s2)) | ||
|
|
||
| # test keys are VarNames | ||
| for key in keys(m) | ||
| @test typeof(key) <: VarName | ||
| end | ||
|
|
||
| # test Model constructor for model with single parent node | ||
| single_parent_m = Model(μ = (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) | ||
| @test typeof(single_parent_m) == Model{(:μ, :y)} | ||
| @test typeof(single_parent_m.g) == GraphInfo{(:μ, :y)} | ||
|
|
||
| # test ErrorException for parent node not found | ||
| @test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) | ||
|
|
||
| # test AssertionError thrown for kwargs with the wrong order of inputs | ||
| @test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I'm allowed to call it
GraphInfo. Otherwise it prompts the following error:ERROR: LoadError: LoadError: invalid redefinition of constant GraphInfo