Skip to content

Commit 304104c

Browse files
Add minibatching tests
1 parent 8d7cd3a commit 304104c

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1212

13-
[extensions]
14-
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
15-
OptimizationOptimisersMLUtilsExt = "MLUtils"
16-
1713
[weakdeps]
1814
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1915
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2016

17+
[extensions]
18+
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
19+
OptimizationOptimisersMLUtilsExt = "MLUtils"
20+
2121
[compat]
2222
MLDataDevices = "1.1"
2323
MLUtils = "0.4.4"

lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ module OptimizationOptimisersMLDataDevicesExt
33
using MLDataDevices
44
using OptimizationOptimisers
55

6-
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true
6+
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true)
77

88
end

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
117117
opt = min_opt
118118
x = min_err
119119
θ = min_θ
120-
cache.f.grad(G, θ, d...)
120+
cache.f.grad(G, θ, d)
121121
opt_state = Optimization.OptimizationState(iter = i,
122122
u = θ,
123123
objective = x[1],

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,43 @@ using Zygote
6868

6969
@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
7070
end
71+
72+
@testset "Minibatching" begin
73+
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays
74+
75+
x = rand(10000)
76+
y = sin.(x)
77+
data = MLUtils.DataLoader((x, y), batchsize = 100)
78+
79+
# Define the neural network
80+
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
81+
ps, st = Lux.setup(Random.default_rng(), model)
82+
ps_ca = ComponentArray(ps)
83+
smodel = StatefulLuxLayer{true}(model, nothing, st)
84+
85+
function callback(state, l)
86+
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
87+
return l < 1e-4
88+
end
89+
90+
function loss(ps, data)
91+
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
92+
return sum(abs2, ypred .- data[2])
93+
end
94+
95+
optf = OptimizationFunction(loss, AutoZygote())
96+
prob = OptimizationProblem(optf, ps_ca, data)
97+
98+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)
99+
100+
@test res.objective < 1e-4
101+
102+
using MLDataDevices
103+
data = CPUDevice()(data)
104+
optf = OptimizationFunction(loss, AutoZygote())
105+
prob = OptimizationProblem(optf, ps_ca, data)
106+
107+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)
108+
109+
@test res.objective < 1e-4
110+
end

0 commit comments

Comments
 (0)