diff --git a/Project.toml b/Project.toml index 46f57c63..4d19bc73 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 8776cc6f..cd9a4474 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -4,7 +4,7 @@ using Reexport: Reexport, @reexport @reexport using DiffEqBase # Explicit imports from standard libraries -using LinearAlgebra: LinearAlgebra, I, mul! +using LinearAlgebra: LinearAlgebra, I, mul!, eigvals using Random: Random, randexp, randexp! # Explicit imports from external packages @@ -20,6 +20,8 @@ using Base.Threads: Threads, @threads using Base.FastMath: add_fast using Setfield: @set, @set! +using SimpleNonlinearSolve + # Import functions we extend from Base import Base: size, getindex, setindex!, length, similar, show, merge!, merge @@ -129,7 +131,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, EnsembleGPUKernel +export SimpleTauLeaping, SimpleAdaptiveTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 7e3a34fd..02573dbb 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -1,5 +1,26 @@ struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end +# Define solver type hierarchy +abstract type AbstractImplicitSolver end +struct NewtonImplicitSolver <: AbstractImplicitSolver end +struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end + +# Adaptive tau-leaping solver +struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter for tau selection + solver::AbstractImplicitSolver # Solver type for implicit method + eigenvalue_check::Bool # Enable eigenvalue-based stiffness detection + stiffness_ratio_threshold::T # # Stiffness ratio threshold + implicit_epsilon_factor::T # Scaling factor for implicit tau-selection +end + +# Stiffness detection uses a dynamic threshold epsilon * sum(u) for propensity ratios, +# as inspired by Cao et al. (2007), Section III.B. Optional eigenvalue-based check +# uses the Jacobian's eigenvalue ratio. implicit_epsilon_factor=10.0 relaxes tau-selection +# for implicit tau-leaping, per Cao et al. (2007), Section III.A. +SimpleAdaptiveTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver(), eigenvalue_check=false, stiffness_ratio_threshold=1e4, implicit_epsilon_factor=10.0) = + SimpleAdaptiveTauLeaping(epsilon, solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor) + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -14,6 +35,20 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) jump_prob.regular_jump !== nothing end +# Validation for adaptive tau-leaping +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping) + if !(jump_prob.aggregator isa PureLeaping) + @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ + JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ + Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." + end + isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + jump_prob.massaction_jump !== nothing +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; seed = nothing, dt = error("dt is required for SimpleTauLeaping.")) validate_pure_leaping_inputs(jump_prob, alg) || @@ -61,6 +96,293 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; interp = DiffEqBase.ConstantInterpolation(t, u)) end +# Compute highest order of reaction (HOR) +# Reference: Cao et al. (2006), J. Chem. Phys. 124, 044109, Section IV +function compute_hor(reactant_stoch, numjumps) + hor = zeros(Int, numjumps) + for j in 1:numjumps + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) + if order > 3 + error("Reaction $j has order $order, which is not supported (maximum order is 3).") + end + hor[j] = order + end + return hor +end + +# Precompute max_hor and max_stoich for g_i calculation +# Reference: Cao et al. (2006), Section IV, equation (27) +function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) + max_hor = zeros(Int, numspecies) + max_stoich = zeros(Int, numspecies) + for j in 1:numjumps + for (spec_idx, stoch) in reactant_stoch[j] + if stoch > 0 + if hor[j] > max_hor[spec_idx] + max_hor[spec_idx] = hor[j] + max_stoich[spec_idx] = stoch + elseif hor[j] == max_hor[spec_idx] + max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) + end + end + end + end + return max_hor, max_stoich +end + +# Compute g_i to bound propensity changes +# Reference: Cao et al. (2006), Section IV, equation (27) +function compute_gi(u, max_hor, max_stoich, i, t) + if max_hor[i] == 0 + return 1.0 + elseif max_hor[i] == 1 + return 1.0 + elseif max_hor[i] == 2 + if max_stoich[i] == 1 + return 2.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 + end + elseif max_hor[i] == 3 + if max_stoich[i] == 1 + return 3.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 + elseif max_stoich[i] == 3 + return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 + end + end + return 1.0 +end + +# Compute tau for explicit tau-leaping +# Reference: Cao et al. (2006), equation (8), using equations (9a) and (9b) +function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + rate(rate_cache, u, p, t) + if all(==(0.0), rate_cache) + return dtmin + end + tau = Inf + for i in 1:length(u) + mu = zero(eltype(rate_cache)) # Equation (9a): mu_i(x) = sum_j nu_ij * a_j(x) + sigma2 = zero(eltype(rate_cache)) # Equation (9b): sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) + for j in 1:size(nu, 2) + mu += nu[i, j] * rate_cache[j] + sigma2 += nu[i, j]^2 * rate_cache[j] + end + gi = compute_gi(u, max_hor, max_stoich, i, t) # Equation (27) + bound = max(epsilon * max(u[i], 0.0) / gi, 1.0) + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) + end + return max(tau, dtmin) +end + +# Compute tau for implicit tau-leaping with relaxed error control +# Reference: Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.A, using relaxed epsilon for larger steps +function compute_tau_implicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps, implicit_epsilon_factor) + tau_explicit = compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + u_predict = float.(copy(u)) # Initialize as Float64 to handle fractional updates + rate(rate_cache, u, p, t) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + u_predict[spec_idx] += nu[spec_idx, j] * rate_cache[j] * tau_explicit + end + end + u_predict = max.(u_predict, 0.0) + + relaxed_epsilon = epsilon * implicit_epsilon_factor + # Reuse compute_tau with predicted state, time, and relaxed epsilon + tau = compute_tau(u_predict, rate_cache, nu, hor, p, t + tau_explicit, relaxed_epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + return max(tau, dtmin) +end + +# Define residual for implicit equation +# Reference: Cao et al. (2004), J. Chem. Phys. 121, 4059 +function implicit_equation!(resid, u_new, params) + u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params + rate(rate_cache, u_new, p, t + tau) + resid .= u_new .- u_current + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + if isa(solver, NewtonImplicitSolver) + resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004) + else # TrapezoidalImplicitSolver + rate_current = similar(rate_cache) + rate(rate_current, u_current, p, t) + resid[spec_idx] -= nu[spec_idx, j] * 0.5 * (rate_cache[j] + rate_current[j]) * tau + end + end + end + resid .= max.(resid, -u_new) # Ensure non-negative solution +end + +# Solve implicit equation +function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) + u_new = float.(copy(u_current)) + prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) + sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) + return sol.u, sol.retcode == ReturnCode.Success +end + +# Compute Jacobian for eigenvalue-based stiffness detection +# Reference: Cao et al. (2007), Section III.B +function compute_jacobian(u, rate, numjumps, numspecies, p, t) + J = zeros(numjumps, numspecies) + rate_cache = zeros(numjumps) + rate(rate_cache, u, p, t) + h = 1e-6 + for i in 1:numspecies + u_plus = float.(copy(u)) + u_plus[i] += h + rate_plus = zeros(numjumps) + rate(rate_plus, u_plus, p, t) + for j in 1:numjumps + J[j, i] = (rate_plus[j] - rate_cache[j]) / h + end + end + return J +end + +# Stiffness detection using propensity ratio or eigenvalues +# Reference: Cao et al. (2007), Section III.B +function is_stiff(rate_cache, u, epsilon, eigenvalue_check, stiffness_ratio_threshold, p, t, rate, numjumps, numspecies) + non_zero_rates = [rate for rate in rate_cache if rate > 0] + if length(non_zero_rates) <= 1 + return false + end + if eigenvalue_check + J = compute_jacobian(u, rate, numjumps, numspecies, p, t) + eigvals = real.(LinearAlgebra.eigvals(J)) + non_zero_eigvals = [abs(λ) for λ in eigvals if abs(λ) > 1e-10] + if length(non_zero_eigvals) <= 1 + return false + end + max_eig = maximum(non_zero_eigvals) + min_eig = minimum(non_zero_eigvals) + return max_eig / min_eig > stiffness_ratio_threshold # Stiffness ratio threshold, Petzold (1983), SIAM J. Sci. Stat. Comput. 4(1), 136–148 + else + max_rate = maximum(non_zero_rates) + min_rate = minimum(non_zero_rates) + threshold = epsilon * sum(u) + return max_rate / min_rate > threshold # Propensity ratio threshold, Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.B + end +end + +# Adaptive tau-leaping solver +# Reference: Cao et al. (2007), Cao et al. (2004), Cao et al. (2006) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) + validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + @unpack prob, rng = jump_prob + (seed !== nothing) && seed!(rng, seed) + + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + rate = (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end + end + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + u_current = copy(u0) + t_current = tspan[1] + usave = [copy(u0)] + tsave = [tspan[1]] + rate_cache = zeros(Float64, numjumps) + counts = zeros(Int64, numjumps) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + solver = alg.solver + eigenvalue_check = alg.eigenvalue_check + stiffness_ratio_threshold = alg.stiffness_ratio_threshold + implicit_epsilon_factor = alg.implicit_epsilon_factor + + nu = zeros(Int64, length(u0), numjumps) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end + end + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + numspecies = length(u0) + + saveat_times = isnothing(saveat) ? Vector{Float64}() : + (saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)) + save_idx = 1 + + while t_current < t_end + rate(rate_cache, u_current, p, t_current) + use_implicit = is_stiff(rate_cache, u_current, epsilon, eigenvalue_check, stiffness_ratio_threshold, p, t_current, rate, numjumps, numspecies) + tau = use_implicit ? + compute_tau_implicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps, implicit_epsilon_factor) : + compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = min(tau, t_end - t_current) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_current + end + + if use_implicit + u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver) + if !converged + tau /= 2 + continue + end + rate(rate_cache, u_new_float, p, t_current + tau) + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2004) + du .= zero(eltype(u_current)) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + if nu[spec_idx, j] != 0 + du[spec_idx] += nu[spec_idx, j] * counts[j] + end + end + end + u_new = u_current + du # Cao et al. (2004) + else + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2006), equation (8) + du .= zero(eltype(u_current)) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + if nu[spec_idx, j] != 0 + du[spec_idx] += nu[spec_idx, j] * counts[j] + end + end + end + u_new = u_current + du + end + + if any(<(0), u_new) + tau /= 2 + continue + end + t_new = t_current + tau + + if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(usave, copy(u_new)) # Ensure integer solutions + push!(tsave, t_new) + if !isempty(saveat_times) && t_new >= saveat_times[save_idx] + save_idx += 1 + end + end + + u_current = u_new + t_current = t_new + end + + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(tsave, usave)) + return sol +end + struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 3ccc6740..4eb4acf5 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -1,28 +1,142 @@ using JumpProcesses, DiffEqBase -using Test, LinearAlgebra +using Test, LinearAlgebra, Statistics using StableRNGs rng = StableRNG(12345) -function regular_rate(out, u, p, t) - out[1] = (0.1 / 1000.0) * u[1] * u[2] - out[2] = 0.01u[2] -end +Nsims = 1000 + +# SIR model with influx +@testset "SIR Model Correctness" begin + β = 0.1 / 1000.0 + ν = 0.01 + influx_rate = 1.0 + p = (β, ν, influx_rate) + + # ConstantRateJump formulation for SSAStepper + rate1(u, p, t) = p[1] * u[1] * u[2] # β*S*I (infection) + rate2(u, p, t) = p[2] * u[2] # ν*I (recovery) + rate3(u, p, t) = p[3] # influx_rate (S influx) + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[1] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) + + u0 = [999.0, 10.0, 0.0] # S, I, R + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng) + + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # RegularJump formulation for SimpleTauLeaping + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[2] + out[2] = p[2] * u[2] + out[3] = p[3] + end + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0 + dc[1] = -counts[1] + counts[3] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] + end + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng) + + # Solve with SimpleTauLeaping + sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) -const dc = zeros(3, 2) -dc[1, 1] = -1 -dc[2, 1] = 1 -dc[2, 2] = -1 -dc[3, 2] = 1 + # MassActionJump formulation for SimpleAdaptiveTauLeaping + reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) -function regular_c(du, u, p, t, counts, mark) - mul!(du, dc, counts) + # Solve with SimpleAdaptiveTauLeaping + sol_adaptive_newton = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=NewtonImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + sol_adaptive_trapezoidal = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=TrapezoidalImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # Compute mean infected (I) trajectories + t_points = 0:1.0:250.0 + max_direct_I = maximum([mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_direct_I = maximum([mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_adaptive_newton = maximum([mean(sol_adaptive_newton[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_adaptive_trapezoidal = maximum([mean(sol_adaptive_trapezoidal[i](t)[2] for i in 1:Nsims) for t in t_points]) + + # Test mean infected trajectories + @test isapprox(max_direct_I, max_direct_I, rtol=0.05) + @test isapprox(max_direct_I, max_adaptive_newton, rtol=0.05) + @test isapprox(max_direct_I, max_adaptive_trapezoidal, rtol=0.05) end -rj = RegularJump(regular_rate, regular_c, 2) -jumps = JumpSet(rj) -prob = DiscreteProblem([999, 1, 0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, PureLeaping(), rj; rng) -sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) +# SEIR model with exposed compartment +@testset "SEIR Model Correctness" begin + β = 0.3 / 1000.0 + σ = 0.2 + ν = 0.01 + p = (β, σ, ν) + + # ConstantRateJump formulation for SSAStepper + rate1(u, p, t) = p[1] * u[1] * u[3] # β*S*I (infection) + rate2(u, p, t) = p[2] * u[2] # σ*E (progression) + rate3(u, p, t) = p[3] * u[3] # ν*I (recovery) + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) + + u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng) + + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # RegularJump formulation for SimpleTauLeaping + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[3] + out[2] = p[2] * u[2] + out[3] = p[3] * u[3] + end + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] - counts[3] + dc[4] = counts[3] + end + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng) + + # Solve with SimpleTauLeaping + sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) + + # MassActionJump formulation for SimpleAdaptiveTauLeaping + reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) + + # Solve with SimpleAdaptiveTauLeaping + sol_adaptive_newton = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=NewtonImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + sol_adaptive_trapezoidal = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=TrapezoidalImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + + # Compute mean infected (I) trajectories + t_points = 0:1.0:250.0 + max_direct_I = maximum([mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_direct_I = maximum([mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_adaptive_newton = maximum([mean(sol_adaptive_newton[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_adaptive_trapezoidal = maximum([mean(sol_adaptive_trapezoidal[i](t)[2] for i in 1:Nsims) for t in t_points]) + + # Test mean infected trajectories + @test isapprox(max_direct_I, max_direct_I, rtol=0.05) + @test isapprox(max_direct_I, max_adaptive_newton, rtol=0.05) + @test isapprox(max_direct_I, max_adaptive_trapezoidal, rtol=0.05) +end # Test PureLeaping aggregator functionality @testset "PureLeaping Aggregator Tests" begin