|
2 | 2 | # API methods # |
3 | 3 | ############### |
4 | 4 |
|
5 | | -function gradient{F}(f::F, x, cfg::AbstractConfig = GradientConfig(x)) |
| 5 | +@compat const AllowedGradientConfig{F,M} = Union{GradientConfig{Tag{F,M}}, GradientConfig{Tag{Void,M}}} |
| 6 | + |
| 7 | +gradient(f, x, cfg::GradientConfig) = throw(ConfigMismatchError(f, cfg)) |
| 8 | +gradient!(out, f, x, cfg::GradientConfig) = throw(ConfigMismatchError(f, cfg)) |
| 9 | + |
| 10 | +function gradient{F,M}(f::F, x, cfg::AllowedGradientConfig{F,M} = GradientConfig(f, x)) |
6 | 11 | if chunksize(cfg) == length(x) |
7 | 12 | return vector_mode_gradient(f, x, cfg) |
8 | 13 | else |
9 | 14 | return chunk_mode_gradient(f, x, cfg) |
10 | 15 | end |
11 | 16 | end |
12 | 17 |
|
13 | | -function gradient!{F}(out, f::F, x, cfg::AbstractConfig = GradientConfig(x)) |
| 18 | +function gradient!{F,M}(out, f::F, x, cfg::AllowedGradientConfig{F,M} = GradientConfig(f, x)) |
14 | 19 | if chunksize(cfg) == length(x) |
15 | 20 | vector_mode_gradient!(out, f, x, cfg) |
16 | 21 | else |
|
72 | 77 | # chunk mode # |
73 | 78 | ############## |
74 | 79 |
|
75 | | -# single threaded # |
76 | | -#-----------------# |
77 | | - |
78 | 80 | function chunk_mode_gradient_expr(out_definition::Expr) |
79 | 81 | return quote |
80 | 82 | @assert length(x) >= N "chunk size cannot be greater than length(x) ($(N) > $(length(x)))" |
@@ -119,80 +121,10 @@ function chunk_mode_gradient_expr(out_definition::Expr) |
119 | 121 | end |
120 | 122 | end |
121 | 123 |
|
122 | | -@eval function chunk_mode_gradient{F,N}(f::F, x, cfg::GradientConfig{N}) |
| 124 | +@eval function chunk_mode_gradient{F,T,V,N}(f::F, x, cfg::GradientConfig{T,V,N}) |
123 | 125 | $(chunk_mode_gradient_expr(:(out = similar(x, valtype(ydual))))) |
124 | 126 | end |
125 | 127 |
|
126 | | -@eval function chunk_mode_gradient!{F,N}(out, f::F, x, cfg::GradientConfig{N}) |
| 128 | +@eval function chunk_mode_gradient!{F,T,V,N}(out, f::F, x, cfg::GradientConfig{T,V,N}) |
127 | 129 | $(chunk_mode_gradient_expr(:())) |
128 | 130 | end |
129 | | - |
130 | | -# multithreaded # |
131 | | -#---------------# |
132 | | - |
133 | | -if IS_MULTITHREADED_JULIA |
134 | | - function multithread_chunk_mode_expr(out_definition::Expr) |
135 | | - return quote |
136 | | - cfg = gradient_config(multi_cfg) |
137 | | - N = chunksize(cfg) |
138 | | - @assert length(x) >= N "chunk size cannot be greater than length(x) ($(N) > $(length(x)))" |
139 | | - |
140 | | - # precalculate loop bounds |
141 | | - xlen = length(x) |
142 | | - remainder = xlen % N |
143 | | - lastchunksize = ifelse(remainder == 0, N, remainder) |
144 | | - lastchunkindex = xlen - lastchunksize + 1 |
145 | | - middlechunks = 2:div(xlen - lastchunksize, N) |
146 | | - |
147 | | - # fetch and seed work vectors |
148 | | - current_cfg = cfg[compat_threadid()] |
149 | | - current_xdual = current_cfg.duals |
150 | | - current_seeds = current_cfg.seeds |
151 | | - |
152 | | - Base.Threads.@threads for t in 1:length(cfg) |
153 | | - seed!(cfg[t].duals, x) |
154 | | - end |
155 | | - |
156 | | - # do first chunk manually to calculate output type |
157 | | - seed!(current_xdual, x, 1, current_seeds) |
158 | | - current_ydual = f(current_xdual) |
159 | | - $(out_definition) |
160 | | - extract_gradient_chunk!(out, current_ydual, 1, N) |
161 | | - seed!(current_xdual, x, 1) |
162 | | - |
163 | | - # do middle chunks |
164 | | - Base.Threads.@threads for c in middlechunks |
165 | | - # see https://github.com/JuliaLang/julia/issues/14948 |
166 | | - local chunk_cfg = cfg[compat_threadid()] |
167 | | - local chunk_xdual = chunk_cfg.duals |
168 | | - local chunk_seeds = chunk_cfg.seeds |
169 | | - local chunk_index = ((c - 1) * N + 1) |
170 | | - seed!(chunk_xdual, x, chunk_index, chunk_seeds) |
171 | | - local chunk_dual = f(chunk_xdual) |
172 | | - extract_gradient_chunk!(out, chunk_dual, chunk_index, N) |
173 | | - seed!(chunk_xdual, x, chunk_index) |
174 | | - end |
175 | | - |
176 | | - # do final chunk |
177 | | - seed!(current_xdual, x, lastchunkindex, current_seeds, lastchunksize) |
178 | | - current_ydual = f(current_xdual) |
179 | | - extract_gradient_chunk!(out, current_ydual, lastchunkindex, lastchunksize) |
180 | | - |
181 | | - # load value, this is a no-op unless `out` is a DiffResult |
182 | | - extract_value!(out, current_ydual) |
183 | | - |
184 | | - return out |
185 | | - end |
186 | | - end |
187 | | - |
188 | | - @eval function chunk_mode_gradient{F}(f::F, x, multi_cfg::MultithreadConfig) |
189 | | - $(multithread_chunk_mode_expr(:(out = similar(x, valtype(current_ydual))))) |
190 | | - end |
191 | | - |
192 | | - @eval function chunk_mode_gradient!{F}(out, f::F, x, multi_cfg::MultithreadConfig) |
193 | | - $(multithread_chunk_mode_expr(:())) |
194 | | - end |
195 | | -else |
196 | | - chunk_mode_gradient(f, x, cfg::Tuple) = error("Multithreading is not enabled for this Julia installation.") |
197 | | - chunk_mode_gradient!(out, f, x, cfg::Tuple) = chunk_mode_gradient!(f, x, cfg) |
198 | | -end |
0 commit comments