diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index df1b1eb60e18f..a4a038d876e23 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -16,8 +16,9 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, - MetaItemInner, PatKind, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, FnRetTy, + FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, + PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -73,10 +74,12 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { + fn extract_item_info( + iitem: &P, + ) -> Option<(Visibility, FnSig, Ident, Generics, bool)> { match &iitem.kind { ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false)) } _ => None, } @@ -228,16 +231,20 @@ mod llvm_enzyme { // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function - let Some((vis, sig, primal, generics)) = (match &item { + let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) - } + Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( + assoc_item.vis.clone(), + sig.clone(), + ident.clone(), + generics.clone(), + *of_trait, + )), _ => None, }, _ => None, @@ -255,7 +262,6 @@ mod llvm_enzyme { }; let has_ret = has_ret(&sig.decl.output); - let sig_span = ecx.with_call_site_ctxt(sig.span); // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field @@ -324,15 +330,18 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); - let n_active: u32 = x - .input_activity - .iter() - .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) - .count() as u32; - let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); + let (d_sig, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); + let d_body = gen_enzyme_body( - ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, + ecx, + &d_sig, + primal, + span, + idents, + errored, + first_ident(&meta_item_vec[0]), &generics, + impl_of_trait, ); // The first element of it is the name of the function to be generated @@ -429,12 +438,13 @@ mod llvm_enzyme { tokens: ts, }); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -444,13 +454,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { @@ -485,76 +495,111 @@ mod llvm_enzyme { ty } - // Will generate a body of the type: + // Generate `enzyme_autodiff` intrinsic call // ``` - // { - // unsafe { - // asm!("NOP"); - // } - // ::core::hint::black_box(primal(args)); - // ::core::hint::black_box((args, ret)); - // - // } + // std::intrinsics::enzyme_autodiff(source, diff, (args)) // ``` - fn init_body_helper( + fn call_enzyme_autodiff( ecx: &ExtCtxt<'_>, - span: Span, primal: Ident, - new_names: &[String], - sig_span: Span, - new_decl_span: Span, - idents: &[Ident], - errored: bool, + diff: Ident, + span: Span, + d_sig: &FnSig, generics: &Generics, - ) -> (P, P, P, P) { - let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let noop = ast::InlineAsm { - asm_macro: ast::AsmMacro::Asm, - template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], - template_strs: Box::new([]), - operands: vec![], - clobber_abis: vec![], - options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, - line_spans: vec![], - }; - let noop_expr = ecx.expr_asm(span, P(noop)); - let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); - let unsf_block = ast::Block { - stmts: thin_vec![ecx.stmt_semi(noop_expr)], - id: ast::DUMMY_NODE_ID, - tokens: None, - rules: unsf, + is_impl: bool, + ) -> rustc_ast::Stmt { + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl); + + let tuple_expr = ecx.expr_tuple( span, - }; - let unsf_expr = ecx.expr_block(P(unsf_block)); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let black_box_primal_call = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], + d_sig + .decl + .inputs + .iter() + .map(|arg| match arg.pat.kind { + PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)), + _ => todo!(), + }) + .collect::>() + .into(), ); - let tup_args = new_names - .iter() - .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) - .collect(); - let black_box_remaining_args = ecx.expr_call( - sig_span, - blackbox_call_expr.clone(), - thin_vec![ecx.expr_tuple(sig_span, tup_args)], + let enzyme_path = ecx.path( + span, + vec![ + Ident::from_str("std"), + Ident::from_str("intrinsics"), + Ident::with_dummy_span(sym::enzyme_autodiff), + ], + ); + let call_expr = ecx.expr_call( + span, + ecx.expr_path(enzyme_path), + vec![primal_path_expr, diff_path_expr, tuple_expr].into(), ); - let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_semi(unsf_expr)); + ecx.stmt_expr(call_expr) + } - // This uses primal args which won't be available if we errored before - if !errored { - body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); - } - body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); + // Generate turbofish expression from fn name and generics + // Given `foo` and `` params, gen `foo::` + fn gen_turbofish_expr( + ecx: &ExtCtxt<'_>, + ident: Ident, + generics: &Generics, + span: Span, + is_impl: bool, + ) -> P { + let generic_args = generics + .params + .iter() + .map(|p| { + let path = ast::Path::from_ident(p.ident); + let ty = ecx.ty_path(path); + AngleBracketedArg::Arg(GenericArg::Type(ty)) + }) + .collect::>(); - (body, primal_call, black_box_primal_call, blackbox_call_expr) + let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args }; + + let segment = PathSegment { + ident, + id: ast::DUMMY_NODE_ID, + args: Some(P(GenericArgs::AngleBracketed(args))), + }; + + let segments = if is_impl { + thin_vec![ + PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None }, + segment, + ] + } else { + thin_vec![segment] + }; + + let path = Path { span, segments, tokens: None }; + + ecx.expr_path(path) + } + + // Will generate a body of the type: + // ``` + // primal(args); + // std::intrinsics::enzyme_autodiff(primal, diff, (args)) + // } + // ``` + fn init_body_helper( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + idents: &[Ident], + _errored: bool, + generics: &Generics, + ) -> P { + let _primal_call = gen_primal_call(ecx, span, primal, idents, generics); + let body = ecx.block(span, ThinVec::new()); + body } /// We only want this function to type-check, since we will replace the body @@ -567,132 +612,30 @@ mod llvm_enzyme { /// from optimizing any arguments away. fn gen_enzyme_body( ecx: &ExtCtxt<'_>, - x: &AutoDiffAttrs, - n_active: u32, - sig: &ast::FnSig, d_sig: &ast::FnSig, primal: Ident, - new_names: &[String], span: Span, - sig_span: Span, idents: Vec, errored: bool, + diff_ident: Ident, generics: &Generics, + is_impl: bool, ) -> P { let new_decl_span = d_sig.span; - // Just adding some default inline-asm and black_box usages to prevent early inlining - // and optimizations which alter the function signature. - // - // The bb_primal_call is the black_box call of the primal function. We keep it around, - // since it has the convenient property of returning the type of the primal function, - // Remember, we only care to match types here. - // No matter which return we pick, we always wrap it into a std::hint::black_box call, - // to prevent rustc from propagating it into the caller. - let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( + // Add a call to the primal function to prevent it from being inlined + // and call `enzyme_autodiff` intrinsic (this also covers the return type) + let mut body = init_body_helper(ecx, span, primal, &idents, errored, generics); + + body.stmts.push(call_enzyme_autodiff( ecx, - span, primal, - new_names, - sig_span, + diff_ident, new_decl_span, - &idents, - errored, + d_sig, generics, - ); - - if !has_ret(&d_sig.decl.output) { - // there is no return type that we have to match, () works fine. - return body; - } - - // Everything from here onwards just tries to fullfil the return type. Fun! - - // having an active-only return means we'll drop the original return type. - // So that can be treated identical to not having one in the first place. - let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); - - if primal_ret && n_active == 0 && x.mode.is_rev() { - // We only have the primal ret. - body.stmts.push(ecx.stmt_expr(bb_primal_call)); - return body; - } - - if !primal_ret && n_active == 1 { - // Again no tuple return, so return default float val. - let ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - let arg = ty.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - body.stmts.push(ecx.stmt_expr(default_call_expr)); - return body; - } - - let mut exprs: P = primal_call; - let d_ret_ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - if x.mode.is_fwd() { - // Fwd mode is easy. If the return activity is Const, we support arbitrary types. - // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. - // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. - // In all three cases, we can return `std::hint::black_box(::default())`. - if x.ret_activity == DiffActivity::Const { - // Here we call the primal function, since our dummy function has the same return - // type due to the Const return activity. - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 }; - let y = ExprKind::Path( - Some(P(q)), - ecx.path_ident(span, Ident::with_dummy_span(kw::Default)), - ); - let default_call_expr = ecx.expr(span, y); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); - } - } else if x.mode.is_rev() { - if x.width == 1 { - // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. - match d_ret_ty.kind { - TyKind::Tup(ref args) => { - // We have a tuple return type. We need to create a tuple of the same size - // and fill it with default values. - let mut exprs2 = thin_vec![exprs]; - for arg in args.iter().skip(1) { - let arg = arg.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs2.push(default_call_expr); - } - exprs = ecx.expr_tuple(new_decl_span, exprs2); - } - _ => { - // Interestingly, even the `-> ArbitraryType` case - // ends up getting matched and handled correctly above, - // so we don't have to handle any other case for now. - panic!("Unsupported return type: {:?}", d_ret_ty); - } - } - } - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - unreachable!("Unsupported mode: {:?}", x.mode); - } - - body.stmts.push(ecx.stmt_expr(exprs)); + is_impl, + )); body } @@ -779,7 +722,7 @@ mod llvm_enzyme { sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - ) -> (ast::FnSig, Vec, Vec, bool) { + ) -> (ast::FnSig, Vec, bool) { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; @@ -791,7 +734,7 @@ mod llvm_enzyme { found: num_activities, }); // This is not the right signature, but we can continue parsing. - return (sig.clone(), vec![], vec![], true); + return (sig.clone(), vec![], true); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); @@ -834,7 +777,7 @@ mod llvm_enzyme { if errors { // This is not the right signature, but we can continue parsing. - return (sig.clone(), new_inputs, idents, true); + return (sig.clone(), idents, true); } let unsafe_activities = x @@ -1042,7 +985,7 @@ mod llvm_enzyme { } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); - (d_sig, new_inputs, idents, false) + (d_sig, idents, false) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index b07d9a5cfca8c..4d7515b669906 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,43 +1,18 @@ use std::ptr; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; -use rustc_codegen_ssa::ModuleCodegen; -use rustc_codegen_ssa::back::write::ModuleConfig; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; -use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; -use rustc_errors::FatalError; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_middle::bug; -use tracing::{debug, trace}; +use tracing::debug; -use crate::back::write::llvm_err; -use crate::builder::{SBuilder, UNNAMED}; +use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; -use crate::llvm::{Metadata, True}; +use crate::llvm::{Metadata, True, Type}; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; - -fn get_params(fnc: &Value) -> Vec<&Value> { - let param_num = llvm::LLVMCountParams(fnc) as usize; - let mut fnc_args: Vec<&Value> = vec![]; - fnc_args.reserve(param_num); - unsafe { - llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); - fnc_args.set_len(param_num); - } - fnc_args -} - -fn has_sret(fnc: &Value) -> bool { - let num_args = llvm::LLVMCountParams(fnc) as usize; - if num_args == 0 { - false - } else { - unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) } - } -} +use crate::{attributes, llvm}; // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the // original inputs, as well as metadata and the additional shadow arguments. @@ -49,14 +24,13 @@ fn has_sret(fnc: &Value) -> bool { // need to match those. // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it // using iterators and peek()? -fn match_args_from_caller_to_enzyme<'ll>( +fn match_args_from_caller_to_enzyme<'ll, 'tcx>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll, 'll>, + builder: &mut Builder<'_, 'll, 'tcx>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], outer_args: &[&'ll llvm::Value], - has_sret: bool, ) { debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level @@ -68,14 +42,6 @@ fn match_args_from_caller_to_enzyme<'ll>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - if has_sret { - // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the - // inner function will still return something. We increase our outer_pos by one, - // and once we're done with all other args we will take the return of the inner call and - // update the sret pointer with it - outer_pos = 1; - } - let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); @@ -194,92 +160,6 @@ fn match_args_from_caller_to_enzyme<'ll>( } } -// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input -// arguments. We do however need to declare them with their correct return type. -// We already figured the correct return type out in our frontend, when generating the outer_fn, -// so we can now just go ahead and use that. This is not always trivial, e.g. because sret. -// Beyond sret, this article describes our challenges nicely: -// -// I.e. (i32, f32) will get merged into i64, but we don't handle that yet. -fn compute_enzyme_fn_ty<'ll>( - cx: &SimpleCx<'ll>, - attrs: &AutoDiffAttrs, - fn_to_diff: &'ll Value, - outer_fn: &'ll Value, -) -> &'ll llvm::Type { - let fn_ty = cx.get_type_of_global(outer_fn); - let mut ret_ty = cx.get_return_type(fn_ty); - - let has_sret = has_sret(outer_fn); - - if has_sret { - // Now we don't just forward the return type, so we have to figure it out based on the - // primal return type, in combination with the autodiff settings. - let fn_ty = cx.get_type_of_global(fn_to_diff); - let inner_ret_ty = cx.get_return_type(fn_ty); - - let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) }; - if inner_ret_ty == void_ty { - // This indicates that even the inner function has an sret. - // Right now I only look for an sret in the outer function. - // This *probably* needs some extra handling, but I never ran - // into such a case. So I'll wait for user reports to have a test case. - bug!("sret in inner function"); - } - - if attrs.width == 1 { - // Enzyme returns a struct of style: - // `{ original_ret(if requested), float, float, ... }` - let mut struct_elements = vec![]; - if attrs.has_primal_ret() { - struct_elements.push(inner_ret_ty); - } - // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, - // and therefore part of the return struct. - let param_tys = cx.func_params_types(fn_ty); - for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { - if matches!(act, DiffActivity::Active) { - // Now find the float type at position i based on the fn_ty, - // to know what (f16/f32/f64/...) to add to the struct. - struct_elements.push(param_ty); - } - } - ret_ty = cx.type_struct(&struct_elements, false); - } else { - // First we check if we also have to deal with the primal return. - match attrs.mode { - DiffMode::Forward => match attrs.ret_activity { - DiffActivity::Dual => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) }; - ret_ty = arr_ty; - } - DiffActivity::DualOnly => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) }; - ret_ty = arr_ty; - } - DiffActivity::Const => { - todo!("Not sure, do we need to do something here?"); - } - _ => { - bug!("unreachable"); - } - }, - DiffMode::Reverse => { - todo!("Handle sret for reverse mode"); - } - _ => { - bug!("unreachable"); - } - } - } - } - - // LLVM can figure out the input types on it's own, so we take a shortcut here. - unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) } -} - /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another /// function with expected naming and calling conventions[^1] which will be /// discovered by the enzyme LLVM pass and its body populated with the differentiated @@ -289,11 +169,15 @@ fn compute_enzyme_fn_ty<'ll>( /// [^1]: // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll>( +pub(crate) fn generate_enzyme_call<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, - outer_fn: &'ll Value, + outer_name: &str, + ret_ty: &'ll Type, + fn_args: &[&'ll Value], attrs: AutoDiffAttrs, + dest: PlaceRef<'tcx, &'ll Value>, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -303,11 +187,9 @@ fn generate_enzyme_call<'ll>( } .to_string(); - // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple + // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. - let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::str::from_utf8(name).unwrap(); - ad_name.push_str(outer_fn_name); + ad_name.push_str(outer_name); // Let us assume the user wrote the following function square: // @@ -341,176 +223,49 @@ fn generate_enzyme_call<'ll>( // ret double %0 // } // ``` - unsafe { - let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn); - - // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and - // think a bit more about what should go here. - let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_simple_fn( - cx, - &ad_name, - llvm::CallConv::try_from(cc).expect("invalid callconv"), - llvm::UnnamedAddr::No, - llvm::Visibility::Default, - enzyme_ty, - ); - - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - - // We add a made-up attribute just such that we can recognize it after AD to update - // (no)-inline attributes. We'll then also remove this attribute. - let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); - attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - - // first, remove all calls from fnc - let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); - let br = llvm::LLVMRustGetTerminator(entry); - llvm::LLVMRustEraseInstFromParent(br); - - let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = SBuilder::build(cx, entry); - - let num_args = llvm::LLVMCountParams(&fn_to_diff); - let mut args = Vec::with_capacity(num_args as usize + 1); - args.push(fn_to_diff); - - let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); - if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { - args.push(cx.get_metadata_value(enzyme_primal_ret)); - } - if attrs.width > 1 { - let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); - args.push(cx.get_metadata_value(enzyme_width)); - args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); - } - - let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = get_params(outer_fn); - match_args_from_caller_to_enzyme( - &cx, - &builder, - attrs.width, - &mut args, - &attrs.input_activity, - &outer_args, - has_sret, - ); - - let call = builder.call(enzyme_ty, ad_fn, &args, None); - - // This part is a bit iffy. LLVM requires that a call to an inlineable function has some - // metadata attached to it, but we just created this code oota. Given that the - // differentiated function already has partly confusing metadata, and given that this - // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the - // dummy code which we inserted at a higher level. - // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, - // and how to best improve it for enzyme core and rust-enzyme. - let md_ty = cx.get_md_kind_id("dbg"); - if llvm::LLVMRustHasMetadata(last_inst, md_ty) { - let md = llvm::LLVMRustDIGetInstMetadata(last_inst) - .expect("failed to get instruction metadata"); - let md_todiff = cx.get_metadata_value(md); - llvm::LLVMSetMetadata(call, md_ty, md_todiff); - } else { - // We don't panic, since depending on whether we are in debug or release mode, we might - // have no debug info to copy, which would then be ok. - trace!("no dbg info"); - } - - // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); - - if cx.val_ty(call) == cx.type_void() || has_sret { - if has_sret { - // This is what we already have in our outer_fn (shortened): - // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) { - // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>) - // - // store [4 x double] %7, ptr %0, align 8 - // ret void - // } - - // now store the result of the enzyme call into the sret pointer. - let sret_ptr = outer_args[0]; - let call_ty = cx.val_ty(call); - if attrs.width == 1 { - assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); - } else { - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); - } - llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); - } - builder.ret_void(); - } else { - builder.ret(call); - } - - // Let's crash in case that we messed something up above and generated invalid IR. - llvm::LLVMRustVerifyFunction( - outer_fn, - llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, - ); + let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }; + + // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and + // think a bit more about what should go here. + let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) }; + let ad_fn = declare_simple_fn( + cx, + &ad_name, + llvm::CallConv::try_from(cc).expect("invalid callconv"), + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + enzyme_ty, + ); + + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to + // do it's work. + let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); + attributes::apply_to_llfn(ad_fn, Function, &[attr]); + + let num_args = llvm::LLVMCountParams(&fn_to_diff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(fn_to_diff); + + let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { + args.push(cx.get_metadata_value(enzyme_primal_ret)); } -} - -pub(crate) fn differentiate<'ll>( - module: &'ll ModuleCodegen, - cgcx: &CodegenContext, - diff_items: Vec, - _config: &ModuleConfig, -) -> Result<(), FatalError> { - for item in &diff_items { - trace!("{}", item); - } - - let diag_handler = cgcx.create_dcx(); - - let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); - - // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag? - if !diff_items.is_empty() - && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) - { - return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable)); - } - - // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme. - for item in diff_items.iter() { - let name = item.source.clone(); - let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(fn_def) = fn_def else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find source function".to_owned(), - }, - )); - }; - debug!(?item.target); - let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(fn_target) = fn_target else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find target function".to_owned(), - }, - )); - }; - - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + if attrs.width > 1 { + let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + args.push(cx.get_metadata_value(enzyme_width)); + args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); } - // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts + match_args_from_caller_to_enzyme( + &cx, + builder, + attrs.width, + &mut args, + &attrs.input_activity, + fn_args, + ); - trace!("done with differentiate()"); + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - Ok(()) + builder.store_to_place(call, dest.val); } diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 0324dff6ff256..e5aabaa8b76a8 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -652,7 +652,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type { + pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { assert_eq!(self.type_kind(ty), TypeKind::Function); unsafe { llvm::LLVMGetReturnType(ty) } } diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index d50ad8a1a9cb4..d2ef3039107af 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -37,6 +37,8 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } +// TODO(Sa4dUs): we will need to reintroduce these errors somewhere +/* #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_lto)] pub(crate) struct AutoDiffWithoutLTO; @@ -44,6 +46,7 @@ pub(crate) struct AutoDiffWithoutLTO; #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_enable)] pub(crate) struct AutoDiffWithoutEnable; +*/ #[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index f7f062849a8b5..2724b834fa820 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,23 +3,26 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; +use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; -use rustc_hir as hir; +use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; -use rustc_symbol_mangling::mangle_internal_symbol; +use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::spec::PanicStrategy; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; +use crate::builder::autodiff::generate_enzyme_call; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -187,6 +190,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } + sym::enzyme_autodiff => { + codegen_enzyme_autodiff(self, tcx, instance, args, result); + return Ok(()); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( @@ -1121,6 +1128,86 @@ fn get_rust_try_fn<'a, 'll, 'tcx>( rust_try } +fn codegen_enzyme_autodiff<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + args: &[OperandRef<'tcx, &'ll Value>], + result: PlaceRef<'tcx, &'ll Value>, +) { + let fn_args = instance.args; + let callee_ty = instance.ty(tcx, bx.typing_env()); + + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), sig); + + let ret_ty = sig.output(); + let llret_ty = bx.layout_of(ret_ty).llvm_type(bx); + + let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]); + + // Get source, diff, and attrs + let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, source_params) => (def_id, source_params), + _ => bug!("invalid args"), + }; + let fn_source = + Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap(); + let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, diff_args) => (def_id, diff_args), + _ => bug!("invalid args"), + }; + let fn_diff = + Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); + let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + + // Build body + generate_enzyme_call( + bx, + bx.cx, + fn_to_diff, + &diff_symbol, + llret_ty, + &val_arr, + diff_attrs.clone(), + result, + ); +} + +fn get_args_from_tuple<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + op: OperandRef<'tcx, &'ll Value>, +) -> Vec<&'ll Value> { + match op.val { + OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: op.layout }; + + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(bx, i); + let field_layout = tuple_place.layout.field(bx, i); + let llvm_ty = field_layout.llvm_type(bx.cx); + + let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } + + ret_arr + } + OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + } +} + fn generic_simd_intrinsic<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, name: Symbol, diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index cdfffbe47bfa5..6682f44dc72c4 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -26,10 +26,9 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use context::SimpleCx; -use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; +use errors::ParseTargetMachineConfig; use llvm_util::target_config; use rustc_ast::expand::allocator::AllocatorKind; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -43,7 +42,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::Symbol; mod back { @@ -227,19 +226,6 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } - /// Generate autodiff rules - fn autodiff( - cgcx: &CodegenContext, - module: &ModuleCodegen, - diff_fncs: Vec, - config: &ModuleConfig, - ) -> Result<(), FatalError> { - if cgcx.lto != Lto::Fat { - let dcx = cgcx.create_dcx(); - return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); - } - builder::autodiff::differentiate(module, cgcx, diff_fncs, config) - } } impl LlvmCodegenBackend { diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index ce6fe8a191b3b..5815f15c0fa83 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,13 +1,11 @@ use std::ffi::CString; use std::sync::Arc; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::memmap::Mmap; use rustc_errors::FatalError; use super::write::CodegenContext; use crate::ModuleCodegen; -use crate::back::write::ModuleConfig; use crate::traits::*; pub struct ThinModule { @@ -78,23 +76,6 @@ impl LtoModuleCodegen { LtoModuleCodegen::Thin(ref m) => m.cost(), } } - - /// Run autodiff on Fat LTO module - pub fn autodiff( - self, - cgcx: &CodegenContext, - diff_fncs: Vec, - config: &ModuleConfig, - ) -> Result, FatalError> { - match &self { - LtoModuleCodegen::Fat(module) => { - B::autodiff(cgcx, &module, diff_fncs, config)?; - } - _ => panic!("autodiff called with non-fat LTO module"), - } - - Ok(self) - } } pub enum SerializedModule { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index c3bfe4c13cdf7..588904921e22f 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -408,12 +408,8 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let mut module = + let module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); - if cgcx.lto == Lto::Fat && !autodiff.is_empty() { - let config = cgcx.config(ModuleKind::Regular); - module = module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise()); - } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 7bd27eb3ef1cd..fc753af6fe8c2 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -720,7 +720,7 @@ impl<'a> MixedExportNameAndNoMangleState<'a> { /// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the /// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never /// panic, unless we introduced a bug when parsing the autodiff macro. -fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { +pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { let attrs = tcx.get_attrs(id, sym::rustc_autodiff); let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::>(); diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 07a0609fda1a1..40fd70fa5ad60 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,4 +1,3 @@ -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -62,12 +61,6 @@ pub trait WriteBackendMethods: Clone + 'static { want_summary: bool, ) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); - fn autodiff( - cgcx: &CodegenContext, - module: &ModuleCodegen, - diff_fncs: Vec, - config: &ModuleConfig, - ) -> Result<(), FatalError>; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 060fc51b7bda7..305a215544ce2 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -134,6 +134,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::round_ties_even_f32 | sym::round_ties_even_f64 | sym::round_ties_even_f128 + | sym::enzyme_autodiff | sym::const_eval_select => hir::Safety::Safe, _ => hir::Safety::Unsafe, }; @@ -197,6 +198,7 @@ pub(crate) fn check_intrinsic_type( let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { + sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index e90e32ebebb9f..e1ec4dadee0f0 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -205,6 +205,8 @@ //! this is not implemented however: a mono item will be produced //! regardless of whether it is actually needed or not. +mod autodiff; + use std::cell::OnceCell; use std::path::PathBuf; @@ -237,6 +239,8 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan}; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; +#[cfg(llvm_enzyme)] +use crate::collector::autodiff::collect_enzyme_autodiff_source_fn; use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit}; #[derive(PartialEq)] @@ -916,6 +920,9 @@ fn visit_instance_use<'tcx>( return; } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { + #[cfg(llvm_enzyme)] + collect_enzyme_autodiff_source_fn(tcx, instance, intrinsic, output); + if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will // be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs new file mode 100644 index 0000000000000..d062302ae53a6 --- /dev/null +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -0,0 +1,38 @@ +use rustc_middle::bug; +use rustc_middle::ty::{self, IntrinsicDef, TyCtxt}; +use tracing::debug; + +use crate::collector::{MonoItems, create_fn_mono_item}; + +pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + intrinsic: IntrinsicDef, + output: &mut MonoItems<'tcx>, +) { + if intrinsic.name != rustc_span::sym::enzyme_autodiff { + return; + }; + + debug!("enzyme_autodiff found"); + let (primal, span) = match instance.args[0].kind() { + rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() { + ty::FnDef(def_id, substs) => { + let span = tcx.def_span(def_id); + let instance = ty::Instance::expect_resolve( + tcx, + ty::TypingEnv::non_body_analysis(tcx, def_id), + *def_id, + substs, + span, + ); + + (instance, span) + } + _ => bug!("expected function"), + }, + _ => bug!("expected type"), + }; + + output.push(create_fn_mono_item(tcx, primal, span)); +} diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index c9262d24a1717..f7f95ac7ee512 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -915,6 +915,7 @@ symbols! { enumerate_method, env, env_CFG_RELEASE: env!("CFG_RELEASE"), + enzyme_autodiff, eprint_macro, eprintln_macro, eq, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index b5c3e91d04687..1aecc62ac315e 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3114,6 +3114,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn enzyme_autodiff(f: F, df: G, args: T) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] diff --git a/tests/codegen/autodiff/batched.rs b/tests/codegen/autodiff/batched.rs index d27aed50e6cc4..5e94c7bb9b8e6 100644 --- a/tests/codegen/autodiff/batched.rs +++ b/tests/codegen/autodiff/batched.rs @@ -10,6 +10,7 @@ // reduce this test to only match the first lines and the ret instructions. #![feature(autodiff)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_forward; @@ -21,7 +22,7 @@ fn square(x: &f32) -> f32 { x * x } -// d_sqaure2 +// d_square2 // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") // CHECK-NEXT: start: // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 @@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 { // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: ret [4 x float] %19 -// CHECK-NEXT: } +// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl" +// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val +// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 +// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1" +// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val +// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 +// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2" +// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val +// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 +// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3" +// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val +// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 +// CHECK-NEXT: ret [4 x float] %15 +// CHECK-NEXT: } // d_square3, the extra float is the original return value (x * x) // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") @@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 { // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 // CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 -// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 -// CHECK-NEXT: ret { float, [4 x float] } %21 -// CHECK-NEXT: } +// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl" +// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val +// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 +// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1" +// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val +// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 +// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2" +// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val +// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 +// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3" +// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val +// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 +// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0 +// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1 +// CHECK-NEXT: ret { float, [4 x float] } %17 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen/autodiff/generic.rs b/tests/codegen/autodiff/generic.rs index 2f674079be021..9553ef3760e74 100644 --- a/tests/codegen/autodiff/generic.rs +++ b/tests/codegen/autodiff/generic.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen/autodiff/identical_fnc.rs b/tests/codegen/autodiff/identical_fnc.rs index 1c25b3d09ab0d..1b4edf1d954b4 100644 --- a/tests/codegen/autodiff/identical_fnc.rs +++ b/tests/codegen/autodiff/identical_fnc.rs @@ -10,6 +10,7 @@ // We also explicetly test that we keep running merge_function after AD, by checking for two // identical function calls in the LLVM-IR, while having two different calls in the Rust code. #![feature(autodiff)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; @@ -29,10 +30,8 @@ fn square2(x: &f64) -> f64 { // CHECK-NEXT:start: // CHECK-NOT:br // CHECK-NOT:ret -// CHECK:; call identical_fnc::d_square -// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1) -// CHECK-NEXT:; call identical_fnc::d_square -// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2) +// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx1) +// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx2) fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen/autodiff/inline.rs b/tests/codegen/autodiff/inline.rs index 65bed170207cc..1aa1b8a912be1 100644 --- a/tests/codegen/autodiff/inline.rs +++ b/tests/codegen/autodiff/inline.rs @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen/autodiff/scalar.rs b/tests/codegen/autodiff/scalar.rs index 096b4209e84ad..745b03ee0ed8f 100644 --- a/tests/codegen/autodiff/scalar.rs +++ b/tests/codegen/autodiff/scalar.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen/autodiff/sret.rs b/tests/codegen/autodiff/sret.rs index d2fa85e3e3787..e2272fd4df7d3 100644 --- a/tests/codegen/autodiff/sret.rs +++ b/tests/codegen/autodiff/sret.rs @@ -8,6 +8,7 @@ // We therefore use this test to verify some of our sret handling. #![feature(autodiff)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; @@ -17,26 +18,25 @@ fn primal(x: f32, y: f32) -> f64 { (x * x * y) as f64 } -// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) -// CHECK-NEXT:start: -// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) -// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 -// CHECK-NEXT: store double %.elt, ptr %_0, align 8 -// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 -// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 -// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 -// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 -// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 -// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 -// CHECK-NEXT: ret void -// CHECK-NEXT:} +// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y) +// CHECK-NEXT: invertstart: +// CHECK-NEXT: %_4 = fmul float %x, %x +// CHECK-NEXT: %_3 = fmul float %_4, %y +// CHECK-NEXT: %_0 = fpext float %_3 to double +// CHECK-NEXT: %0 = fadd fast float %y, %y +// CHECK-NEXT: %1 = fmul fast float %0, %x +// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0 +// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1 +// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2 +// CHECK-NEXT: ret { double, float, float } %4 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); let y = std::hint::black_box(2.5); let scalar = std::hint::black_box(1.0); let (r1, r2, r3) = df(x, y, scalar); - // 3*3*1.5 = 22.5 + // 3*3*2.5 = 22.5 assert_eq!(r1, 22.5); // 2*x*y = 2*3*2.5 = 15.0 assert_eq!(r2, 15.0); diff --git a/tests/codegen/autodiff/trait.rs b/tests/codegen/autodiff/trait.rs new file mode 100644 index 0000000000000..988e9145087b2 --- /dev/null +++ b/tests/codegen/autodiff/trait.rs @@ -0,0 +1,31 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Just check it does not crash for now +// CHECK: ; +#![feature(autodiff)] +#![feature(core_intrinsics)] + +use std::autodiff::autodiff_reverse; + +struct Foo { + a: f64, +} + +trait MyTrait { + fn f(&self, x: f64) -> f64; + fn df(&self, x: f64, seed: f64) -> (f64, f64); +} + +impl MyTrait for Foo { + #[autodiff_reverse(df, Const, Active, Active)] + fn f(&self, x: f64) -> f64 { + self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) + } +} + +fn main() { + let foo = Foo { a: 3.0f64 }; + dbg!(foo.df(1.0, 1.0)); +} diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index a2525abc83207..787c2e517492c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -36,78 +37,44 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] -#[inline(never)] -pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f64, f64)>::default()) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64); #[rustc_autodiff] #[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f2(x, y)) -} +#[rustc_intrinsic] +pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -#[inline(never)] -pub fn df4() -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4() -> (); #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] -#[inline(never)] -pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((by_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64; #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; struct DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] @@ -115,84 +82,47 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const)] -#[inline(never)] -pub fn df6() -> DoesNotImplDefault { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f6()); - ::core::hint::black_box(()); - ::core::hint::black_box(f6()) -} +#[rustc_intrinsic] +pub fn df6() -> DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] -#[inline(never)] -pub fn df7(x: f32) -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f7(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df7(x: f32) -> (); #[no_mangle] #[rustc_autodiff] #[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] -#[inline(never)] +#[rustc_intrinsic] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 5usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 5usize]>::default()) -} +-> [f32; 5usize]; #[rustc_autodiff(Forward, 4, Dual, DualOnly)] -#[inline(never)] +#[rustc_intrinsic] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 4usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 4usize]>::default()) -} +-> [f32; 4usize]; #[rustc_autodiff(Forward, 1, Dual, DualOnly)] -#[inline(never)] -fn f8_1(x: &f32, bx_0: &f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) -} +#[rustc_intrinsic] +fn f8_1(x: &f32, bx_0: &f32) -> f32; pub fn f9() { #[rustc_autodiff] #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] - #[inline(never)] - fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f32, f32)>::default()) - } + #[rustc_intrinsic] + fn d_inner_2(x: f32, bx_0: f32) + -> (f32, f32); #[rustc_autodiff(Forward, 1, Dual, DualOnly)] - #[inline(never)] - fn d_inner_1(x: f32, bx_0: f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) - } + #[rustc_intrinsic] + fn d_inner_1(x: f32, bx_0: f32) + -> f32; } #[rustc_autodiff] #[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] -#[inline(never)] +#[rustc_intrinsic] pub fn d_square + - Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f10::(x)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f10::(x)) -} +Copy>(x: &T, dx_0: &mut T, dret: T) -> T; fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index e23a1b3e241e9..b003d87dccfa7 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index e67c3443ddef1..6f368c74f1a26 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -29,58 +30,36 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f1(x, y)) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -#[inline(never)] -pub fn df2() { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df2(); #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -#[inline(never)] -pub fn df4(x: f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4(x: f32); #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] -#[inline(never)] -pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dy_0)); -} +#[rustc_intrinsic] +pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32); fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index d37e5e3eb4cec..fc95ba2e5a63e 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp @@ -23,7 +24,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 { unimplemented!() } -enum Foo { Reverse } +enum Foo { + Reverse, +} use Foo::Reverse; // What happens if we already have Reverse in type (enum variant decl) and value (enum variant // constructor) namespace? > It's expected to work normally. diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index d18061b2dbdef..4bc8dac0dc758 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -31,7 +32,7 @@ self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] - #[inline(never)] + #[rustc_intrinsic] fn df(&self, x: f64, dret: f64) -> (f64, f64) { unsafe { asm!("NOP", options(pure, nomem)); }; ::core::hint::black_box(self.f(x)); diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 11ff209f9d89e..9f00ff5eb02c1 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp