@@ -163,66 +163,6 @@ protected internal virtual Module _to(Device device, ScalarType dtype)
163163 return this ;
164164 }
165165
166- protected void _toEpilog ( Device device , ScalarType dtype )
167- {
168- foreach ( var ( _, sm ) in named_children ( ) ) sm . _to ( device , dtype ) ;
169-
170- var alreadyHandled = new HashSet < IntPtr > ( ) ;
171-
172- foreach ( var field in GetType ( ) . GetFields ( BindingFlags . NonPublic | BindingFlags . Public | BindingFlags . Instance ) ) {
173-
174- var fieldName = field . ComponentName ( ) ;
175- var value = field . GetValue ( this ) ;
176-
177- switch ( value ) {
178- // This order in which these cases are arranged is significant.
179- case Parameter param when dtype == param . dtype && device . type == param . device_type && device . index == param . device_index :
180- alreadyHandled . Add ( param . handle ) ;
181- continue ;
182-
183- case Parameter param : {
184- var t = param . to ( dtype , device ) ;
185- t . retain_grad ( ) ;
186- var p = new Parameter ( t , param . requires_grad ) ;
187- field . SetValue ( this , p ) ;
188- ConditionallyRegisterParameter ( fieldName , p ) ;
189- alreadyHandled . Add ( p . handle ) ;
190- break ;
191- }
192-
193- case Tensor tensor when ( device . type != tensor . device_type || device . index != tensor . device_index ) : {
194- var t = tensor . to ( dtype , device ) ;
195- field . SetValue ( this , t ) ;
196- ConditionallyRegisterBuffer ( fieldName , t ) ;
197- alreadyHandled . Add ( t . handle ) ;
198- break ;
199- }
200-
201- case Tensor tensor :
202- alreadyHandled . Add ( tensor . handle ) ;
203- break ;
204- }
205- }
206-
207- foreach ( var ( name , param ) in named_parameters ( false ) . ToList ( ) ) {
208- if ( alreadyHandled . Contains ( param . handle ) ) continue ;
209- var t = param . to ( dtype , device ) ;
210- ConditionallyRegisterParameter ( name , t ) ;
211- }
212-
213- foreach ( var ( name , buffer ) in named_buffers ( false ) . ToList ( ) ) {
214- if ( alreadyHandled . Contains ( buffer . handle ) ) continue ;
215- var t = buffer . to ( dtype , device ) ;
216- ConditionallyRegisterBuffer ( name , t ) ;
217- }
218-
219- _deviceType = device . type ;
220- _deviceIndex = device . index ;
221-
222- Debug . Assert ( _deviceType == DeviceType . CUDA || _deviceIndex == - 1 ) ;
223- }
224-
225-
226166 /// <summary>
227167 /// Moves the parameters and buffers.
228168 /// </summary>
@@ -249,63 +189,6 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
249189 return this ;
250190 }
251191
252- protected void _toEpilog ( DeviceType deviceType , int deviceIndex )
253- {
254- foreach ( var ( _, sm ) in named_children ( ) ) sm . _to ( deviceType , deviceIndex ) ;
255-
256- var alreadyHandled = new HashSet < IntPtr > ( ) ;
257-
258- foreach ( var field in GetType ( ) . GetFields ( BindingFlags . NonPublic | BindingFlags . Public | BindingFlags . Instance ) ) {
259-
260- var fieldName = field . ComponentName ( ) ;
261- var value = field . GetValue ( this ) ;
262-
263- switch ( value ) {
264- // This order in which these cases are arranged is significant.
265- case Parameter param when deviceType == param . device_type && deviceIndex == param . device_index :
266- alreadyHandled . Add ( param . handle ) ;
267- continue ;
268-
269- case Parameter param : {
270- var t = param . to ( deviceType , deviceIndex ) ;
271- t . retain_grad ( ) ;
272- var p = new Parameter ( t , param . requires_grad ) ;
273- field . SetValue ( this , p ) ;
274- ConditionallyRegisterParameter ( fieldName , p ) ;
275- alreadyHandled . Add ( p . handle ) ;
276- break ;
277- }
278-
279- case Tensor tensor when ( deviceType != tensor . device_type || deviceIndex != tensor . device_index ) : {
280- var t = tensor . to ( deviceType , deviceIndex ) ;
281- field . SetValue ( this , t ) ;
282- ConditionallyRegisterBuffer ( fieldName , t ) ;
283- alreadyHandled . Add ( t . handle ) ;
284- break ;
285- }
286-
287- case Tensor tensor :
288- alreadyHandled . Add ( tensor . handle ) ;
289- break ;
290- }
291- }
292-
293- foreach ( var ( name , param ) in named_parameters ( false ) . ToList ( ) ) {
294- if ( alreadyHandled . Contains ( param . handle ) ) continue ;
295- var t = param . to ( deviceType , deviceIndex ) ;
296- ConditionallyRegisterParameter ( name , t ) ;
297- }
298-
299- foreach ( var ( name , buffer ) in named_buffers ( false ) . ToList ( ) ) {
300- if ( alreadyHandled . Contains ( buffer . handle ) ) continue ;
301- var t = buffer . to ( deviceType , deviceIndex ) ;
302- ConditionallyRegisterBuffer ( name , t ) ;
303- }
304-
305- _deviceType = deviceType ;
306- _deviceIndex = deviceIndex ;
307- }
308-
309192 private DeviceType _deviceType = DeviceType . CPU ;
310193 private int _deviceIndex = - 1 ;
311194
@@ -325,55 +208,62 @@ protected internal virtual Module _to(ScalarType dtype)
325208
326209 protected void _toEpilog ( ScalarType dtype )
327210 {
328- foreach ( var ( _, sm ) in named_children ( ) ) sm . _to ( dtype ) ;
211+ _toEpilog ( dtype , null ) ;
212+ }
329213
330- var alreadyHandled = new HashSet < IntPtr > ( ) ;
214+ protected void _toEpilog ( Device device , ScalarType dtype )
215+ {
216+ _toEpilog ( dtype , device ) ;
217+ }
331218
332- foreach ( var field in GetType ( ) . GetFields ( BindingFlags . NonPublic | BindingFlags . Public | BindingFlags . Instance ) ) {
219+ protected void _toEpilog ( DeviceType deviceType , int deviceIndex )
220+ {
221+ _toEpilog ( null , new Device ( deviceType , deviceIndex ) ) ;
222+ }
333223
334- var fieldName = field . ComponentName ( ) ;
335- var value = field . GetValue ( this ) ;
224+ private void _toEpilog ( ScalarType ? dtype , Device device )
225+ {
226+ foreach ( var ( _, sm ) in named_children ( ) ) {
227+ if ( device is null ) sm . _to ( dtype . Value ) ;
228+ else if ( dtype is null ) sm . _to ( device . type , device . index ) ;
229+ else sm . _to ( device , dtype . Value ) ;
230+ }
336231
337- switch ( value ) {
338- // This order in which these cases are arranged is significant.
339- case Parameter param when dtype == param . dtype :
340- alreadyHandled . Add ( param . handle ) ;
341- continue ;
342-
343- case Parameter param : {
344- var t = param . to ( dtype ) ;
345- t . retain_grad ( ) ;
346- var p = new Parameter ( t , param . requires_grad ) ;
347- field . SetValue ( this , p ) ;
348- ConditionallyRegisterParameter ( fieldName , p ) ;
349- alreadyHandled . Add ( p . handle ) ;
350- break ;
351- }
232+ var fieldsByComponentName = GetType ( ) . GetFields ( BindingFlags . NonPublic | BindingFlags . Public | BindingFlags . Instance )
233+ . ToDictionary ( field => field . ComponentName ( ) ) ;
352234
353- case Tensor tensor when dtype == tensor . dtype :
354- alreadyHandled . Add ( tensor . handle ) ;
355- continue ;
235+ foreach ( var ( name , param ) in named_parameters ( false ) . ToList ( ) ) {
236+ if ( ! param . toWillCopy ( dtype ?? param . dtype , device ?? param . device ) ) continue ;
356237
357- case Tensor tensor : {
358- var t = tensor . to ( dtype ) ;
359- field . SetValue ( this , t ) ;
360- ConditionallyRegisterBuffer ( fieldName , t ) ;
361- alreadyHandled . Add ( t . handle ) ;
362- break ;
363- }
364- }
365- }
238+ // Store the requires_grad flag ahead, since we dispose the parameter after moving
239+ bool requiresGrad = param . requires_grad ;
240+ Parameter p ;
241+ // When moving the parameter, we don't want the autograd to track this movement on the graph.
242+ // In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
243+ // disable grad we would need to call .detach() on the moved tensor.
244+ using ( var d = torch . no_grad ( ) )
245+ p = new Parameter ( param . to ( dtype ?? param . dtype , device ?? param . device , disposeAfter : true ) , requiresGrad ) ;
246+ ConditionallyRegisterParameter ( name , p ) ;
366247
367- foreach ( var ( name , param ) in named_parameters ( false ) . ToList ( ) ) {
368- if ( alreadyHandled . Contains ( param . handle ) ) continue ;
369- var t = param . to ( dtype ) ;
370- ConditionallyRegisterParameter ( name , t ) ;
248+ // If this parameter is a field, set it
249+ if ( fieldsByComponentName . TryGetValue ( name , out var field ) )
250+ field . SetValue ( this , p ) ;
371251 }
372252
373253 foreach ( var ( name , buffer ) in named_buffers ( false ) . ToList ( ) ) {
374- if ( alreadyHandled . Contains ( buffer . handle ) ) continue ;
375- var t = buffer . to ( dtype ) ;
254+ if ( ! buffer . toWillCopy ( dtype ?? buffer . dtype , device ?? buffer . device ) ) continue ;
255+
256+ // Buffers don't get grads so we don't need to detach them afterwards
257+ var t = buffer . to ( dtype ?? buffer . dtype , device ?? buffer . device , disposeAfter : true ) ;
376258 ConditionallyRegisterBuffer ( name , t ) ;
259+
260+ if ( fieldsByComponentName . TryGetValue ( name , out var field ) )
261+ field . SetValue ( this , t ) ;
262+ }
263+
264+ if ( device is not null ) {
265+ _deviceType = device . type ;
266+ _deviceIndex = device . index ;
377267 }
378268 }
379269
0 commit comments