Skip to content

Commit 39fa5fb

Browse files
Merge pull request #827 from SciML/ap/refactor
feat: make MLUtils into a weakdep & suppport MLDataDevices
2 parents 904cac0 + 1f4cba3 commit 39fa5fb

File tree

5 files changed

+78
-6
lines changed

5 files changed

+78
-6
lines changed
Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
name = "OptimizationOptimisers"
22
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
7-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
87
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
98
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
109
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1110
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1211
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1312

13+
[weakdeps]
14+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
15+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
16+
17+
[extensions]
18+
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
19+
OptimizationOptimisersMLUtilsExt = "MLUtils"
20+
1421
[compat]
22+
MLDataDevices = "1.1"
1523
MLUtils = "0.4.4"
1624
Optimisers = "0.2, 0.3"
1725
Optimization = "4"
@@ -20,9 +28,14 @@ Reexport = "1.2"
2028
julia = "1"
2129

2230
[extras]
31+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2332
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
33+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
34+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
35+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
36+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2437
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2538
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2639

2740
[targets]
28-
test = ["ForwardDiff", "Test", "Zygote"]
41+
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationOptimisersMLDataDevicesExt
2+
3+
using MLDataDevices
4+
using OptimizationOptimisers
5+
6+
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true)
7+
8+
end
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationOptimisersMLUtilsExt
2+
3+
using MLUtils
4+
using OptimizationOptimisers
5+
6+
OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true
7+
8+
end

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OptimizationOptimisers
22

33
using Reexport, Printf, ProgressLogging
44
@reexport using Optimisers, Optimization
5-
using Optimization.SciMLBase, MLUtils
5+
using Optimization.SciMLBase
66

77
SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
88
SciMLBase.requiresgradient(opt::AbstractRule) = true
@@ -16,6 +16,8 @@ function SciMLBase.__init(
1616
kwargs...)
1717
end
1818

19+
isa_dataiterator(data) = false
20+
1921
function SciMLBase.__solve(cache::OptimizationCache{
2022
F,
2123
RC,
@@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
5759
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
5860
end
5961

60-
if cache.p isa MLUtils.DataLoader
62+
if isa_dataiterator(cache.p)
6163
data = cache.p
6264
dataiterate = true
6365
else
6466
data = [cache.p]
6567
dataiterate = false
6668
end
69+
6770
opt = cache.opt
6871
θ = copy(cache.u0)
6972
G = copy(θ)
@@ -114,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
114117
opt = min_opt
115118
x = min_err
116119
θ = min_θ
117-
cache.f.grad(G, θ, d...)
120+
cache.f.grad(G, θ, d)
118121
opt_state = Optimization.OptimizationState(iter = i,
119122
u = θ,
120123
objective = x[1],

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,43 @@ using Zygote
6868

6969
@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
7070
end
71+
72+
@testset "Minibatching" begin
73+
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays
74+
75+
x = rand(10000)
76+
y = sin.(x)
77+
data = MLUtils.DataLoader((x, y), batchsize = 100)
78+
79+
# Define the neural network
80+
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
81+
ps, st = Lux.setup(Random.default_rng(), model)
82+
ps_ca = ComponentArray(ps)
83+
smodel = StatefulLuxLayer{true}(model, nothing, st)
84+
85+
function callback(state, l)
86+
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
87+
return l < 1e-4
88+
end
89+
90+
function loss(ps, data)
91+
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
92+
return sum(abs2, ypred .- data[2])
93+
end
94+
95+
optf = OptimizationFunction(loss, AutoZygote())
96+
prob = OptimizationProblem(optf, ps_ca, data)
97+
98+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)
99+
100+
@test res.objective < 1e-4
101+
102+
using MLDataDevices
103+
data = CPUDevice()(data)
104+
optf = OptimizationFunction(loss, AutoZygote())
105+
prob = OptimizationProblem(optf, ps_ca, data)
106+
107+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)
108+
109+
@test res.objective < 1e-4
110+
end

0 commit comments

Comments
 (0)