diff --git a/src/SUMMARY.md b/src/SUMMARY.md index a16127385..249140956 100644 --- a/src/SUMMARY.md +++ b/src/SUMMARY.md @@ -108,6 +108,7 @@ - [Installation](./autodiff/installation.md) - [How to debug](./autodiff/debugging.md) - [Autodiff flags](./autodiff/flags.md) + - [Type Trees](./autodiff/type-trees.md) # Source Code Representation diff --git a/src/autodiff/type-trees.md b/src/autodiff/type-trees.md new file mode 100644 index 000000000..68cb78650 --- /dev/null +++ b/src/autodiff/type-trees.md @@ -0,0 +1,193 @@ +# TypeTrees for Autodiff + +## What are TypeTrees? +Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently. + +## Structure +```rust +TypeTree(Vec) + +Type { + offset: isize, // byte offset (-1 = everywhere) + size: usize, // size in bytes + kind: Kind, // Float, Integer, Pointer, etc. + child: TypeTree // nested structure +} +``` + +## Example: `fn compute(x: &f32, data: &[f32]) -> f32` + +**Input 0: `x: &f32`** +```rust +TypeTree(vec![Type { + offset: -1, size: 8, kind: Pointer, + child: TypeTree(vec![Type { + offset: 0, size: 4, kind: Float, // Single value: use offset 0 + child: TypeTree::new() + }]) +}]) +``` + +**Input 1: `data: &[f32]`** +```rust +TypeTree(vec![Type { + offset: -1, size: 8, kind: Pointer, + child: TypeTree(vec![Type { + offset: -1, size: 4, kind: Float, // -1 = all elements + child: TypeTree::new() + }]) +}]) +``` + +**Output: `f32`** +```rust +TypeTree(vec![Type { + offset: 0, size: 4, kind: Float, // Single scalar: use offset 0 + child: TypeTree::new() +}]) +``` + +## Why Needed? +- Enzyme can't deduce complex type layouts from LLVM IR +- Prevents slow memory pattern analysis +- Enables correct derivative computation for nested structures +- Tells Enzyme which bytes are differentiable vs metadata + +## What Enzyme Does With This Information: + +Without TypeTrees: +```llvm +; Enzyme sees generic LLVM IR: +define float @distance(ptr %p1, ptr %p2) { +; Has to guess what these pointers point to +; Slow analysis of all memory operations +; May miss optimization opportunities +} +``` + +With TypeTrees: +```llvm +define "enzyme_type"="{[-1]:Float@float}" float @distance( + ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p1, + ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p2 +) { +; Enzyme knows exact type layout +; Can generate efficient derivative code directly +} +``` + +# TypeTrees - Offset and -1 Explained + +## Type Structure + +```rust +Type { + offset: isize, // WHERE this type starts + size: usize, // HOW BIG this type is + kind: Kind, // WHAT KIND of data (Float, Int, Pointer) + child: TypeTree // WHAT'S INSIDE (for pointers/containers) +} +``` + +## Offset Values + +### Regular Offset (0, 4, 8, etc.) +**Specific byte position within a structure** + +```rust +struct Point { + x: f32, // offset 0, size 4 + y: f32, // offset 4, size 4 + id: i32, // offset 8, size 4 +} +``` + +TypeTree for `&Point` (internal representation): +```rust +TypeTree(vec![ + Type { offset: 0, size: 4, kind: Float }, // x at byte 0 + Type { offset: 4, size: 4, kind: Float }, // y at byte 4 + Type { offset: 8, size: 4, kind: Integer } // id at byte 8 +]) +``` + +Generates LLVM +```llvm +"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer}" +``` + +### Offset -1 (Special: "Everywhere") +**Means "this pattern repeats for ALL elements"** + +#### Example 1: Direct Array `[f32; 100]` (no pointer indirection) +```rust +TypeTree(vec![Type { + offset: -1, // ALL positions + size: 4, // each f32 is 4 bytes + kind: Float, // every element is float +}]) +``` + +Generates LLVM: `"enzyme_type"="{[-1]:Float@float}"` + +#### Example 1b: Array Reference `&[f32; 100]` (with pointer indirection) +```rust +TypeTree(vec![Type { + offset: -1, size: 8, kind: Pointer, + child: TypeTree(vec![Type { + offset: -1, // ALL array elements + size: 4, // each f32 is 4 bytes + kind: Float, // every element is float + }]) +}]) +``` + +Generates LLVM: `"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"` + +Instead of listing 100 separate Types with offsets `0,4,8,12...396` + +#### Example 2: Slice `&[i32]` +```rust +// Pointer to slice data +TypeTree(vec![Type { + offset: -1, size: 8, kind: Pointer, + child: TypeTree(vec![Type { + offset: -1, // ALL slice elements + size: 4, // each i32 is 4 bytes + kind: Integer + }]) +}]) +``` + +Generates LLVM: `"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"` + +#### Example 3: Mixed Structure +```rust +struct Container { + header: i64, // offset 0 + data: [f32; 1000], // offset 8, but elements use -1 +} +``` + +```rust +TypeTree(vec![ + Type { offset: 0, size: 8, kind: Integer }, // header + Type { offset: 8, size: 4000, kind: Pointer, + child: TypeTree(vec![Type { + offset: -1, size: 4, kind: Float // ALL array elements + }]) + } +]) +``` + +## Key Distinction: Single Values vs Arrays + +**Single Values** use offset `0` for precision: +- `&f32` has exactly one f32 value at offset 0 +- More precise than using -1 ("everywhere") +- Generates: `{[-1]:Pointer, [-1,0]:Float@float}` + +**Arrays** use offset `-1` for efficiency: +- `&[f32; 100]` has the same pattern repeated 100 times +- Using -1 avoids listing 100 separate offsets +- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}` \ No newline at end of file