diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 8776cc6f..22c51c09 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -129,7 +129,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, EnsembleGPUKernel +export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 7e3a34fd..78b2a14c 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -1,5 +1,11 @@ struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end +struct SimpleExplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter +end + +SimpleExplicitTauLeaping(; epsilon=0.05) = SimpleExplicitTauLeaping(epsilon) + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -14,6 +20,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) jump_prob.regular_jump !== nothing end +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping) + if !(jump_prob.aggregator isa PureLeaping) + @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ + JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ + Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." + end + isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + jump_prob.massaction_jump !== nothing +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; seed = nothing, dt = error("dt is required for SimpleTauLeaping.")) validate_pure_leaping_inputs(jump_prob, alg) || @@ -61,6 +80,213 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; interp = DiffEqBase.ConstantInterpolation(t, u)) end +function compute_hor(reactant_stoch, numjumps) + # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. + # HOR is the sum of stoichiometric coefficients of reactants in reaction j. + hor = zeros(Int, numjumps) + for j in 1:numjumps + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) + if order > 3 + error("Reaction $j has order $order, which is not supported (maximum order is 3).") + end + hor[j] = order + end + return hor +end + +function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) + # Precompute reaction conditions for each species i, including: + # - max_hor: the highest order of reaction (HOR) where species i is a reactant. + # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. + # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). + max_hor = zeros(Int, numspecies) + max_stoich = zeros(Int, numspecies) + for j in 1:numjumps + for (spec_idx, stoch) in reactant_stoch[j] + if stoch > 0 # Species is a reactant + if hor[j] > max_hor[spec_idx] + max_hor[spec_idx] = hor[j] + max_stoich[spec_idx] = stoch + elseif hor[j] == max_hor[spec_idx] + max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) + end + end + end + end + return max_hor, max_stoich +end + +function compute_gi(u, max_hor, max_stoich, i, t) + # Compute g_i for species i to bound the relative change in propensity functions, + # as per Cao et al. (2006), Section IV, equation (27). + # g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant: + # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 + # - HOR = 2 (second-order): + # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 + # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) + # - HOR = 3 (third-order): + # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 + # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) + # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) + # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. + if max_hor[i] == 0 # No reactions involve species i as a reactant + return 1.0 + elseif max_hor[i] == 1 + return 1.0 + elseif max_hor[i] == 2 + if max_stoich[i] == 1 + return 2.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1 + end + elseif max_hor[i] == 3 + if max_stoich[i] == 1 + return 3.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1 + elseif max_stoich[i] == 3 + return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2 + end + end + return 1.0 # Default case +end + +function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): + # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } + # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): + # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) + # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). + rate(rate_cache, u, p, t) + if all(==(0.0), rate_cache) # Handle case where all rates are zero + return dtmin + end + tau = Inf + for i in 1:length(u) + mu = zero(eltype(u)) + sigma2 = zero(eltype(u)) + for j in 1:size(nu, 2) + mu += nu[i, j] * rate_cache[j] # Equation (9a) + sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) + end + gi = compute_gi(u, max_hor, max_stoich, i, t) + bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) + end + return max(tau, dtmin) +end + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; + seed = nothing, + dtmin = 1e-10, + saveat = nothing) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + @unpack prob, rng = jump_prob + (seed !== nothing) && seed!(rng, seed) + + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + rj = jump_prob.regular_jump + # Extract rates + rate = rj !== nothing ? rj.rate : + (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end + end + c = rj !== nothing ? rj.c : nothing + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + # Initialize current state and saved history + u_current = copy(u0) + t_current = tspan[1] + usave = [copy(u0)] + tsave = [tspan[1]] + rate_cache = zeros(float(eltype(u0)), numjumps) + counts = zero(rate_cache) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + + # Extract net stoichiometry for state updates + nu = zeros(float(eltype(u0)), length(u0), numjumps) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end + end + # Extract reactant stoichiometry for hor and gi + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + + # Set up saveat_times + saveat_times = nothing + if isnothing(saveat) + saveat_times = Vector{typeof(tspan[1])}() + elseif saveat isa Number + saveat_times = collect(range(tspan[1], tspan[2], step=saveat)) + else + saveat_times = collect(saveat) + end + + save_idx = 1 + + while t_current < t_end + rate(rate_cache, u_current, p, t_current) + tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = min(tau, t_end - t_current) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_current + end + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) + du .= 0 + if c !== nothing + c(du, u_current, p, t_current, counts, nothing) + else + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] + end + end + end + u_new = u_current + du + if any(<(0), u_new) + # Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3 + tau /= 2 + continue + end + # Ensure non-negativity, as per Cao et al. (2006), Section 3.3 + for i in eachindex(u_new) + u_new[i] = max(u_new[i], 0) + end + t_new = t_current + tau + + # Save state if at a saveat time or if saveat is empty + if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(usave, copy(u_new)) + push!(tsave, t_new) + if !isempty(saveat_times) && t_new >= saveat_times[save_idx] + save_idx += 1 + end + end + + u_current = u_new + t_current = t_new + end + + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(tsave, usave)) + return sol +end + struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 3ccc6740..8db0566a 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -1,28 +1,154 @@ using JumpProcesses, DiffEqBase -using Test, LinearAlgebra +using Test, LinearAlgebra, Statistics using StableRNGs rng = StableRNG(12345) -function regular_rate(out, u, p, t) - out[1] = (0.1 / 1000.0) * u[1] * u[2] - out[2] = 0.01u[2] +Nsims = 1000 + +# SIR model with influx +@testset "SIR Model Correctness" begin + β = 0.1 / 1000.0 + ν = 0.01 + influx_rate = 1.0 + p = (β, ν, influx_rate) + + # ConstantRateJump formulation for SSAStepper + rate1(u, p, t) = p[1] * u[1] * u[2] # β*S*I (infection) + rate2(u, p, t) = p[2] * u[2] # ν*I (recovery) + rate3(u, p, t) = p[3] # influx_rate (S influx) + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[1] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) + + u0 = [999.0, 10.0, 0.0] # S, I, R + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng) + + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # RegularJump formulation for SimpleTauLeaping + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[2] + out[2] = p[2] * u[2] + out[3] = p[3] + end + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0 + dc[1] = -counts[1] + counts[3] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] + end + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng) + + # Solve with SimpleTauLeaping + sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) + + # MassActionJump formulation for SimpleExplicitTauLeaping + reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) + + # Solve with SimpleExplicitTauLeaping + sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # Compute mean infected (I) trajectories + t_points = 0:1.0:250.0 + max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_simple_I = maximum([mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_explicit_I = maximum([mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_simple_I, rtol=0.05) + @test isapprox(max_direct_I, max_explicit_I, rtol=0.05) end -const dc = zeros(3, 2) -dc[1, 1] = -1 -dc[2, 1] = 1 -dc[2, 2] = -1 -dc[3, 2] = 1 +# SEIR model with exposed compartment +@testset "SEIR Model Correctness" begin + β = 0.3 / 1000.0 + σ = 0.2 + ν = 0.01 + p = (β, σ, ν) + + # ConstantRateJump formulation for SSAStepper + rate1(u, p, t) = p[1] * u[1] * u[3] # β*S*I (infection) + rate2(u, p, t) = p[2] * u[2] # σ*E (progression) + rate3(u, p, t) = p[3] * u[3] # ν*I (recovery) + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) -function regular_c(du, u, p, t, counts, mark) - mul!(du, dc, counts) + u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng) + + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # RegularJump formulation for SimpleTauLeaping + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[3] + out[2] = p[2] * u[2] + out[3] = p[3] * u[3] + end + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] - counts[3] + dc[4] = counts[3] + end + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng) + + # Solve with SimpleTauLeaping + sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) + + # MassActionJump formulation for SimpleExplicitTauLeaping + reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) + + # Solve with SimpleExplicitTauLeaping + sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # Compute mean infected (I) trajectories + t_points = 0:1.0:250.0 + max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_simple_I = maximum([mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_explicit_I = maximum([mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_simple_I, rtol=0.05) + @test isapprox(max_direct_I, max_explicit_I, rtol=0.05) end -rj = RegularJump(regular_rate, regular_c, 2) -jumps = JumpSet(rj) -prob = DiscreteProblem([999, 1, 0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, PureLeaping(), rj; rng) -sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) +# Test zero-rate case for SimpleExplicitTauLeaping +@testset "Zero Rates Test for SimpleExplicitTauLeaping" begin + # SIR model: S + I -> 2I, I -> R + reactant_stoch = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoch = [[1=>-1, 2=>1], [2=>-1, 3=>1], []] + rates = [0.1/1000, 0.05, 0.0] # beta/N, gamma, dummy rate for empty reaction + maj = MassActionJump(rates, reactant_stoch, net_stoch) + u0 = [0, 0, 0] # All populations zero + tspan = (0.0, 250.0) + prob = DiscreteProblem(u0, tspan) + jump_prob = JumpProblem(prob, PureLeaping(), maj) + + sol = solve(jump_prob, SimpleExplicitTauLeaping(); dtmin = 0.1, saveat=1.0) + + # Check that solution completes and covers tspan + @test sol.t[end] ≈ 250.0 atol=1e-6 + # Check that state remains zero + @test all(u == [0, 0, 0] for u in sol.u) +end # Test PureLeaping aggregator functionality @testset "PureLeaping Aggregator Tests" begin