@@ -198,112 +198,126 @@ end
198
198
199
199
# ## Operator Forms
200
200
201
- struct JacVec {F, T1, T2, xType}
201
+ struct FwdModeAutoDiffVecProd {F,U,C,V,V!} <: AbstractAutoDiffVecProd
202
202
f:: F
203
- cache1 :: T1
204
- cache2 :: T2
205
- x :: xType
206
- autodiff :: Bool
203
+ u :: U
204
+ cache :: C
205
+ vecprod :: V
206
+ vecprod! :: V!
207
207
end
208
208
209
- function JacVec (f, x:: AbstractArray , tag = DeivVecTag (); autodiff = true )
210
- if autodiff
211
- cache1 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
212
- }. (x, ForwardDiff. Partials .(tuple .(x)))
213
- cache2 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
214
- }. (x, ForwardDiff. Partials .(tuple .(x)))
215
- else
216
- cache1 = similar (x)
217
- cache2 = similar (x)
218
- end
219
- JacVec (f, cache1, cache2, x, autodiff)
209
+ function update_coefficients (L:: FwdModeAutoDiffVecProd , u, p, t)
210
+ FwdModeAutoDiffVecProd (L. f, u, L. vecprod, L. vecprod!, L. cache)
220
211
end
221
212
222
- Base. eltype (L:: JacVec ) = eltype (L. x)
223
- Base. size (L:: JacVec ) = (length (L. cache1), length (L. cache1))
224
- Base. size (L:: JacVec , i:: Int ) = length (L. cache1)
225
- function Base.:* (L:: JacVec , v:: AbstractVector )
226
- L. autodiff ? auto_jacvec (_x -> L. f (_x), L. x, v) :
227
- num_jacvec (_x -> L. f (_x), L. x, v)
213
+ function update_coefficients! (L:: FwdModeAutoDiffVecProd , u, p, t)
214
+ copy! (L. u, u)
215
+ L
228
216
end
229
217
230
- function LinearAlgebra. mul! (dy:: AbstractVector , L:: JacVec , v:: AbstractVector )
231
- if L. autodiff
232
- auto_jacvec! (dy, (_y, _x) -> L. f (_y, _x), L. x, v, L. cache1, L. cache2)
233
- else
234
- num_jacvec! (dy, (_y, _x) -> L. f (_y, _x), L. x, v, L. cache1, L. cache2)
235
- end
218
+ function (L:: FwdModeAutoDiffVecProd )(v, p, t)
219
+ L. vecprod (L. f, L. u, v)
236
220
end
237
221
238
- struct HesVec{F, T1, T2, xType}
239
- f:: F
240
- cache1:: T1
241
- cache2:: T2
242
- cache3:: T2
243
- x:: xType
244
- autodiff:: Bool
222
+ function (L:: FwdModeAutoDiffVecProd )(dv, v, p, t)
223
+ L. vecprod! (dv, L. f, L. u, v, L. cache... )
245
224
end
246
225
247
- function HesVec (f, x:: AbstractArray ; autodiff = true )
226
+ function JacVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
227
+
248
228
if autodiff
249
- cache1 = ForwardDiff. GradientConfig (f, x)
250
- cache2 = similar (x)
251
- cache3 = similar (x)
229
+ cache1 = Dual{
230
+ typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
231
+ }. (u, ForwardDiff. Partials .(tuple .(u)))
232
+
233
+ cache2 = copy (cache1)
252
234
else
253
- cache1 = similar (x)
254
- cache2 = similar (x)
255
- cache3 = similar (x)
235
+ cache1 = similar (u)
236
+ cache2 = similar (u)
256
237
end
257
- HesVec (f, cache1, cache2, cache3, x, autodiff)
258
- end
259
238
260
- Base. size (L:: HesVec ) = (length (L. cache2), length (L. cache2))
261
- Base. size (L:: HesVec , i:: Int ) = length (L. cache2)
262
- function Base.:* (L:: HesVec , v:: AbstractVector )
263
- L. autodiff ? numauto_hesvec (L. f, L. x, v) : num_hesvec (L. f, L. x, v)
264
- end
239
+ cache = (cache1, cache2,)
265
240
266
- function LinearAlgebra. mul! (dy:: AbstractVector , L:: HesVec , v:: AbstractVector )
267
- if L. autodiff
268
- numauto_hesvec! (dy, L. f, L. x, v, L. cache1, L. cache2, L. cache3)
269
- else
270
- num_hesvec! (dy, L. f, L. x, v, L. cache1, L. cache2, L. cache3)
241
+ vecprod = autodiff ? auto_jacvec : num_jacvec
242
+ vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243
+
244
+ outofplace = static_hasmethod (f, typeof ((u,)))
245
+ isinplace = static_hasmethod (f, typeof ((u, u,)))
246
+
247
+ if ! (isinplace) & ! (outofplace)
248
+ error (" $f must have signature f(u), or f(du, u)." )
271
249
end
272
- end
273
250
274
- struct HesVecGrad{G, T1, T2, uType}
275
- g :: G
276
- cache1 :: T1
277
- cache2 :: T2
278
- x :: uType
279
- autodiff :: Bool
251
+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
252
+
253
+ FunctionOperator (L, u, u;
254
+ isinplace = isinplace, outofplace = outofplace,
255
+ p = p, t = t, islinear = true ,
256
+ )
280
257
end
281
258
282
- function HesVecGrad (g, x:: AbstractArray , tag = DeivVecTag (); autodiff = false )
259
+ function HesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
260
+
283
261
if autodiff
284
- cache1 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
285
- }. (x, ForwardDiff. Partials .(tuple .(x)))
286
- cache2 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
287
- }. (x, ForwardDiff. Partials .(tuple .(x)))
262
+ cache1 = ForwardDiff. GradientConfig (f, u)
263
+ cache2 = similar (u)
264
+ cache3 = similar (u)
288
265
else
289
- cache1 = similar (x)
290
- cache2 = similar (x)
266
+ cache1 = similar (u)
267
+ cache2 = similar (u)
268
+ cache3 = similar (u)
291
269
end
292
- HesVecGrad (g, cache1, cache2, x, autodiff)
293
- end
294
270
295
- Base. size (L:: HesVecGrad ) = (length (L. cache2), length (L. cache2))
296
- Base. size (L:: HesVecGrad , i:: Int ) = length (L. cache2)
297
- function Base.:* (L:: HesVecGrad , v:: AbstractVector )
298
- L. autodiff ? auto_hesvecgrad (L. g, L. x, v) : num_hesvecgrad (L. g, L. x, v)
271
+ cache = (cache1, cache2, cache3,)
272
+
273
+ vecprod = autodiff ? numauto_hesvec : num_hesvec
274
+ vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
275
+
276
+ outofplace = static_hasmethod (f, typeof ((u,)))
277
+ isinplace = static_hasmethod (f, typeof ((u,)))
278
+
279
+ if ! (isinplace) & ! (outofplace)
280
+ error (" $f must have signature f(u)." )
281
+ end
282
+
283
+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
284
+
285
+ FunctionOperator (L, u, u;
286
+ isinplace = isinplace, outofplace = outofplace,
287
+ p = p, t = t, islinear = true ,
288
+ )
299
289
end
300
290
301
- function LinearAlgebra. mul! (dy:: AbstractVector ,
302
- L:: HesVecGrad ,
303
- v:: AbstractVector )
304
- if L. autodiff
305
- auto_hesvecgrad! (dy, L. g, L. x, v, L. cache1, L. cache2)
291
+ function HesVecGrad (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
292
+
293
+ if autodiff
294
+ cache1 = Dual{
295
+ typeof (ForwardDiff. Tag (DeivVecTag (), eltype (u))), eltype (u), 1
296
+ }. (u, ForwardDiff. Partials .(tuple .(u)))
297
+
298
+ cache2 = copy (cache1)
306
299
else
307
- num_hesvecgrad! (dy, L. g, L. x, v, L. cache1, L. cache2)
300
+ cache1 = similar (u)
301
+ cache2 = similar (u)
302
+ end
303
+
304
+ cache = (cache1, cache2,)
305
+
306
+ vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
307
+ vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
308
+
309
+ outofplace = static_hasmethod (f, typeof ((u,)))
310
+ isinplace = static_hasmethod (f, typeof ((u, u,)))
311
+
312
+ if ! (isinplace) & ! (outofplace)
313
+ error (" $f must have signature f(u), or f(du, u)." )
308
314
end
315
+
316
+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
317
+
318
+ FunctionOperator (L, u, u;
319
+ isinplace = isinplace, outofplace = outofplace,
320
+ p = p, t = t, islinear = true ,
321
+ )
309
322
end
323
+ #
0 commit comments