Skip to content

Commit b38bcbf

Browse files
committed
autodiff: recursive depth limit added in typetree
1 parent 731a98a commit b38bcbf

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,7 +2257,25 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
22572257

22582258
/// Generate TypeTree for a specific type.
22592259
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
2260+
2261+
/// Maximum recursion depth for TypeTree generation to prevent infinite loops
2262+
/// Set to 32 levels which should be sufficient for most practical type hierarchies
2263+
/// while preventing stack overflow from pathological recursive types.
2264+
const MAX_TYPETREE_DEPTH: usize = 32;
2265+
2266+
fn typetree_from_ty_with_depth<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, depth: usize) -> TypeTree {
2267+
if depth > MAX_TYPETREE_DEPTH {
2268+
return TypeTree::new();
2269+
}
2270+
2271+
typetree_from_ty_impl(tcx, ty, depth)
2272+
}
2273+
22602274
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
2275+
typetree_from_ty_with_depth(tcx, ty, 0)
2276+
}
2277+
2278+
fn typetree_from_ty_impl<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, depth: usize) -> TypeTree {
22612279
if ty.is_scalar() {
22622280
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
22632281
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
@@ -2299,7 +2317,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22992317
return TypeTree::new();
23002318
}
23012319

2302-
let element_tree = typetree_from_ty(tcx, *element_ty);
2320+
let element_tree = typetree_from_ty_impl(tcx, *element_ty, depth + 1);
23032321

23042322
let element_layout = tcx
23052323
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
@@ -2335,7 +2353,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23352353

23362354
if ty.is_slice() {
23372355
if let ty::Slice(element_ty) = ty.kind() {
2338-
let element_tree = typetree_from_ty(tcx, *element_ty);
2356+
let element_tree = typetree_from_ty_impl(tcx, *element_ty, depth + 1);
23392357
return element_tree;
23402358
}
23412359
}
@@ -2349,7 +2367,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23492367
let mut current_offset = 0;
23502368

23512369
for tuple_ty in tuple_types.iter() {
2352-
let element_tree = typetree_from_ty(tcx, tuple_ty);
2370+
let element_tree = typetree_from_ty_impl(tcx, tuple_ty, depth + 1);
23532371

23542372
let element_layout = tcx
23552373
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
@@ -2385,7 +2403,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23852403

23862404
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
23872405
let field_ty = field_def.ty(tcx, args);
2388-
let field_tree = typetree_from_ty(tcx, field_ty);
2406+
let field_tree = typetree_from_ty_impl(tcx, field_ty, depth + 1);
23892407

23902408
let field_offset = layout.fields.offset(field_idx).bytes_usize();
23912409

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
; Check that recursive TypeTree generation doesn't infinite loop
2+
; The depth limit should prevent infinite recursion
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_recursive{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("recursive.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[repr(C)]
6+
struct Node {
7+
value: f64,
8+
next: Option<Box<Node>>,
9+
}
10+
11+
#[autodiff_reverse(d_test, Duplicated, Active)]
12+
#[no_mangle]
13+
fn test_recursive(node: &Node) -> f64 {
14+
node.value
15+
}
16+
17+
fn main() {
18+
let node = Node { value: 1.0, next: None };
19+
let mut d_node = Node { value: 0.0, next: None };
20+
let _result = d_test(&node, &mut d_node, 1.0);
21+
}

0 commit comments

Comments
 (0)