Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ForwardBackward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ export
UniformUnmasking,
GeneralDiscrete,
PiQ,
HPiQ,
PiNode,
PiLeaf,
#Likelihoods & States
CategoricalLikelihood,
GaussianLikelihood,
Expand All @@ -38,6 +41,9 @@ export
tensor,
sumnorm,
stochastic,
init_leaf_indices!,
add_child!,
init_first_level_parent!,
#Manifolds
ManifoldProcess,
ManifoldState,
Expand Down
265 changes: 261 additions & 4 deletions src/processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,273 @@ process = PiQ(ones(4) ./ 4)
process = PiQ(2.0, [0.1, 0.2, 0.3, 0.4])
```
"""
struct PiQ{T} <: DiscreteProcess
struct PiQ{T,V<:AbstractVector{T}} <: DiscreteProcess
r::T
π::Vector{T}
π::V
β::T
end

function PiQ(r::T,π::Vector{T}; normalize=true) where T <: Real
function PiQ(r::T,π::AbstractVector{T}; normalize=true) where T <: Real
piNormed = π ./ sum(π)
β = normalize ? 1/(1-sum(abs2.(piNormed))) : T(1.0)
PiQ(r, piNormed, β)
end

PiQ(π::Vector{T}; normalize=true) where T <: Real = PiQ(T(1.0), π; normalize=normalize)
PiQ(π::AbstractVector{T}; normalize=true) where T <: Real = PiQ(T(1.0), π; normalize=normalize)

"""
abstract type Nodal <: Nodal


Base type for tree nodes which is used to define the PHiQ process.
"""
abstract type Nodal end

"""
mutable struct PiNode{T} <: Nodal

Internal node type for The PHiQ tree.

# Parameters
- `u`: Rate parameter
- `parent`: Parent node
- `children`: Children nodes
- `leaf_indices`: State indices of descendent leaf nodes
"""
#TODO Make this GPU compatible
mutable struct PiNode{T} <: Nodal
u::T
parent::Union{PiNode{T}, Nothing}
children::Union{Vector{<:Nodal},Nothing}
leaf_indices::Union{Vector{<:Int}, Nothing}
first_level_parent::Union{Bool, Nothing}
PiNode(u::T) where T = new{T}(u, nothing, nothing, nothing, nothing)
end
# mutable struct PiNode{T, S, V<:AbstractVector{S}} <: Nodal
# u::T
# parent::Union{PiNode{T}, Nothing}
# children::Union{Vector{<:Nodal},Nothing}
# leaf_indices::Union{V, Nothing}
# first_level_parent::Union{Bool, Nothing}

# function PiNode{T,S,V}(u, parent, children, leaf_indices, first_level_parent) where {T,S,V}
# new{T,S,V}(u, parent, children, leaf_indices, first_level_parent)
# end

# PiNode(u::T) where T = new{T, Int32, Vector{Int32}}(u, nothing, nothing, nothing, nothing)
# end

"""
mutable struct PiLeaf{T} <: Nodal

A PiLeaf node is a representation of a discrete state.

# Parameters
- `index`: State index
- `parent`: parent node
"""
mutable struct PiLeaf <: Nodal
index::Int32
parent::Union{PiNode, Nothing}
PiLeaf(index) = new(index, nothing)
end

"""
struct HPiQ{T} <: DiscreteProcess

Discrete-state continuous-time process with an equilibrium vector `π` and a hierichal tree structure `tree`, which imposes a hierichal structure where transition events can occur for a subset of the states.
Note, remember to call `init_leaf_indicies!` to correctly collect descendent leaf states for internal nodes, this is needed to call e.g. forward and backward.

# Parameters
- `tree`: Root node of a tree
- `π`: equilibrium vector

# Examples
```julia

# The root
tree = PiNode(1.0)

#Internal Nodes
child1 = PiNode(2.0)
child2 = PiNode(3.0)

add_child!(tree, child1)
add_child!(tree, child2)

# States
leaf1 = PiLeaf(1)
leaf2 = PiLeaf(2)
leaf3 = PiLeaf(3)
leaf4 = PiLeaf(4)

add_child!(child1, leaf1)
add_child!(child1, leaf2)
add_child!(child2, leaf3)
add_child!(child2, leaf4)

init_leaf_indices!(tree)
π = [0.2, 0.3, 0.4, 0.1]

HPiQ_process = HPiQ(tree, π)
```
"""
struct HPiQ{T, V<:AbstractVector{T}} <: DiscreteProcess
tree::PiNode
π::V
end

"""
add_child!(node::PiNode, child::Nodal)

Helper function for the construction of a HPiQ tree.

# Parameters
- `node`: The parent node
- `child`: The child node
"""
function add_child!(node::PiNode, child::Nodal)
if isnothing(node.children)
node.children = Nodal[]
end
push!(node.children, child)
child.parent = node
end

"""
init_leaf_indices!(node::PiNode)

This function assigns the state indices of its descendent leaf nodes to each internal node in a HPiQ tree.

#TODO maybe save indexing as range, e.g. 320:23220, but this restricts the degree of freedom of indexing PiLeafs state

# Parameters
- `node`: The root node of the tree
"""
function init_leaf_indices!(node::PiNode)
indices = Int[]
if isnothing(node.children)
node.leaf_indices = indices
return indices
end
for child in node.children
if isa(child, PiLeaf)
push!(indices, child.index)
elseif isa(child, PiNode)
append!(indices, init_leaf_indices!(child))
end
end
node.leaf_indices = sort!(unique!(indices))
return node.leaf_indices
end
# function init_leaf_indices!(node::PiNode{T, S, V}; gpu_adapted::Bool=false) where {T, S, V}
# if isnothing(node.children)
# if gpu_adapted
# node.leaf_indices = V()
# else
# node.leaf_indices = S[]
# end
# return node.leaf_indices
# end
# collected_indices = S[]
# for child in node.children
# if child isa PiLeaf
# push!(collected_indices, child.index)
# elseif child isa PiNode
# append!(collected_indices, init_leaf_indices!(child))
# end
# end
# unique!(collected_indices)
# sort!(collected_indices)
# if gpu_adapted
# node.leaf_indices = V(collected_indices)
# else
# node.leaf_indices = collected_indices
# end
# return node.leaf_indices
# end

# function init_leaf_indices!(node::PiNode)
# indices = typeof(node).parameters[3][]
# if isnothing(node.children)
# node.leaf_indices = indices
# return indices
# end
# for child in node.children
# if isa(child, PiLeaf)
# push!(indices, child.index)
# elseif isa(child, PiNode)
# append!(indices, init_leaf_indices!(child))
# end
# end
# a=unique!(indices)
# println(typeof(a))
# node.leaf_indices = sort!(a)
# return node.leaf_indices
# end

# function init_first_level_parent!(node::PiNode)
# node.first_level_parent = true
# for child in node.children
# if isa(child, PiNode)
# node.first_level_parent = false
# init_first_level_parent!(child)
# end
# end
# end

function init_first_level_parent!(node::PiNode)
node.first_level_parent = all(child -> isa(child, PiLeaf), node.children)
for child in node.children
if isa(child, PiNode)
init_first_level_parent!(child)
end
end
end


# Gets all internal nodes of a HPiQ tree
function get_all_nodes!(node::PiNode, nodes::Vector{PiNode})
push!(nodes, node)
if !isnothing(node.children)
for child in node.children
if isa(child, PiNode)
get_all_nodes!(child, nodes)
end
end
end
return nodes
end

# This maps HPiQ process to its corresponding transition rate matrix.
function HPiQ_Qmatrix(process::HPiQ)
(; tree, π) = process
N = length(π)
Q = zeros(Float64, N, N)
all_nodes = PiNode[]
get_all_nodes!(tree, all_nodes)

for node in all_nodes
isnothing(node.leaf_indices) && continue
idx = node.leaf_indices
length(idx) <= 1 && continue
u = node.u
π_partition_view = view(π, idx)
sum_π = sum(π_partition_view)
isapprox(sum_π, 0.0) && continue
for i_global in idx
for j_global in idx

if i_global != j_global
Q[i_global, j_global] += u * (π[j_global] / sum_π)
end
end
end
end

for i in 1:N
Q[i, i] = -sum(Q[i, :])
end

return Q
end
Loading