Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
- [Installation](./autodiff/installation.md)
- [How to debug](./autodiff/debugging.md)
- [Autodiff flags](./autodiff/flags.md)
- [Type Trees](./autodiff/type-trees.md)

# Source Code Representation

Expand Down
193 changes: 193 additions & 0 deletions src/autodiff/type-trees.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# TypeTrees for Autodiff

## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

## Structure
```rust
TypeTree(Vec<Type>)

Type {
offset: isize, // byte offset (-1 = everywhere)
size: usize, // size in bytes
kind: Kind, // Float, Integer, Pointer, etc.
child: TypeTree // nested structure
}
```

## Example: `fn compute(x: &f32, data: &[f32]) -> f32`

**Input 0: `x: &f32`**
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: 0, size: 4, kind: Float, // Single value: use offset 0
child: TypeTree::new()
}])
}])
```

**Input 1: `data: &[f32]`**
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float, // -1 = all elements
child: TypeTree::new()
}])
}])
```

**Output: `f32`**
```rust
TypeTree(vec![Type {
offset: 0, size: 4, kind: Float, // Single scalar: use offset 0
child: TypeTree::new()
}])
```

## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata

## What Enzyme Does With This Information:

Without TypeTrees:
```llvm
; Enzyme sees generic LLVM IR:
define float @distance(ptr %p1, ptr %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
```

With TypeTrees:
```llvm
define "enzyme_type"="{[-1]:Float@float}" float @distance(
ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p1,
ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
```

# TypeTrees - Offset and -1 Explained

## Type Structure

```rust
Type {
offset: isize, // WHERE this type starts
size: usize, // HOW BIG this type is
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
```

## Offset Values

### Regular Offset (0, 4, 8, etc.)
**Specific byte position within a structure**

```rust
struct Point {
x: f32, // offset 0, size 4
y: f32, // offset 4, size 4
id: i32, // offset 8, size 4
}
```

TypeTree for `&Point` (internal representation):
```rust
TypeTree(vec![
Type { offset: 0, size: 4, kind: Float }, // x at byte 0
Type { offset: 4, size: 4, kind: Float }, // y at byte 4
Type { offset: 8, size: 4, kind: Integer } // id at byte 8
])
```

Generates LLVM
```llvm
"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer}"
```

### Offset -1 (Special: "Everywhere")
**Means "this pattern repeats for ALL elements"**

#### Example 1: Direct Array `[f32; 100]` (no pointer indirection)
```rust
TypeTree(vec![Type {
offset: -1, // ALL positions
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
```

Generates LLVM: `"enzyme_type"="{[-1]:Float@float}"`

#### Example 1b: Array Reference `&[f32; 100]` (with pointer indirection)
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL array elements
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
}])
```

Generates LLVM: `"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"`

Instead of listing 100 separate Types with offsets `0,4,8,12...396`

#### Example 2: Slice `&[i32]`
```rust
// Pointer to slice data
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL slice elements
size: 4, // each i32 is 4 bytes
kind: Integer
}])
}])
```

Generates LLVM: `"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"`

#### Example 3: Mixed Structure
```rust
struct Container {
header: i64, // offset 0
data: [f32; 1000], // offset 8, but elements use -1
}
```

```rust
TypeTree(vec![
Type { offset: 0, size: 8, kind: Integer }, // header
Type { offset: 8, size: 4000, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float // ALL array elements
}])
}
])
```

## Key Distinction: Single Values vs Arrays

**Single Values** use offset `0` for precision:
- `&f32` has exactly one f32 value at offset 0
- More precise than using -1 ("everywhere")
- Generates: `{[-1]:Pointer, [-1,0]:Float@float}`

**Arrays** use offset `-1` for efficiency:
- `&[f32; 100]` has the same pattern repeated 100 times
- Using -1 avoids listing 100 separate offsets
- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}`