Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions SOLUTION_SUMMARY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# TorchAO FP8 + Activation Checkpointing Issue: Analysis & Solution

## Executive Summary

**Issue Confirmed**: TorchAO's FP8 training implementation is indeed unaware of activation checkpointing, causing increased memory usage instead of the expected memory savings when both techniques are used together.

**Root Cause**: The FP8 autograd function always saves high precision (HP) tensors for backward pass, conflicting with activation checkpointing's memory-saving strategy.

**Solution**: Implement checkpointing context detection to adaptively save FP8 tensors instead of HP tensors when checkpointing is active.

## Detailed Analysis

### The Problem

Your test results clearly demonstrate the issue:

| Configuration | Memory Utilization | Notes |
|---------------|-------------------|-------|
| Baseline (no FP8, no checkpointing) | 76.22% | Reference point |
| FP8 only | 74.25% | ✅ Small memory reduction |
| Checkpointing only | 16.1% | ✅ Significant memory reduction |
| **FP8 + Checkpointing** | **29.70%** | ❌ **Memory increases!** |

### Root Cause in Code

**File**: `torchao/float8/float8_linear.py`
**Function**: `matmul_with_hp_or_float8_args.forward()`
**Problem Line**: `ctx.save_for_backward(input_hp, weight_hp_t)`

```python
@staticmethod
def forward(ctx, input_hp: torch.Tensor, weight_hp_t: torch.Tensor, ...):
# ... FP8 conversion logic ...

# PROBLEM: Always saves HP tensors regardless of checkpointing
ctx.save_for_backward(input_hp, weight_hp_t) # ← This conflicts with checkpointing

# Forward computation uses FP8 tensors
return torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t)
```

### Why This Happens

1. **Activation Checkpointing Goal**: Save memory by not storing intermediate activations, recompute them during backward pass
2. **FP8 Implementation**: Always saves HP tensors for gradient computation
3. **Conflict**: When both are used, you get:
- Checkpointed activations (some saved, some recomputed)
- PLUS saved HP tensors from FP8 (always saved)
- Result: More memory usage than either technique alone

## Proposed Solution

### 1. Checkpointing Context Detection

Implement a function to detect when code is running inside activation checkpointing:

```python
def is_in_checkpointing_context() -> bool:
"""Detect if we're inside torch.utils.checkpoint.checkpoint"""
for frame_info in inspect.stack():
if frame_info.function in [
'checkpoint',
'_checkpoint_without_reentrant',
'_checkpoint_with_reentrant',
'CheckpointFunction'
]:
return True
if ('checkpoint' in frame_info.filename.lower() and
'torch' in frame_info.filename.lower()):
return True
return False
```

### 2. Adaptive Memory Management

Modify the FP8 autograd function to adapt based on checkpointing context:

```python
@staticmethod
def forward(ctx, input_hp, weight_hp_t, ...):
# Convert to FP8 (same as current)
input_fp8 = hp_tensor_to_float8_dynamic(input_hp, ...)
weight_fp8_t = hp_tensor_to_float8_dynamic(weight_hp_t, ...)

# Adaptive saving strategy
if is_in_checkpointing_context():
# Save FP8 tensors + conversion metadata (memory efficient)
ctx.save_for_backward(input_fp8, weight_fp8_t)
ctx.save_conversion_metadata(scales, dtypes, configs)
ctx.checkpointing_mode = True
else:
# Original behavior: save HP tensors
ctx.save_for_backward(input_hp, weight_hp_t)
ctx.checkpointing_mode = False

return torch.mm(input_fp8, weight_fp8_t)

@staticmethod
def backward(ctx, grad_output):
if ctx.checkpointing_mode:
# Reconstruct HP tensors from FP8 + metadata
input_fp8, weight_fp8_t = ctx.saved_tensors
input_hp = fp8_to_hp_tensor(input_fp8, ctx.input_scale, ctx.input_dtype)
weight_hp_t = fp8_to_hp_tensor(weight_fp8_t, ctx.weight_scale, ctx.weight_dtype)
else:
# Original behavior
input_hp, weight_hp_t = ctx.saved_tensors

# Continue with gradient computation...
```

## Implementation Plan

### Phase 1: Core Implementation
1. Add checkpointing detection utility
2. Modify `matmul_with_hp_or_float8_args` to use adaptive saving
3. Implement FP8-to-HP reconstruction for backward pass

### Phase 2: Testing & Validation
1. Create memory usage tests comparing all 4 configurations
2. Add numerical accuracy tests to ensure FP8 precision is maintained
3. Performance benchmarks to measure any overhead

### Phase 3: Integration
1. Update existing FP8 tests to include checkpointing scenarios
2. Add documentation explaining the checkpointing compatibility
3. Consider adding configuration options for advanced users

## Expected Benefits

### Memory Usage Improvements
- **FP8 + Checkpointing**: Should achieve memory usage similar to checkpointing alone (~16%)
- **Memory Savings**: FP8 tensors are typically 2x smaller than HP tensors
- **Automatic**: No user configuration required

### Backward Compatibility
- **Zero Breaking Changes**: Existing code continues to work unchanged
- **Automatic Detection**: Seamlessly adapts to checkpointing context
- **Fallback**: Maintains original behavior when checkpointing is not detected

## Files to Modify

1. **`torchao/float8/float8_linear.py`** - Main implementation
2. **`torchao/float8/float8_linear_utils.py`** - Add detection utilities
3. **`test/float8/test_float8_linear.py`** - Add checkpointing tests
4. **`benchmarks/float8/profile_lowp_training.py`** - Update profiling script

## Validation Strategy

### Memory Tests
```python
# Test all 4 configurations and verify:
# 1. FP8 + Checkpointing < FP8 alone
# 2. FP8 + Checkpointing ≈ Checkpointing alone
# 3. Numerical accuracy maintained
```

### Integration Tests
```python
# Test with different checkpointing strategies:
# - Full activation checkpointing
# - Selective checkpointing
# - Nested checkpointing
```

## Answer to Your Question

**"Is TorchAO unaware of activation checkpointing?"**

**YES, absolutely.** TorchAO's FP8 implementation is completely unaware of activation checkpointing. The FP8 autograd functions always save high precision tensors regardless of the checkpointing context, which directly conflicts with checkpointing's memory-saving goals.

This is why you see increased memory usage (29.70%) when combining FP8 with checkpointing, instead of the expected memory reduction. The proposed solution will make TorchAO checkpointing-aware and resolve this issue.

## Next Steps

1. **Implement the detection mechanism** in `float8_linear.py`
2. **Test with your Llama3 8B setup** to validate memory improvements
3. **Submit PR to TorchAO** with the fix and comprehensive tests
4. **Document the improvement** for other users facing this issue

The fix is straightforward and should provide significant memory savings for your use case!
160 changes: 160 additions & 0 deletions fp8_checkpoint_analysis.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# FP8 Training + Activation Checkpointing Memory Issue Analysis

## Problem Summary

When using FP8 training with activation checkpointing in TorchAO, memory utilization increases instead of decreasing. This contradicts the expected behavior where activation checkpointing should reduce memory usage.

**User's Test Results:**
- No float8, no activation checkpointing: 76.22% memory utilization
- Float8 enabled, no activation checkpointing: 74.25% memory utilization
- No float8, full activation checkpointing enabled: 16.1% memory utilization
- **Float8 enabled, full activation checkpointing enabled: 29.70% memory utilization** ⚠️

## Root Cause Analysis

The issue is in `torchao/float8/float8_linear.py` in the `matmul_with_hp_or_float8_args` autograd function:

```python
@staticmethod
def forward(ctx, input_hp: torch.Tensor, weight_hp_t: torch.Tensor, ...):
ctx.save_for_backward(input_hp, weight_hp_t) # ← PROBLEM: Always saves HP tensors
# ... forward computation using FP8 tensors
```

### Why This Causes Issues

1. **Activation Checkpointing Goal**: Save memory by not storing intermediate activations during forward pass, recompute them during backward pass.

2. **FP8 Implementation Conflict**: The FP8 autograd function explicitly saves high precision (HP) tensors in the autograd context, regardless of checkpointing.

3. **Double Memory Usage**: When both are used together:
- Activation checkpointing saves some activations for recomputation
- FP8 implementation saves additional HP tensors
- Result: More memory usage than either technique alone

### Why HP Tensors Are Saved

The backward pass needs HP tensors for:

```python
@staticmethod
def backward(ctx, grad_output):
input_hp, weight_hp_t = ctx.saved_tensors # ← Needs HP tensors

# Cast HP tensors to FP8 for gradient computation
input_maybe_fp8 = hp_tensor_to_float8_dynamic(input_hp, ...)
weight_maybe_fp8 = hp_tensor_to_float8_dynamic(weight_hp_t, ...)

# Compute gradients using FP8 tensors
grad_input = torch.mm(grad_output_fp8, weight_fp8.t())
grad_weight = torch.mm(grad_output_fp8.t(), input_fp8)
```

## Current State: TorchAO is Unaware of Activation Checkpointing

**Answer to user's question: YES, TorchAO is currently unaware of activation checkpointing.**

The FP8 implementation does not:
- Detect when it's running inside a checkpointing context
- Adapt its memory management strategy for checkpointing
- Provide checkpointing-compatible alternatives

## Potential Solutions

### Solution 1: Checkpointing Context Detection

Detect when running inside activation checkpointing and modify behavior:

```python
def is_in_checkpointing_context():
# Check if we're inside torch.utils.checkpoint.checkpoint
import inspect
for frame_info in inspect.stack():
if 'checkpoint' in frame_info.filename and 'checkpoint' in frame_info.function:
return True
return False

@staticmethod
def forward(ctx, input_hp, weight_hp_t, ...):
if is_in_checkpointing_context():
# Save only FP8 tensors and metadata for recomputation
ctx.save_for_backward(input_fp8, weight_fp8_t, scales, ...)
ctx.needs_hp_recomputation = True
else:
# Current behavior
ctx.save_for_backward(input_hp, weight_hp_t)
ctx.needs_hp_recomputation = False
```

### Solution 2: Checkpointing-Aware FP8 Function

Create a separate autograd function optimized for checkpointing:

```python
class matmul_with_fp8_checkpointing_aware(torch.autograd.Function):
@staticmethod
def forward(ctx, input_hp, weight_hp_t, ...):
# Convert to FP8
input_fp8 = hp_tensor_to_float8_dynamic(input_hp, ...)
weight_fp8_t = hp_tensor_to_float8_dynamic(weight_hp_t, ...)

# Save only FP8 tensors and conversion metadata
ctx.save_for_backward(input_fp8, weight_fp8_t)
ctx.save_conversion_metadata(scales, dtypes, configs)

return torch.mm(input_fp8, weight_fp8_t)

@staticmethod
def backward(ctx, grad_output):
input_fp8, weight_fp8_t = ctx.saved_tensors

# Convert FP8 back to HP for gradient computation
input_hp = fp8_to_hp_tensor(input_fp8, ctx.input_scale, ctx.input_dtype)
weight_hp_t = fp8_to_hp_tensor(weight_fp8_t, ctx.weight_scale, ctx.weight_dtype)

# Continue with gradient computation...
```

### Solution 3: Configuration-Based Approach

Add a configuration option to Float8LinearConfig:

```python
@dataclass
class Float8LinearConfig:
# ... existing fields ...
checkpointing_compatible: bool = False

def __post_init__(self):
if self.checkpointing_compatible:
# Use memory-efficient backward pass
self._use_checkpointing_aware_backward = True
```

## Recommended Implementation

I recommend **Solution 1** (Context Detection) as it:
- Automatically adapts to checkpointing without user configuration
- Maintains backward compatibility
- Provides the most seamless user experience

## Testing Strategy

1. **Memory Usage Tests**: Verify memory reduction with checkpointing + FP8
2. **Numerical Accuracy Tests**: Ensure FP8 precision is maintained
3. **Performance Tests**: Measure any overhead from context detection
4. **Integration Tests**: Test with various checkpointing strategies (selective, full, etc.)

## Files That Need Modification

1. `torchao/float8/float8_linear.py` - Main implementation
2. `torchao/float8/float8_linear_utils.py` - Utility functions
3. `torchao/float8/config.py` - Configuration options (if needed)
4. `test/float8/` - Add checkpointing tests

## Impact Assessment

- **Breaking Changes**: None (if implemented correctly)
- **Performance Impact**: Minimal (context detection is lightweight)
- **Memory Impact**: Significant reduction when using checkpointing + FP8
- **User Experience**: Seamless (automatic detection and adaptation)
Loading