Skip to content

Implement autodiff using intrinsics #142640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 138 additions & 189 deletions compiler/rustc_builtin_macros/src/autodiff.rs

Large diffs are not rendered by default.

357 changes: 56 additions & 301 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type {
assert_eq!(self.type_kind(ty), TypeKind::Function);
unsafe { llvm::LLVMGetReturnType(ty) }
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_llvm/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
}
}

// TODO(Sa4dUs): we will need to reintroduce these errors somewhere
/*
#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_without_lto)]
pub(crate) struct AutoDiffWithoutLTO;

#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_without_enable)]
pub(crate) struct AutoDiffWithoutEnable;
*/

#[derive(Diagnostic)]
#[diag(codegen_llvm_lto_disallowed)]
Expand Down
93 changes: 90 additions & 3 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@ use std::cmp::Ordering;

use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
use rustc_hir as hir;
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_hir::{self as hir};
use rustc_middle::mir::BinOp;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
use rustc_middle::ty::{self, GenericArgsRef, Ty};
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_span::{Span, Symbol, sym};
use rustc_symbol_mangling::mangle_internal_symbol;
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
use rustc_target::spec::PanicStrategy;
use tracing::debug;

use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::generate_enzyme_call;
use crate::context::CodegenCx;
use crate::llvm::{self, Metadata};
use crate::type_::Type;
Expand Down Expand Up @@ -187,6 +190,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
&[ptr, args[1].immediate()],
)
}
sym::enzyme_autodiff => {
codegen_enzyme_autodiff(self, tcx, instance, args, result);
return Ok(());
}
sym::is_val_statically_known => {
if let OperandValue::Immediate(imm) = args[0].val {
self.call_intrinsic(
Expand Down Expand Up @@ -1121,6 +1128,86 @@ fn get_rust_try_fn<'a, 'll, 'tcx>(
rust_try
}

fn codegen_enzyme_autodiff<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
result: PlaceRef<'tcx, &'ll Value>,
) {
let fn_args = instance.args;
let callee_ty = instance.ty(tcx, bx.typing_env());

let sig = callee_ty.fn_sig(tcx);
let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), sig);

let ret_ty = sig.output();
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);

let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]);

// Get source, diff, and attrs
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
ty::FnDef(def_id, source_params) => (def_id, source_params),
_ => bug!("invalid args"),
};
let fn_source =
Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap();
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol);
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };

let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
_ => bug!("invalid args"),
};
let fn_diff =
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);

let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };

// Build body
generate_enzyme_call(
bx,
bx.cx,
fn_to_diff,
&diff_symbol,
llret_ty,
&val_arr,
diff_attrs.clone(),
result,
);
}

fn get_args_from_tuple<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
op: OperandRef<'tcx, &'ll Value>,
) -> Vec<&'ll Value> {
match op.val {
OperandValue::Ref(ref place_value) => {
let mut ret_arr = vec![];
let tuple_place = PlaceRef { val: *place_value, layout: op.layout };

for i in 0..tuple_place.layout.layout.0.fields.count() {
let field_place = tuple_place.project_field(bx, i);
let field_layout = tuple_place.layout.field(bx, i);
let llvm_ty = field_layout.llvm_type(bx.cx);

let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align);

ret_arr.push(field_val)
}

ret_arr
}
OperandValue::Pair(v1, v2) => vec![v1, v2],
OperandValue::Immediate(v) => vec![v],
OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
}
}

fn generic_simd_intrinsic<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
name: Symbol,
Expand Down
18 changes: 2 additions & 16 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ use std::mem::ManuallyDrop;
use back::owned_target_machine::OwnedTargetMachine;
use back::write::{create_informational_target_machine, create_target_machine};
use context::SimpleCx;
use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig};
use errors::ParseTargetMachineConfig;
use llvm_util::target_config;
use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn,
Expand All @@ -43,7 +42,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId};
use rustc_middle::ty::TyCtxt;
use rustc_middle::util::Providers;
use rustc_session::Session;
use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest};
use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest};
use rustc_span::Symbol;

mod back {
Expand Down Expand Up @@ -227,19 +226,6 @@ impl WriteBackendMethods for LlvmCodegenBackend {
fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer) {
(module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod()))
}
/// Generate autodiff rules
fn autodiff(
cgcx: &CodegenContext<Self>,
module: &ModuleCodegen<Self::Module>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
if cgcx.lto != Lto::Fat {
let dcx = cgcx.create_dcx();
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO));
}
builder::autodiff::differentiate(module, cgcx, diff_fncs, config)
}
}

impl LlvmCodegenBackend {
Expand Down
19 changes: 0 additions & 19 deletions compiler/rustc_codegen_ssa/src/back/lto.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::ffi::CString;
use std::sync::Arc;

use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::memmap::Mmap;
use rustc_errors::FatalError;

use super::write::CodegenContext;
use crate::ModuleCodegen;
use crate::back::write::ModuleConfig;
use crate::traits::*;

pub struct ThinModule<B: WriteBackendMethods> {
Expand Down Expand Up @@ -78,23 +76,6 @@ impl<B: WriteBackendMethods> LtoModuleCodegen<B> {
LtoModuleCodegen::Thin(ref m) => m.cost(),
}
}

/// Run autodiff on Fat LTO module
pub fn autodiff(
self,
cgcx: &CodegenContext<B>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<LtoModuleCodegen<B>, FatalError> {
match &self {
LtoModuleCodegen::Fat(module) => {
B::autodiff(cgcx, &module, diff_fncs, config)?;
}
_ => panic!("autodiff called with non-fat LTO module"),
}

Ok(self)
}
}

pub enum SerializedModule<M: ModuleBufferMethods> {
Expand Down
6 changes: 1 addition & 5 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,12 +408,8 @@ fn generate_lto_work<B: ExtraBackendMethods>(

if !needs_fat_lto.is_empty() {
assert!(needs_thin_lto.is_empty());
let mut module =
let module =
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
if cgcx.lto == Lto::Fat && !autodiff.is_empty() {
let config = cgcx.config(ModuleKind::Regular);
module = module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise());
}
// We are adding a single work item, so the cost doesn't matter.
vec![(WorkItem::LTO(module), 0)]
} else {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/codegen_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);

let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();
Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_codegen_ssa/src/traits/write.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_errors::{DiagCtxtHandle, FatalError};
use rustc_middle::dep_graph::WorkProduct;

Expand Down Expand Up @@ -62,12 +61,6 @@ pub trait WriteBackendMethods: Clone + 'static {
want_summary: bool,
) -> (String, Self::ThinBuffer);
fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer);
fn autodiff(
cgcx: &CodegenContext<Self>,
module: &ModuleCodegen<Self::Module>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError>;
}

pub trait ThinBufferMethods: Send + Sync {
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::round_ties_even_f32
| sym::round_ties_even_f64
| sym::round_ties_even_f128
| sym::enzyme_autodiff
| sym::const_eval_select => hir::Safety::Safe,
_ => hir::Safety::Unsafe,
};
Expand Down Expand Up @@ -197,6 +198,7 @@ pub(crate) fn check_intrinsic_type(
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
let n_lts = 0;
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
sym::abort => (0, 0, vec![], tcx.types.never),
sym::unreachable => (0, 0, vec![], tcx.types.never),
sym::breakpoint => (0, 0, vec![], tcx.types.unit),
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ symbols! {
enumerate_method,
env,
env_CFG_RELEASE: env!("CFG_RELEASE"),
enzyme_autodiff,
eprint_macro,
eprintln_macro,
eq,
Expand Down
4 changes: 4 additions & 0 deletions library/core/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3114,6 +3114,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
#[rustc_intrinsic]
pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;

#[rustc_nounwind]
#[rustc_intrinsic]
pub const fn enzyme_autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;

/// Inform Miri that a given pointer definitely has a certain alignment.
#[cfg(miri)]
#[rustc_allow_const_fn_unstable(const_eval_select)]
Expand Down
1 change: 1 addition & 0 deletions tests/codegen/autodiff/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// reduce this test to only match the first lines and the ret instructions.

#![feature(autodiff)]
#![feature(intrinsics)]

use std::autodiff::autodiff_forward;

Expand Down
1 change: 1 addition & 0 deletions tests/codegen/autodiff/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]
#![feature(intrinsics)]

use std::autodiff::autodiff_reverse;

Expand Down
1 change: 1 addition & 0 deletions tests/codegen/autodiff/identical_fnc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// We also explicetly test that we keep running merge_function after AD, by checking for two
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
#![feature(autodiff)]
#![feature(intrinsics)]

use std::autodiff::autodiff_reverse;

Expand Down
1 change: 1 addition & 0 deletions tests/codegen/autodiff/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//@ needs-enzyme

#![feature(autodiff)]
#![feature(intrinsics)]

use std::autodiff::autodiff_reverse;

Expand Down
1 change: 1 addition & 0 deletions tests/codegen/autodiff/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]
#![feature(intrinsics)]

use std::autodiff::autodiff_reverse;

Expand Down
28 changes: 14 additions & 14 deletions tests/codegen/autodiff/sret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// We therefore use this test to verify some of our sret handling.

#![feature(autodiff)]
#![feature(intrinsics)]

use std::autodiff::autodiff_reverse;

Expand All @@ -17,26 +18,25 @@ fn primal(x: f32, y: f32) -> f64 {
(x * x * y) as f64
}

// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
// CHECK-NEXT:start:
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
// CHECK-NEXT: ret void
// CHECK-NEXT:}
// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
// CHECK-NEXT: invertstart:
// CHECK-NEXT: %_4 = fmul float %x, %x
// CHECK-NEXT: %_3 = fmul float %_4, %y
// CHECK-NEXT: %_0 = fpext float %_3 to double
// CHECK-NEXT: %0 = fadd fast float %y, %y
// CHECK-NEXT: %1 = fmul fast float %0, %x
// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
// CHECK-NEXT: ret { double, float, float } %4
// CHECK-NEXT: }

fn main() {
let x = std::hint::black_box(3.0);
let y = std::hint::black_box(2.5);
let scalar = std::hint::black_box(1.0);
let (r1, r2, r3) = df(x, y, scalar);
// 3*3*1.5 = 22.5
// 3*3*2.5 = 22.5
assert_eq!(r1, 22.5);
// 2*x*y = 2*3*2.5 = 15.0
assert_eq!(r2, 15.0);
Expand Down
Loading