Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f5ad900
Implemented SimpleAdaptiveTauLeaping and SimpleImplicitTauLeaping
sivasathyaseeelan Aug 9, 2025
aae5d9f
update project.toml
sivasathyaseeelan Aug 10, 2025
230d508
test changes
sivasathyaseeelan Aug 10, 2025
25b8c16
refactor
sivasathyaseeelan Aug 12, 2025
07c429e
test refactor
sivasathyaseeelan Aug 12, 2025
885ac59
refactor
sivasathyaseeelan Aug 13, 2025
0ec9d39
added saveat in SimpleAdaptiveTauLeaping
sivasathyaseeelan Aug 16, 2025
0e7ff26
update
sivasathyaseeelan Aug 24, 2025
89378c6
Update src/simple_regular_solve.jl
sivasathyaseeelan Aug 24, 2025
14e0be7
Update src/simple_regular_solve.jl
sivasathyaseeelan Aug 24, 2025
e7f975e
update
sivasathyaseeelan Aug 24, 2025
6e789cd
test update
sivasathyaseeelan Aug 24, 2025
a8999f4
using maj for adaptive tauleaping
sivasathyaseeelan Aug 24, 2025
0125e21
project.toml update
sivasathyaseeelan Aug 24, 2025
1af190d
saveat logic change
sivasathyaseeelan Aug 24, 2025
10f4ce3
test change
sivasathyaseeelan Aug 24, 2025
6d3d900
saveat optimization
sivasathyaseeelan Aug 25, 2025
2c03d67
refactor
sivasathyaseeelan Aug 25, 2025
3f90750
memory optimization
sivasathyaseeelan Aug 25, 2025
fb72149
validate_pure_leaping_inputs extended for adaptive version
sivasathyaseeelan Aug 25, 2025
7a7232a
some
sivasathyaseeelan Aug 25, 2025
fe7cec0
space optimized in compute_tau_explicit
sivasathyaseeelan Aug 25, 2025
8e7ff16
computegi and comutehor changes
sivasathyaseeelan Aug 25, 2025
bc770d1
reactant_stoch in hor
sivasathyaseeelan Aug 25, 2025
b5f77f5
compute_gi update
sivasathyaseeelan Aug 26, 2025
0b72d4c
added references
sivasathyaseeelan Aug 26, 2025
5415947
added unpack
sivasathyaseeelan Aug 27, 2025
b39390f
test changes
sivasathyaseeelan Aug 27, 2025
822562f
test changes
sivasathyaseeelan Aug 27, 2025
b572987
export changes
sivasathyaseeelan Aug 27, 2025
785266b
test changes
sivasathyaseeelan Aug 27, 2025
b47df7c
some change in gi calculation
sivasathyaseeelan Aug 28, 2025
e02d432
changed compute_gi as per paper
sivasathyaseeelan Aug 28, 2025
092d361
some
sivasathyaseeelan Aug 28, 2025
7217cf0
some
sivasathyaseeelan Aug 28, 2025
cc3a78a
optimized compute_gi
sivasathyaseeelan Aug 29, 2025
48fece2
zero rates case for SimpleAdaptiveTauLeaping is added
sivasathyaseeelan Aug 30, 2025
12f84ba
SimpleExplicitTauLeaping
sivasathyaseeelan Sep 5, 2025
98d64f3
test update
sivasathyaseeelan Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export SSAStepper

# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, EnsembleGPUKernel
export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down
226 changes: 226 additions & 0 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end

struct SimpleExplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter
end

SimpleExplicitTauLeaping(; epsilon=0.05) = SimpleExplicitTauLeaping(epsilon)

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
Expand All @@ -14,6 +20,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
jump_prob.regular_jump !== nothing
end

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
seed = nothing, dt = error("dt is required for SimpleTauLeaping."))
validate_pure_leaping_inputs(jump_prob, alg) ||
Expand Down Expand Up @@ -61,6 +80,213 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
interp = DiffEqBase.ConstantInterpolation(t, u))
end

function compute_hor(reactant_stoch, numjumps)
# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
# HOR is the sum of stoichiometric coefficients of reactants in reaction j.
Comment on lines +83 to +85
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function compute_hor(reactant_stoch, numjumps)
# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
# HOR is the sum of stoichiometric coefficients of reactants in reaction j.
# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
# HOR is the sum of stoichiometric coefficients of reactants in reaction j.
function compute_hor(reactant_stoch, numjumps)

And please update other functions accordingly. Place these kind of design comments before the function, not right inside it.

hor = zeros(Int, numjumps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hor = zeros(Int, numjumps)
hor = zeros(Int, numjumps)

Instead of Int, extract the type from reactant_stoch via a parametric input to this function.

for j in 1:numjumps
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0)
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j])

sum can figure this out and avoid type assumptions.

if order > 3
error("Reaction $j has order $order, which is not supported (maximum order is 3).")
end
hor[j] = order
end
return hor
end

function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps)
# Precompute reaction conditions for each species i, including:
# - max_hor: the highest order of reaction (HOR) where species i is a reactant.
# - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor.
# Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27).
max_hor = zeros(Int, numspecies)
max_stoich = zeros(Int, numspecies)
Comment on lines +102 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't assume Int, just match the eltype(hor).

for j in 1:numjumps
for (spec_idx, stoch) in reactant_stoch[j]
if stoch > 0 # Species is a reactant
if hor[j] > max_hor[spec_idx]
max_hor[spec_idx] = hor[j]
max_stoich[spec_idx] = stoch
elseif hor[j] == max_hor[spec_idx]
max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch)
end
end
end
end
return max_hor, max_stoich
end

function compute_gi(u, max_hor, max_stoich, i, t)
# Compute g_i for species i to bound the relative change in propensity functions,
# as per Cao et al. (2006), Section IV, equation (27).
# g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant:
# - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1
# - HOR = 2 (second-order):
# - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2
# - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1)
# - HOR = 3 (third-order):
# - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3
# - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1))
# - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2)
# Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep.
Comment on lines +119 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function compute_gi(u, max_hor, max_stoich, i, t)
# Compute g_i for species i to bound the relative change in propensity functions,
# as per Cao et al. (2006), Section IV, equation (27).
# g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant:
# - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1
# - HOR = 2 (second-order):
# - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2
# - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1)
# - HOR = 3 (third-order):
# - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3
# - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1))
# - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2)
# Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep.
# Compute g_i for species i to bound the relative change in propensity functions,
# as per Cao et al. (2006), Section IV, equation (27).
# g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant:
# - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1
# - HOR = 2 (second-order):
# - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2
# - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1)
# - HOR = 3 (third-order):
# - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3
# - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1))
# - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2)
# Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep.
function compute_gi(u, max_hor, max_stoich, i, t)

if max_hor[i] == 0 # No reactions involve species i as a reactant
return 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't hard-code types like this. At the top of the function you could do something like

one_max_hor = one( 1 / one(eltype(u)) )

to get the type you need. And then return integer multiples of one_max_hor when you need hardcoded values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please get rid of all the hardcoded floating point values in this code. Replace things like 1.5 by 3/2 etc.

elseif max_hor[i] == 1
return 1.0
elseif max_hor[i] == 2
if max_stoich[i] == 1
return 2.0
elseif max_stoich[i] == 2
return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1
# Fallback to 2 if x_i <= 1
return u[i] > 1 ? 2 + 1 / (u[i] - 1) : 2*one_max_hor

If x_i <= 1 why fall back to 2?

It seems like when this happens we know the reaction actually can't occur as the propensity (i.e. rate) would be non-positive. In this situation I would think the reaction should be turned off and not included in the time-step calculation?

end
elseif max_hor[i] == 3
if max_stoich[i] == 1
return 3.0
elseif max_stoich[i] == 2
return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1
return u[i] > 1 ? 1 (3 + 3 / (2*(u[i] - 1))) : 3*one_max_hor # Fallback to 3.0 if x_i <= 1

elseif max_stoich[i] == 3
return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2
Comment on lines +143 to +148
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the last elseif need to be an elseif or can it just be an else? Similar comment for the max_hor[i] == 2 case.

end
end
return 1.0 # Default case
end

function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
# Compute the tau-leaping step-size using equation (20) from Cao et al. (2006):
# tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) }
# where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b):
# mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x)
# I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified).
rate(rate_cache, u, p, t)
if all(==(0.0), rate_cache) # Handle case where all rates are zero
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if all(==(0.0), rate_cache) # Handle case where all rates are zero
if all(<=(0), rate_cache) # Handle case where all rates are zero

If a rate is negative then that reaction can't happen either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for pure-mass action systems, if all the rates are currently zero then you can never have another reaction occur, so we can just step directly to the final time. That would need to be changed if we later allow non-MassActionJumps.

return dtmin
end
tau = Inf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tau = Inf
tau = typemax(typeof(t))

for i in 1:length(u)
mu = zero(eltype(u))
sigma2 = zero(eltype(u))
for j in 1:size(nu, 2)
mu += nu[i, j] * rate_cache[j] # Equation (9a)
sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b)
end
gi = compute_gi(u, max_hor, max_stoich, i, t)
bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use explicitly typed 1.0.

mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8)
mu_term = abs(mu) > 0 ? bound / abs(mu) : typemax(t) # First term in equation (8)

Similar for the sigma_term.

sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8)
tau = min(tau, mu_term, sigma_term) # Equation (8)
end
return max(tau, dtmin)
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
seed = nothing,
dtmin = 1e-10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dtmin = 1e-10,
dtmin = 1e-10,

Use the type of tspan[2] from within jump_prob.prob.

saveat = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

@unpack prob, rng = jump_prob
(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
# Extract rates
rate = rj !== nothing ? rj.rate :
(out, u, p, t) -> begin
for j in 1:numjumps
out[j] = evalrxrate(u, j, maj)
end
end
Comment on lines +196 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a standalone function instead of inline (you can use a functor design if you need to save internal state).

c = rj !== nothing ? rj.c : nothing
u0 = copy(prob.u0)
tspan = prob.tspan
p = prob.p

# Initialize current state and saved history
u_current = copy(u0)
t_current = tspan[1]
usave = [copy(u0)]
tsave = [tspan[1]]
rate_cache = zeros(float(eltype(u0)), numjumps)
counts = zero(rate_cache)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon

# Extract net stoichiometry for state updates
nu = zeros(float(eltype(u0)), length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
# Extract reactant stoichiometry for hor and gi
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)
max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps)

# Set up saveat_times
saveat_times = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
saveat_times = nothing

Not needed.

if isnothing(saveat)
saveat_times = Vector{typeof(tspan[1])}()
elseif saveat isa Number
saveat_times = collect(range(tspan[1], tspan[2], step=saveat))
else
saveat_times = collect(saveat)
end

save_idx = 1

while t_current < t_end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For type stability reasons, this while loop should probably be a separate function you call.

rate(rate_cache, u_current, p, t_current)
tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
counts .= pois_rand.(rng, max.(rate_cache * tau, zero(tau)))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, if a particular rate is <= 0 why not just set the count to zero directly and avoid the call to pois_rand for it? That seems a better approach than using max in this way.

du .= 0
if c !== nothing
c(du, u_current, p, t_current, counts, nothing)
else
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
end
u_new = u_current + du
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is allocating... Have separate pre-declatred vectors for u_new and u_current and broadcast here.

if any(<(0), u_new)
# Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3
tau /= 2
continue
end
# Ensure non-negativity, as per Cao et al. (2006), Section 3.3
for i in eachindex(u_new)
u_new[i] = max(u_new[i], 0)
end
t_new = t_current + tau
Comment on lines +266 to +269
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the step rejection if any component of u_new is negative make this unneeded? i.e. at this point don't you know that all the entries in u_new are non-negative?


# Save state if at a saveat time or if saveat is empty
if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, copy(u_new))
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current = u_new
t_current = t_new
end

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error=false,
interp=DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand Down
Loading
Loading