@@ -46,18 +46,24 @@ protected override (string name, Tensor buffer)[] _named_buffers()
4646 return ptrArray . Select ( ( x , i ) => ( Marshal . PtrToStringAnsi ( strArray [ i ] ) , new Tensor ( x ) ) ) . ToArray ( ) ;
4747 }
4848
49- public ( string name , Tensor buffer ) [ ] named_attributes ( )
49+ public ( string name , Tensor buffer ) [ ] named_attributes ( bool recurse = true )
5050 {
5151 using var pa = new PinnedArray < IntPtr > ( ) ;
5252 using var sa = new PinnedArray < IntPtr > ( ) ;
53- THSJIT_Module_named_attributes ( handle , pa . CreateArray , sa . CreateArray ) ;
53+ THSJIT_Module_named_attributes ( handle , recurse , pa . CreateArray , sa . CreateArray ) ;
5454 CheckForErrors ( ) ;
5555 var ptrArray = pa . Array ;
5656 var strArray = sa . Array ;
5757
5858 return ptrArray . Select ( ( x , i ) => ( Marshal . PtrToStringAnsi ( strArray [ i ] ) , new Tensor ( x ) ) ) . ToArray ( ) ;
5959 }
6060
61+ public void set_attribute ( string name , Tensor buffer )
62+ {
63+ THSJIT_Module_set_attribute ( handle , name , buffer . Handle ) ;
64+ CheckForErrors ( ) ;
65+ }
66+
6167 /// <summary>
6268 /// Returns an enumerable of all modules in the network, yielding both the name of the module as well as the module itself.
6369 /// </summary>
@@ -149,7 +155,7 @@ protected internal override nn.Module _to(Device device, ScalarType dtype)
149155 CheckForErrors ( ) ;
150156
151157 _toEpilog ( device , dtype ) ;
152-
158+ _toScriptEpilog ( device , dtype ) ;
153159 return this ;
154160 }
155161
@@ -172,6 +178,7 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
172178 CheckForErrors ( ) ;
173179
174180 _toEpilog ( deviceType , deviceIndex ) ;
181+ _toScriptEpilog ( deviceType , deviceIndex ) ;
175182 }
176183
177184 Debug . Assert ( _deviceType == DeviceType . CUDA || _deviceIndex == - 1 ) ;
@@ -192,10 +199,35 @@ protected internal override nn.Module _to(ScalarType dtype)
192199 CheckForErrors ( ) ;
193200
194201 _toEpilog ( dtype ) ;
202+ _toScriptEpilog ( dtype ) ;
195203
196204 return this ;
197205 }
198206
207+ protected void _toScriptEpilog ( ScalarType dtype )
208+ {
209+ _toScriptEpilog ( dtype , null ) ;
210+ }
211+
212+ protected void _toScriptEpilog ( Device device , ScalarType dtype )
213+ {
214+ _toScriptEpilog ( dtype , device ) ;
215+ }
216+
217+ protected void _toScriptEpilog ( DeviceType deviceType , int deviceIndex )
218+ {
219+ _toScriptEpilog ( null , new Device ( deviceType , deviceIndex ) ) ;
220+ }
221+
222+ private void _toScriptEpilog ( ScalarType ? dtype , Device device )
223+ {
224+ foreach ( var ( name , buffer ) in named_attributes ( recurse : false ) ) {
225+ if ( name is null || ! buffer . toWillCopy ( dtype ?? buffer . dtype , device ?? buffer . device ) ) continue ;
226+
227+ set_attribute ( name , buffer . to ( dtype ?? buffer . dtype , device ?? buffer . device , disposeAfter : true ) ) ;
228+ }
229+ }
230+
199231#if false // These functions "work," but the native code doesn't seem to find any interesting information.
200232
201233 public Type GetInputType ( int index )
0 commit comments