Skip to content

Implement autodiff using intrinsics #142640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 174 additions & 8 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
use rustc_ast::tokenstream::*;
use rustc_ast::visit::AssocCtxt::*;
use rustc_ast::{
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind,
FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path,
PathSegment, QSelf, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::{Ident, Span, Symbol, kw, sym};
Expand Down Expand Up @@ -330,17 +331,28 @@
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = gen_enzyme_body(

// TODO(Sa4dUs): Remove this and all the related logic

Check failure on line 335 in compiler/rustc_builtin_macros/src/autodiff.rs

View workflow job for this annotation

GitHub Actions / PR - tidy

TODO is used for tasks that should be done before merging a PR; If you want to leave a message in the codebase use FIXME
let _d_body = gen_enzyme_body(
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
&generics,
);

let d_body = call_enzyme_autodiff(
ecx,
primal,
first_ident(&meta_item_vec[0]),
span,
&d_sig,
&generics,
);

// The first element of it is the name of the function to be generated
let asdf = Box::new(ast::Fn {
defaultness: ast::Defaultness::Final,
sig: d_sig,
ident: first_ident(&meta_item_vec[0]),
generics,
generics: generics.clone(),
contract: None,
body: Some(d_body),
define_opaque: None,
Expand Down Expand Up @@ -429,12 +441,15 @@
tokens: ts,
});

let vis_clone = vis.clone();

let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
attrs: thin_vec![d_attr, inline_never],
attrs: thin_vec![d_attr],
id: ast::DUMMY_NODE_ID,
span,
vis,
Expand All @@ -444,13 +459,13 @@
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
d_fn.vis = vis;

Annotatable::Item(d_fn)
}
Annotatable::Stmt(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
d_fn.vis = vis;

Annotatable::Stmt(P(ast::Stmt {
Expand All @@ -464,7 +479,9 @@
}
};

return vec![orig_annotatable, d_annotatable];
let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);

return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
}

// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
Expand All @@ -485,6 +502,155 @@
ty
}

// Generate `enzyme_autodiff` intrinsic call
// ```
// std::intrinsics::enzyme_autodiff(source, diff, (args))
// ```
fn call_enzyme_autodiff(
ecx: &ExtCtxt<'_>,
primal: Ident,
diff: Ident,
span: Span,
d_sig: &FnSig,
generics: &Generics,
) -> P<ast::Block> {
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span);
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span);

let tuple_expr = ecx.expr_tuple(
span,
d_sig
.decl
.inputs
.iter()
.map(|arg| match arg.pat.kind {
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
_ => todo!(),
})
.collect::<ThinVec<_>>()
.into(),
);

let enzyme_path = ecx.path(
span,
vec![
Ident::from_str("std"),
Ident::from_str("intrinsics"),
Ident::from_str("enzyme_autodiff"),
],
);
let call_expr = ecx.expr_call(
span,
ecx.expr_path(enzyme_path),
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
);

let block = ecx.block_expr(call_expr);

block
}

// Generate turbofish expression from fn name and generics
// Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
fn gen_turbofish_expr(
ecx: &ExtCtxt<'_>,
ident: Ident,
generics: &Generics,
span: Span,
) -> P<ast::Expr> {
let generic_args = generics
.params
.iter()
.map(|p| {
let path = ast::Path::from_ident(p.ident);
let ty = ecx.ty_path(path);
AngleBracketedArg::Arg(GenericArg::Type(ty))
})
.collect::<ThinVec<_>>();

let args = AngleBracketedArgs { span, args: generic_args };

let segment = PathSegment {
ident,
id: ast::DUMMY_NODE_ID,
args: Some(P(GenericArgs::AngleBracketed(args))),
};

let path = Path { span, segments: thin_vec![segment], tokens: None };

ecx.expr_path(path)
}

// Generate dummy const to prevent primal function
// from being optimized away before applying enzyme
// ```
// const _: () =
// {
// #[used]
// pub static DUMMY_PTR: fn_type = primal_fn;
// };
// ```
fn gen_dummy_const(
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
sig: FnSig,
generics: Generics,
vis: Visibility,
) -> Annotatable {
// #[used]
let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let used_attr = outer_normal_attr(&used_attr, new_id, span);

// static DUMMY_PTR: <fn_type> = <primal_ident>
let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
safety: sig.header.safety,
ext: sig.header.ext,
generic_params: generics.params,
decl: sig.decl,
decl_span: sig.span,
}));
let static_ty = ecx.ty(span, fn_ptr_ty);

let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
ident: static_ident,
ty: static_ty,
safety: ast::Safety::Default,
mutability: ast::Mutability::Not,
expr: Some(static_expr),
define_opaque: None,
}));

let static_item = ast::Item {
attrs: thin_vec![used_attr],
id: ast::DUMMY_NODE_ID,
span,
vis,
kind: static_item_kind,
tokens: None,
};

let block_expr = ecx.expr_block(Box::new(ast::Block {
stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
id: ast::DUMMY_NODE_ID,
rules: ast::BlockCheckMode::Default,
span,
tokens: None,
}));

let const_item = ecx.item_const(
span,
Ident::from_str_and_span("_", span),
ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
block_expr,
);

Annotatable::Item(const_item)
}

// Will generate a body of the type:
// ```
// {
Expand Down
Loading
Loading