@@ -3,8 +3,10 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods as _;
6+ use rustc_codegen_ssa:: common:: TypeKind ;
7+ use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
78use rustc_errors:: FatalError ;
9+ use rustc_middle:: bug;
810use tracing:: { debug, trace} ;
911
1012use crate :: back:: write:: llvm_err;
@@ -18,21 +20,42 @@ use crate::value::Value;
1820use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
1921
2022fn get_params ( fnc : & Value ) -> Vec < & Value > {
23+ let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
24+ let mut fnc_args: Vec < & Value > = vec ! [ ] ;
25+ fnc_args. reserve ( param_num) ;
2126 unsafe {
22- let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
23- let mut fnc_args: Vec < & Value > = vec ! [ ] ;
24- fnc_args. reserve ( param_num) ;
2527 llvm:: LLVMGetParams ( fnc, fnc_args. as_mut_ptr ( ) ) ;
2628 fnc_args. set_len ( param_num) ;
27- fnc_args
2829 }
30+ fnc_args
2931}
3032
33+ fn has_sret ( fnc : & Value ) -> bool {
34+ let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
35+ if num_args == 0 {
36+ false
37+ } else {
38+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
39+ }
40+ }
41+
42+ // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
43+ // original inputs, as well as metadata and the additional shadow arguments.
44+ // This function matches the arguments from the outer function to the inner enzyme call.
45+ //
46+ // This function also considers that Rust level arguments not always match the llvm-ir level
47+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
48+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
49+ // need to match those.
50+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
51+ // using iterators and peek()?
3152fn match_args_from_caller_to_enzyme < ' ll > (
3253 cx : & SimpleCx < ' ll > ,
54+ width : u32 ,
3355 args : & mut Vec < & ' ll llvm:: Value > ,
3456 inputs : & [ DiffActivity ] ,
3557 outer_args : & [ & ' ll llvm:: Value ] ,
58+ has_sret : bool ,
3659) {
3760 debug ! ( "matching autodiff arguments" ) ;
3861 // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
4467 let mut outer_pos: usize = 0 ;
4568 let mut activity_pos = 0 ;
4669
70+ if has_sret {
71+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72+ // inner function will still return something. We increase our outer_pos by one,
73+ // and once we're done with all other args we will take the return of the inner call and
74+ // update the sret pointer with it
75+ outer_pos = 1 ;
76+ }
77+
4778 let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
4879 let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
4980 let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -92,39 +123,40 @@ fn match_args_from_caller_to_enzyme<'ll>(
92123 // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
93124 // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
94125 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
95- assert ! ( unsafe {
96- llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
97- } ) ;
98- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
99- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
100- assert ! ( unsafe {
101- llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
102- } ) ;
103- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
104- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
105- assert ! ( unsafe {
106- llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
107- } ) ;
108- args. push ( next_outer_arg2) ;
126+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Integer ) ;
127+
128+ for i in 0 ..( width as usize ) {
129+ let next_outer_arg2 = outer_args[ outer_pos + 2 * ( i + 1 ) ] ;
130+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131+ assert_eq ! ( cx. type_kind( next_outer_ty2) , TypeKind :: Pointer ) ;
132+ let next_outer_arg3 = outer_args[ outer_pos + 2 * ( i + 1 ) + 1 ] ;
133+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
134+ assert_eq ! ( cx. type_kind( next_outer_ty3) , TypeKind :: Integer ) ;
135+ args. push ( next_outer_arg2) ;
136+ }
109137 args. push ( cx. get_metadata_value ( enzyme_const) ) ;
110138 args. push ( next_outer_arg) ;
111- outer_pos += 4 ;
139+ outer_pos += 2 + 2 * width as usize ;
112140 activity_pos += 2 ;
113141 } else {
114142 // A duplicated pointer will have the following two outer_fn arguments:
115143 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
116144 // (..., metadata! enzyme_dup, ptr, ptr, ...).
117145 if matches ! ( diff_activity, DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly )
118146 {
119- assert ! (
120- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty) }
121- == llvm:: TypeKind :: Pointer
122- ) ;
147+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Pointer ) ;
123148 }
124149 // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
125150 args. push ( next_outer_arg) ;
126151 outer_pos += 2 ;
127152 activity_pos += 1 ;
153+
154+ // Now, if width > 1, we need to account for that
155+ for _ in 1 ..width {
156+ let next_outer_arg = outer_args[ outer_pos] ;
157+ args. push ( next_outer_arg) ;
158+ outer_pos += 1 ;
159+ }
128160 }
129161 } else {
130162 // We do not differentiate with resprect to this argument.
@@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
135167 }
136168}
137169
170+ // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
171+ // arguments. We do however need to declare them with their correct return type.
172+ // We already figured the correct return type out in our frontend, when generating the outer_fn,
173+ // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
174+ // Beyond sret, this article describes our challenges nicely:
175+ // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
176+ // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
177+ fn compute_enzyme_fn_ty < ' ll > (
178+ cx : & SimpleCx < ' ll > ,
179+ attrs : & AutoDiffAttrs ,
180+ fn_to_diff : & ' ll Value ,
181+ outer_fn : & ' ll Value ,
182+ ) -> & ' ll llvm:: Type {
183+ let fn_ty = cx. get_type_of_global ( outer_fn) ;
184+ let mut ret_ty = cx. get_return_type ( fn_ty) ;
185+
186+ let has_sret = has_sret ( outer_fn) ;
187+
188+ if has_sret {
189+ // Now we don't just forward the return type, so we have to figure it out based on the
190+ // primal return type, in combination with the autodiff settings.
191+ let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
192+ let inner_ret_ty = cx. get_return_type ( fn_ty) ;
193+
194+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
195+ if inner_ret_ty == void_ty {
196+ // This indicates that even the inner function has an sret.
197+ // Right now I only look for an sret in the outer function.
198+ // This *probably* needs some extra handling, but I never ran
199+ // into such a case. So I'll wait for user reports to have a test case.
200+ bug ! ( "sret in inner function" ) ;
201+ }
202+
203+ if attrs. width == 1 {
204+ todo ! ( "Handle sret for scalar ad" ) ;
205+ } else {
206+ // First we check if we also have to deal with the primal return.
207+ match attrs. mode {
208+ DiffMode :: Forward => match attrs. ret_activity {
209+ DiffActivity :: Dual => {
210+ let arr_ty =
211+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
212+ ret_ty = arr_ty;
213+ }
214+ DiffActivity :: DualOnly => {
215+ let arr_ty =
216+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
217+ ret_ty = arr_ty;
218+ }
219+ DiffActivity :: Const => {
220+ todo ! ( "Not sure, do we need to do something here?" ) ;
221+ }
222+ _ => {
223+ bug ! ( "unreachable" ) ;
224+ }
225+ } ,
226+ DiffMode :: Reverse => {
227+ todo ! ( "Handle sret for reverse mode" ) ;
228+ }
229+ _ => {
230+ bug ! ( "unreachable" ) ;
231+ }
232+ }
233+ }
234+ }
235+
236+ // LLVM can figure out the input types on it's own, so we take a shortcut here.
237+ unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
238+ }
239+
138240/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139241/// function with expected naming and calling conventions[^1] which will be
140242/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
197299 // }
198300 // ```
199301 unsafe {
200- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201- // arguments. We do however need to declare them with their correct return type.
202- // We already figured the correct return type out in our frontend, when generating the outer_fn,
203- // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204- let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
205- let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
206-
207- // LLVM can figure out the input types on it's own, so we take a shortcut here.
208- let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
302+ let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
209303
210- //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
304+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211305 // think a bit more about what should go here.
212306 let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
213307 let ad_fn = declare_simple_fn (
@@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
240334 if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
241335 args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
242336 }
337+ if attrs. width > 1 {
338+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
339+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
340+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
341+ }
243342
343+ let has_sret = has_sret ( outer_fn) ;
244344 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
245- match_args_from_caller_to_enzyme ( & cx, & mut args, & attrs. input_activity , & outer_args) ;
345+ match_args_from_caller_to_enzyme (
346+ & cx,
347+ attrs. width ,
348+ & mut args,
349+ & attrs. input_activity ,
350+ & outer_args,
351+ has_sret,
352+ ) ;
246353
247354 let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
248355
249356 // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
250- // metadata attachted to it, but we just created this code oota. Given that the
357+ // metadata attached to it, but we just created this code oota. Given that the
251358 // differentiated function already has partly confusing metadata, and given that this
252359 // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
253360 // dummy code which we inserted at a higher level.
@@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
268375 // Now that we copied the metadata, get rid of dummy code.
269376 llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
270377
271- if cx. val_ty ( call) == cx. type_void ( ) {
378+ if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
379+ if has_sret {
380+ // This is what we already have in our outer_fn (shortened):
381+ // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
382+ // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
383+ // <Here we are, we want to add the following two lines>
384+ // store [4 x double] %7, ptr %0, align 8
385+ // ret void
386+ // }
387+
388+ // now store the result of the enzyme call into the sret pointer.
389+ let sret_ptr = outer_args[ 0 ] ;
390+ let call_ty = cx. val_ty ( call) ;
391+ assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
392+ llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
393+ }
272394 builder. ret_void ( ) ;
273395 } else {
274396 builder. ret ( call) ;
@@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
300422 if !diff_items. is_empty ( )
301423 && !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
302424 {
303- let dcx = cgcx. create_dcx ( ) ;
304- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
425+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
305426 }
306427
307428 // Before dumping the module, we want all the TypeTrees to become part of the module.
0 commit comments