Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
be1b846
initiali commit
PavanChaggar Nov 29, 2021
e7debfb
adding simpleppl file and test model
PavanChaggar Nov 29, 2021
7b439f4
prototyping model constructor
PavanChaggar Dec 6, 2021
fc46d10
removing graphs dependency
PavanChaggar Dec 6, 2021
56c6f32
typo in Model
PavanChaggar Dec 6, 2021
867de59
adding dag struct and topological sorting
PavanChaggar Dec 13, 2021
1066882
working draft of model constructor
PavanChaggar Jan 10, 2022
5ebca19
changing model constructor to a @generator function to fix type insta…
PavanChaggar Jan 10, 2022
a510e78
adding docs and tests
PavanChaggar Jan 10, 2022
69e1939
weird typo thing in vs code...
PavanChaggar Jan 10, 2022
8de143f
adding show method
PavanChaggar Jan 10, 2022
9fa1244
fixing typo in docs
PavanChaggar Jan 10, 2022
0f30616
Delete playground.jl
PavanChaggar Jan 10, 2022
e58db4f
adding SparseArrays to test env
PavanChaggar Jan 10, 2022
2115817
making suggested changes to adjacency matrix function
PavanChaggar Jan 18, 2022
dbf0929
making DAG outer constructor; changing types constraints for docs
PavanChaggar Jan 18, 2022
f7a3cf4
permuting A matrix and changing type of sorted vertices
PavanChaggar Jan 18, 2022
8da1156
deleting some old comments; adding reference for dfs sort; changes to…
PavanChaggar Jan 18, 2022
b0b4c0c
starting on iterators
PavanChaggar Jan 22, 2022
80b18c2
adding more iterator functions (maybe incorrectly)
PavanChaggar Jan 23, 2022
268eaed
adding more iterator function + dag function
PavanChaggar Jan 29, 2022
5857cd1
adding nicer Model constructor
PavanChaggar Jan 29, 2022
8758fb5
changing adjacency matrix functions to support single input nodes
PavanChaggar Jan 29, 2022
8ca3af8
adding more tests
PavanChaggar Jan 30, 2022
7699ae2
editing adjacency matrix function and adding another test for model w…
PavanChaggar Jan 30, 2022
c053b2f
adding links to functions adapted from graphs.jl
PavanChaggar Jan 31, 2022
14e3e98
adding error exception for parent node not found + tests
PavanChaggar Jan 31, 2022
5205ad5
fixing typo
PavanChaggar Jan 31, 2022
d51a91d
reverting adjacency matrix function and changing test
PavanChaggar Jan 31, 2022
8a8afa6
changing Model(;kwargs...) constructor to infer node inputs
PavanChaggar Feb 6, 2022
1273fda
merging DAG and ModelState into GraphInfo
PavanChaggar Feb 6, 2022
21192be
renaming simpleppl.jl to graphinfo.jl
PavanChaggar Feb 7, 2022
0ce30f1
adding `Model` type back in
PavanChaggar Feb 7, 2022
ebcacae
adding type annotations for Model; adding tests for Model iterator type
PavanChaggar Feb 7, 2022
1267785
putting GraphInfo into submodule GraphPPL
PavanChaggar Feb 7, 2022
f2e86da
remove exports; adding checks for input to `Models` are correct
PavanChaggar Feb 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# vs code environment
.vscode
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.4"
version = "0.4.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
AbstractMCMC = "2, 3"
Expand Down
6 changes: 5 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ export VarName, getsym, getlens, inspace, subsumes, varname, vsym, @varname, @vs
# Abstract model functions
export AbstractProbabilisticProgram, condition, decondition, logdensityof, densityof


# Abstract traces
export AbstractModelTrace

Expand All @@ -17,4 +16,9 @@ include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
include("deprecations.jl")

# GraphInfo
module GraphPPL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module GraphPPL
module GraphInfo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I'm allowed to call it GraphInfo. Otherwise it prompts the following error:

ERROR: LoadError: LoadError: invalid redefinition of constant GraphInfo

include("graphinfo.jl")
end

end # module
255 changes: 255 additions & 0 deletions src/graphinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
using AbstractPPL
import Base.getindex
using SparseArrays
using Setfield
using Setfield: PropertyLens, get

"""
GraphInfo

Record the state of the model as a struct of NamedTuples, all
sharing the same key values, namely, those of the model parameters.
`value` should store the initial/current value of the parameters.
`input` stores a tuple of inputs for a given node. `eval` are the
anonymous functions associated with each node. These might typically
be either deterministic values or some distribution, but could an
arbitrary julia program. `kind` is a tuple of symbols indicating
whether the node is a logical or stochastic node. Additionally, the
adjacency matrix and topologically ordered vertex list and stored.

GraphInfo is instantiated using the `Model` constctor.
"""

struct GraphInfo{T} <: AbstractModelTrace
input::NamedTuple{T}
value::NamedTuple{T}
eval::NamedTuple{T}
kind::NamedTuple{T}
A::SparseMatrixCSC
sorted_vertices::Vector{Symbol}
end

"""
Model(;kwargs...)

`Model` type constructor that takes in named arguments for
nodes and returns a `Model`. Nodes are pairs of variable names
and tuples containing default value, an eval function
and node type. The inputs of each node are inferred from
their anonymous functions. The returned object has a type
GraphInfo{(sorted_vertices...)}.

# Examples
```jl-doctest
julia> using AbstractPPL

julia> Model(
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
μ = (1.0, () -> 1.0, :Logical),
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
)
Nodes:
μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical)
s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic)
y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
```
"""

struct Model{T} <: AbstractProbabilisticProgram
g::GraphInfo{T}
end

function Model(;kwargs...)
for (i, node) in enumerate(values(kwargs))
@assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)"
end
vals = getvals(NamedTuple(kwargs))
args = [argnames(f) for f in vals[2]]
A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args))
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...])
Model(GraphInfo(modelinputs..., A, sorted_vertices))
end

"""
dag(inputs)

Function taking in a NamedTuple containing the inputs to each node
and returns the implied adjacency matrix and topologically ordered
vertex list.
"""
function dag(inputs)
input_names = Symbol[keys(inputs)...]
A = adjacency_matrix(inputs)
sorted_vertices = topological_sort_by_dfs(A)
sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices)
sorted_A, input_names[sorted_vertices]
end

"""
getvals(nt::NamedTuple{T})

Takes in the arguments to Model(;kwargs...) as a NamedTuple and
reorders into a tuple of tuples each containing either of value,
input, eval and kind, as required by the GraphInfo type.
"""
@generated function getvals(nt::NamedTuple{T}) where T
values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3]
m = [:($(values[:,i]...), ) for i in 1:3]
return Expr(:tuple, m...) # :($(m...),)
end

"""
argnames(f::Function)

Returns a Vector{Symbol} of the inputs to an anonymous function `f`.
"""
argnames(f::Function) = Base.method_argnames(first(methods(f)))[2:end]

"""
adjacency_matrix(inputs)

For a NamedTuple{T} with vertices `T` paired with tuples of input nodes,
`adjacency_matrix` constructs the adjacency matrix using the order
of variables given by `T`.

# Examples
```jl-doctest
julia> inputs = (a = (), b = (), c = (:a, :b))
(a = (), b = (), c = (:a, :b))

julia> AbstractPPL.adjacency_matrix(inputs)
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
⋅ ⋅ ⋅
⋅ ⋅ ⋅
1.0 1.0 ⋅
```
"""
function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
N = length(inputs)
col_inds = NamedTuple{nodes}(ntuple(identity, N))
A = spzeros(Bool, N, N)
for (row, node) in enumerate(nodes)
for input in inputs[node]
if input ∉ nodes
error("Parent node of $(input) not found in node set: $(nodes)")
end
col = col_inds[input]
A[row, col] = true
end
end
return A
end

function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int
#adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302
inds, _ = findnz(A[:, u])
inds
end

function topological_sort_by_dfs(A)
# lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44
# Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf
n_verts = size(A)[1]
vcolor = zeros(UInt8, n_verts)
verts = Vector{Int64}()
for v in 1:n_verts
vcolor[v] != 0 && continue
S = Vector{Int64}([v])
vcolor[v] = 1
while !isempty(S)
u = S[end]
w = 0
for n in outneighbors(A, u)
if vcolor[n] == 1
error("The input graph contains at least one loop.") # TODO 0.7 should we use a different error?
elseif vcolor[n] == 0
w = n
break
end
end
if w != 0
vcolor[w] = 1
push!(S, w)
else
vcolor[u] = 2
push!(verts, u)
pop!(S)
end
end
end
return reverse(verts)
end

"""
Base.getindex(m::Model, vn::VarName{p})

Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`,
`eval` and `kind` for node `p`.

# Examples

```jl-doctest
julia> using AbstractPPL

julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
μ = (1.0, () -> 1.0, :Logical),
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
(s2 = Symbol[], μ = Symbol[], y = [:μ, :s2])
Nodes:
μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic)
s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical)
y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)


julia> m[@varname y]
(value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
```
"""
@generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p}
fns = fieldnames(GraphInfo)[1:4]
name_lens = Setfield.PropertyLens{p}()
field_lenses = [Setfield.PropertyLens{f}() for f in fns]
values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses]
return :(NamedTuple{$(fns)}(($(values...),)))
end

function Base.getindex(m::Model, vn::VarName)
return m.g[vn]
end

function Base.show(io::IO, m::Model)
print(io, "Nodes: \n")
for node in nodes(m)
print(io, "$node = ", m[VarName{node}()], "\n")
end
end


function Base.iterate(m::Model, state=1)
state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
end

Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
Base.IteratorEltype(m::Model) = HasEltype()

Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
Base.values(m::Model) = Base.Generator(identity, m)
Base.length(m::Model) = length(nodes(m))
Base.keytype(m::Model) = eltype(keys(m))
Base.valtype(m::Model) = eltype(m)


"""
dag(m::Model)

Returns the adjacency matrix of the model as a SparseArray.
"""
get_dag(m::Model) = m.g.A

"""
nodes(m::Model)

Returns a `Vector{Symbol}` containing the sorted vertices
of the DAG.
"""
nodes(m::Model) = m.g.sorted_vertices
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
64 changes: 64 additions & 0 deletions test/graphinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using AbstractPPL
import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag
using SparseArrays
using Test
## Example taken from Mamba
line = Dict{Symbol, Any}(
:x => [1, 2, 3, 4, 5],
:y => [1, 3, 3, 3, 5]
)
line[:xmat] = [ones(5) line[:x]]

# just making it a NamedTuple so that the values can be tested later. Constructor should be used as Model(;kwargs...).
model = (
β = (zeros(2), () -> MvNormal(2, sqrt(1000)), :Stochastic),
xmat = (line[:xmat], () -> line[:xmat], :Logical),
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
μ = (zeros(5), (xmat, β) -> xmat * β, :Logical),
y = (zeros(5), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
)

# construct the model!
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor

# test the type of the model is correct
@test typeof(m) <: Model
@test typeof(m) == Model{(:s2, :xmat, :β, :μ, :y)}
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
@test typeof(m.g) == GraphInfo{(:s2, :xmat, :β, :μ, :y)}

# test the dag is correct
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
@test get_dag(m) == A

@test length(m) == 5
@test eltype(m) == valtype(m)

# check the values from the NamedTuple match the values in the fields of GraphInfo
vals = AbstractPPL.GraphPPL.getvals(model)
for (i, field) in enumerate([:value, :eval, :kind])
@test eval( :( values(m.g.$field) == vals[$i] ) )
end

for node in m
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]}
end

# test the right inputs have been inferred
@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, :β), y = (:μ, :s2))

# test keys are VarNames
for key in keys(m)
@test typeof(key) <: VarName
end

# test Model constructor for model with single parent node
single_parent_m = Model(μ = (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
@test typeof(single_parent_m) == Model{(:μ, :y)}
@test typeof(single_parent_m.g) == GraphInfo{(:μ, :y)}

# test ErrorException for parent node not found
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))

# test AssertionError thrown for kwargs with the wrong order of inputs
@test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Test

@testset "AbstractPPL.jl" begin
include("deprecations.jl")

include("graphinfo.jl")
@testset "doctests" begin
DocMeta.setdocmeta!(
AbstractPPL,
Expand All @@ -22,5 +22,4 @@ using Test
)
doctest(AbstractPPL; manual=false)
end
end

end