Skip to content

Jac #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
GPUifyLoops = "ba82f77b-6841-5d2e-bd9f-4daf811aec27"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
julia = "1"
Expand Down
167 changes: 161 additions & 6 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,51 @@
module DiffEqGPU

using GPUifyLoops, CuArrays, CUDAnative, DiffEqBase
using GPUifyLoops, CuArrays, CUDAnative, DiffEqBase, LinearAlgebra

function gpu_kernel(f,du,u,p,t)
@loop for i in (1:size(u,2); (blockIdx().x-1) * blockDim().x + threadIdx().x)
@views @inbounds f(du[:,i],u[:,i],p,t)
@views @inbounds f(du[:,i],u[:,i],p[:,i],t)
nothing
end
nothing
end

function jac_kernel(f,J,u,p,t)
@loop for i in (0:(size(u,2)-1); (blockIdx().x-1) * blockDim().x + threadIdx().x)
section = 1 + (i*size(u,1)) : ((i+1)*size(u,1))
@views @inbounds f(J[section,section],u[:,i+1],p[:,i+1],t)
nothing
end
nothing
end

function discrete_condition_kernel(condition,cur,u,t,p)
@loop for i in (1:size(u,2); (blockIdx().x-1) * blockDim().x + threadIdx().x)
@views @inbounds cur[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
nothing
end
nothing
end

function discrete_affect!_kernel(affect!,cur,u,t,p)
@loop for i in (1:size(u,2); (blockIdx().x-1) * blockDim().x + threadIdx().x)
@views @inbounds cur[i] && affect!(FakeIntegrator(u[:,i],t,p[:,i]))
nothing
end
nothing
end

function continuous_condition_kernel(condition,out,u,t,p)
@loop for i in (1:size(u,2); (blockIdx().x-1) * blockDim().x + threadIdx().x)
@views @inbounds out[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
nothing
end
nothing
end

function continuous_affect!_kernel(affect!,event_idx,u,t,p)
@loop for i in ((event_idx,); (blockIdx().x-1) * blockDim().x + threadIdx().x)
@views @inbounds affect!(FakeIntegrator(u[:,i],t,p[:,i]))
nothing
end
nothing
Expand All @@ -15,6 +57,12 @@ function GPUifyLoops.launch_config(::typeof(gpu_kernel),maxthreads,context,g,f,d
(threads=t,blocks=blocks)
end

struct FakeIntegrator{uType,tType,P}
u::uType
t::tType
p::P
end

abstract type EnsembleArrayAlgorithm <: DiffEqBase.EnsembleAlgorithm end
struct EnsembleCPUArray <: EnsembleArrayAlgorithm end
struct EnsembleGPUArray <: EnsembleArrayAlgorithm end
Expand Down Expand Up @@ -62,14 +110,121 @@ function batch_solve(ensembleprob,alg,ensemblealg,I;kwargs...)
end
end

prob = ODEProblem(_f,u0,probs[1].tspan,p)
sol = solve(prob,alg; kwargs...)
if DiffEqBase.has_jac(probs[1].f)
_jac = let jac=probs[1].f.jac
function (J,u,p,t)
version = u isa CuArray ? CUDA() : CPU()
@launch version jac_kernel(jac,J,u,p,t)
end
end
else
_jac = nothing
end

if probs[1].f.colorvec !== nothing
colorvec = repeat(probs[1].f.colorvec,length(I))
else
colorvec = repeat(1:length(probs[1].u0),length(I))
end

if :callback ∉ keys(probs[1].kwargs)
_callback = nothing
elseif probs[1].kwargs[:callback] isa DiscreteCallback
if ensemblealg isa EnsembleGPUArray
cur = CuArray([false for i in 1:length(probs)])
else
cur = [false for i in 1:length(probs)]
end
_condition = probs[1].kwargs[:callback].condition
_affect! = probs[1].kwargs[:callback].affect!

condition = function (u,t,integrator)
version = u isa CuArray ? CUDA() : CPU()
@launch version discrete_condition_kernel(_condition,cur,u,t,integrator.p)
any(cur)
end

affect! = function (integrator)
version = integrator.u isa CuArray ? CUDA() : CPU()
@launch version discrete_affect!_kernel(_affect!,cur,integrator.u,integrator.t,integrator.p)
end

_callback = DiscreteCallback(condition,affect!,save_positions=probs[1].kwargs[:callback].save_positions)
elseif probs[1].kwargs[:callback] isa ContinuousCallback
_condition = probs[1].kwargs[:callback].condition
_affect! = probs[1].kwargs[:callback].affect!
_affect_neg! = probs[1].kwargs[:callback].affect_neg!

condition = function (out,u,t,integrator)
version = u isa CuArray ? CUDA() : CPU()
@launch version continuous_condition_kernel(_condition,out,u,t,integrator.p)
nothing
end

affect! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDA() : CPU()
@launch version continuous_affect!_kernel(_affect!,event_idx,integrator.u,integrator.t,integrator.p)
end

affect_neg! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDA() : CPU()
@launch version continuous_affect!_kernel(_affect_neg!,event_idx,integrator.u,integrator.t,integrator.p)
end

_callback = VectorContinuousCallback(condition,affect!,affect_neg!,length(probs),save_positions=probs[1].kwargs[:callback].save_positions)
end

f_func = ODEFunction(_f,jac=_jac,colorvec=colorvec)
prob = ODEProblem(f_func,u0,probs[1].tspan,p;
probs[1].kwargs...)
sol = solve(prob,alg; callback = _callback, kwargs...)

us = Array.(sol.u)
solus = [[us[i][:,j] for i in 1:length(us)] for j in 1:length(probs)]
[DiffEqBase.build_solution(probs[i],alg,sol.t,solus[i]) for i in 1:length(probs)]
[DiffEqBase.build_solution(probs[i],alg,sol.t,solus[i],destats=sol.destats,retcode=sol.retcode) for i in 1:length(probs)]
end

### GPU Factorization

mutable struct LinSolveGPUSplitFactorize{T}
facts::Array{CuArrays.CUSOLVER.CuQR{T,CuArray{T,2}}}
len::Int
end
LinSolveGPUSplitFactorize() = LinSolveGPUSplitFactorize(CuArrays.CUSOLVER.CuQR{Float32,CuArray{Float32,2}}[],0)

function (p::LinSolveGPUSplitFactorize)(x,A,b,update_matrix=false;kwargs...)
version = b isa CuArray ? CUDA() : CPU()
if update_matrix
@launch version qr_kernel(p.facts,A)
end
if typeof(p.A) <: SuiteSparse.UMFPACK.UmfpackLU || typeof(p.factorization) <: typeof(lu)
ldiv!(x,p.A,b) # No 2-arg form for SparseArrays!
else
x .= b
@launch version ldiv!_kernel(p.facts,x,p.len)
end
end
function (p::LinSolveGPUSplitFactorize)(::Type{Val{:init}},f,u0_prototype)
LinSolveGPUSplitFactorize(Array{CuArrays.CUSOLVER.CuQR{eltype(u0_prototype),CuArray{eltype(u0_prototype),2}}}(undef,size(u0_prototype,2)),size(u0_prototype,1))
end

function qr_kernel(facts,W,len)
@loop for i in (0:length(facts)-1; (blockIdx().x-1) * blockDim().x + threadIdx().x)
section = 1 + (i*len) : ((i+1)*len)
facts[i] = qr!(W[])
nothing
end
nothing
end

function ldiv!_kernel(facts,W,len)
@loop for i in (0:length(facts)-1; (blockIdx().x-1) * blockDim().x + threadIdx().x)
@views ldiv!(facts[i],x[(i*len+1):((i+1)*len)])
nothing
end
nothing
end

export EnsembleCPUArray, EnsembleGPUArray
export EnsembleCPUArray, EnsembleGPUArray, LinSolveGPUSplitFactorize

end # module
61 changes: 60 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ u0 = Float32[1.0;0.0;0.0]
tspan = (0.0f0,100.0f0)
p = (10.0f0,28.0f0,8/3f0)
prob = ODEProblem(lorenz,u0,tspan,p)
prob_func = (prob,i,repeat) -> remake(prob,p=rand(Float32,3).*p)
const pre_p = [rand(Float32,3) for i in 1:100_000]
prob_func = (prob,i,repeat) -> remake(prob,p=pre_p[i].*p)
monteprob = EnsembleProblem(prob, prob_func = prob_func)

#Performance check with nvvp
Expand All @@ -25,3 +26,61 @@ monteprob = EnsembleProblem(prob, prob_func = prob_func)
@time solve(monteprob,Tsit5(),EnsembleCPUArray(),trajectories=100_000,saveat=1.0f0)
@time solve(monteprob,Tsit5(),EnsembleThreads(), trajectories=100_000,saveat=1.0f0)
@time solve(monteprob,Tsit5(),EnsembleSerial(), trajectories=100_000,saveat=1.0f0)


solve(monteprob,TRBDF2(),EnsembleCPUArray(),dt=0.1,trajectories=2,saveat=1.0f0)
solve(monteprob,TRBDF2(),EnsembleGPUArray(),dt=0.1,trajectories=2,saveat=1.0f0)
@test_broken solve(monteprob,TRBDF2(linsolve=LinSolveGPUSplitFactorize()),EnsembleGPUArray(),dt=0.1,trajectories=2,saveat=1.0f0)

function lorenz_jac(du,u,p,t)
@inbounds begin
σ = p[1]
ρ = p[2]
β = p[3]
x = u[1]
y = u[2]
z = u[3]
du[1,1] = -σ
du[2,1] = ρ - z
du[3,1] = y
du[1,2] = σ
du[2,2] = -1
du[3,2] = x
du[1,3] = 0
du[2,3] = -x
du[3,3] = -β
end
nothing
end

func = ODEFunction(lorenz,jac=lorenz_jac)
prob_jac = ODEProblem(func,u0,tspan,p)
monteprob_jac = EnsembleProblem(prob_jac, prob_func = prob_func)

@time solve(monteprob_jac,TRBDF2(),EnsembleCPUArray(),dt=0.1,trajectories=2,saveat=1.0f0)
@time solve(monteprob_jac,TRBDF2(),EnsembleGPUArray(),dt=0.1,trajectories=100,saveat=1.0f0)

condition = function (u,t,integrator)
@inbounds u[1] > 5
end

affect! = function (integrator)
@inbounds integrator.u[1] = -4
end

callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=DiscreteCallback(condition,affect!,save_positions=(false,false)))
callback_monteprob = EnsembleProblem(callback_prob, prob_func = prob_func)
solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=100,saveat=1.0f0)

c_condition = function (u,t,integrator)
@inbounds u[1] - 3
end

c_affect! = function (integrator)
@inbounds integrator.u[1] += 20
end

callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=ContinuousCallback(c_condition,c_affect!,save_positions=(false,false)))
callback_monteprob = EnsembleProblem(callback_prob, prob_func = prob_func)
CuArrays.@allowscalar solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=2,saveat=1.0f0)
@test_broken solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=10,saveat=1.0f0)[1].retcode == :Success