@@ -150,6 +150,9 @@ protected virtual void Dispose(bool disposing)
150150 /// <param name="dtype">The target element type.</param>
151151 protected internal virtual Module _to ( Device device , ScalarType dtype )
152152 {
153+ if ( ! dtype . IsFloatingPoint ( ) && ! dtype . IsComplex ( ) )
154+ throw new ArgumentException ( $ "nn.Module.to only accepts floating point or complex types, but got desired dtype={ dtype . ToString ( ) } ") ;
155+
153156 if ( device . type != DeviceType . CUDA ) { device = new Device ( device . type , - 1 ) ; } ;
154157
155158 if ( device . type == DeviceType . CUDA && ! torch . cuda . is_available ( ) ) throw new InvalidOperationException ( "CUDA is not available." ) ;
@@ -198,6 +201,9 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
198201 /// <returns></returns>
199202 protected internal virtual Module _to ( ScalarType dtype )
200203 {
204+ if ( ! dtype . IsFloatingPoint ( ) && ! dtype . IsComplex ( ) )
205+ throw new ArgumentException ( $ "nn.Module.to only accepts floating point or complex types, but got desired dtype={ dtype . ToString ( ) } ") ;
206+
201207 THSNN_Module_to_dtype ( handle , ( sbyte ) dtype ) ;
202208 CheckForErrors ( ) ;
203209
@@ -238,11 +244,14 @@ private void _toEpilog(ScalarType? dtype, Device device)
238244 // Store the requires_grad flag ahead, since we dispose the parameter after moving
239245 bool requiresGrad = param . requires_grad ;
240246 Parameter p ;
247+ ScalarType paramType =
248+ dtype != null && ( param . dtype . IsFloatingPoint ( ) || param . dtype . IsComplex ( ) ) ? dtype . Value : param . dtype ;
249+
241250 // When moving the parameter, we don't want the autograd to track this movement on the graph.
242251 // In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
243252 // disable grad we would need to call .detach() on the moved tensor.
244253 using ( var d = torch . no_grad ( ) )
245- p = new Parameter ( param . to ( dtype ?? param . dtype , device ?? param . device , disposeAfter : true ) , requiresGrad ) ;
254+ p = new Parameter ( param . to ( paramType , device ?? param . device , disposeAfter : true ) , requiresGrad ) ;
246255 ConditionallyRegisterParameter ( name , p ) ;
247256
248257 // If this parameter is a field, set it
@@ -253,8 +262,11 @@ private void _toEpilog(ScalarType? dtype, Device device)
253262 foreach ( var ( name , buffer ) in named_buffers ( false ) . ToList ( ) ) {
254263 if ( ! buffer . toWillCopy ( dtype ?? buffer . dtype , device ?? buffer . device ) ) continue ;
255264
265+ ScalarType bufferType =
266+ dtype != null && ( buffer . dtype . IsFloatingPoint ( ) || buffer . dtype . IsComplex ( ) ) ? dtype . Value : buffer . dtype ;
267+
256268 // Buffers don't get grads so we don't need to detach them afterwards
257- var t = buffer . to ( dtype ?? buffer . dtype , device ?? buffer . device , disposeAfter : true ) ;
269+ var t = buffer . to ( bufferType , device ?? buffer . device , disposeAfter : true ) ;
258270 ConditionallyRegisterBuffer ( name , t ) ;
259271
260272 if ( fieldsByComponentName . TryGetValue ( name , out var field ) )
0 commit comments