-
Notifications
You must be signed in to change notification settings - Fork 54
Added memory optimization for onnx transforms #538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5d1fecc
85c8a0d
7436ccb
ca8a0e8
76c5329
e636e24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,17 +5,25 @@ | |
# | ||
# ---------------------------------------------------------------------------- | ||
|
||
import gc | ||
import logging | ||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
from onnx import ModelProto, external_data_helper, numpy_helper | ||
|
||
from QEfficient.utils.constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OnnxTransform: | ||
""" | ||
OnnxTransform is the base class for graph modifications on exported onnx. | ||
""" | ||
|
||
_external_data_loaded_cache = {} # Dict[int, bool] | ||
|
||
def __init__(self): | ||
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") | ||
|
||
|
@@ -31,6 +39,68 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: | |
""" | ||
raise NotImplementedError("Use subclasses for ONNX transform") | ||
|
||
@classmethod | ||
def _check_external_data_loaded(cls, model: ModelProto) -> bool: | ||
""" | ||
Check if external data is already loaded in the model. | ||
|
||
:param model: The ONNX model to check | ||
:returns: True if external data is already loaded, False otherwise | ||
""" | ||
# Use object ID as key instead of the object itself | ||
model_id = id(model) | ||
# Return cached result if available | ||
if model_id in cls._external_data_loaded_cache: | ||
return cls._external_data_loaded_cache[model_id] | ||
|
||
# Load the model if not already loaded | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we think of skipping this extra loop for checking whether for all the tensors external data has been loaded or not. The place where we are loading the external data there we can maintain a flag. This flag by default will be set to false and then once all the external data is loaded we can mark it to TRUE. Then in code we may have to just check the flag. or may not need this function if you want to directly use the flag. |
||
# Check if tensor has external data but no raw data loaded | ||
if len(tensor.external_data) > 0 and not tensor.HasField("raw_data"): | ||
cls._external_data_loaded_cache[model_id] = False | ||
return False | ||
|
||
cls._external_data_loaded_cache[model_id] = True | ||
return True | ||
|
||
@classmethod | ||
def _load_external_data(cls, model: ModelProto, onnx_base_dir: Optional[str] = None): | ||
""" | ||
Performs a bulk load of external data if it's not already loaded. | ||
Updates the cache upon successful load. | ||
""" | ||
model_id = id(model) | ||
if not cls._check_external_data_loaded(model): | ||
logger.info("External data not loaded. Performing bulk load.") | ||
external_data_helper.load_external_data_for_model(model, onnx_base_dir) | ||
cls._external_data_loaded_cache[model_id] = True | ||
else: | ||
logger.info("External data already loaded (or cached). Skipping bulk load.") | ||
|
||
@classmethod | ||
def _cleanup_external_data_and_cache(cls, model: ModelProto): | ||
""" | ||
Combines clearing external data from the model and its cache entry. | ||
""" | ||
# Remove the loaded raw data from tensors | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
if tensor.HasField("raw_data"): | ||
tensor.ClearField("raw_data") | ||
|
||
# Clear the cache entry for this model using its ID | ||
model_id = id(model) | ||
if model_id in cls._external_data_loaded_cache: | ||
del cls._external_data_loaded_cache[model_id] | ||
|
||
logger.info("External data and cache cleaned up.") | ||
|
||
@classmethod | ||
def _cleanup_memory(cls): | ||
""" | ||
Force garbage collection to free up memory after tensor processing. | ||
""" | ||
gc.collect() | ||
|
||
|
||
class FP16ClipTransform(OnnxTransform): | ||
""" | ||
|
@@ -42,26 +112,42 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar | |
""" | ||
:param onnx_base_dir: Base directory to load tensors | ||
""" | ||
finfo = np.finfo(np.float16) | ||
fp16_max = finfo.max | ||
fp16_min = finfo.min | ||
transformed = False | ||
try: | ||
# --- FIX: Ensure external data is loaded efficiently BEFORE processing --- | ||
cls._load_external_data(model, onnx_base_dir) | ||
|
||
for tensor in external_data_helper._get_all_tensors(model): | ||
nptensor = numpy_helper.to_array(tensor, onnx_base_dir) | ||
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): | ||
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) | ||
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) | ||
finfo = np.finfo(np.float16) | ||
fp16_max = finfo.max | ||
fp16_min = finfo.min | ||
transformed = False | ||
|
||
processed_count = 0 | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
nptensor = numpy_helper.to_array(tensor) # Removed onnx_base_dir as data is already loaded | ||
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): | ||
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) | ||
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) | ||
|
||
# Restore -inf values | ||
if neg_inf_mask.any(): | ||
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) | ||
# Restore -inf values | ||
if neg_inf_mask.any(): | ||
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) | ||
|
||
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) | ||
tensor.CopyFrom(new_tensor) | ||
transformed = True | ||
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) | ||
tensor.CopyFrom(new_tensor) | ||
transformed = True | ||
|
||
return model, transformed | ||
del neg_inf_mask, clipped_tensor, new_tensor | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this loop itself you can check and then update flag |
||
del nptensor | ||
processed_count += 1 | ||
|
||
if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0: | ||
cls._cleanup_memory() | ||
|
||
return model, transformed | ||
finally: | ||
# Ensure cleanup happens even if an exception occurs | ||
cls._cleanup_memory() | ||
|
||
|
||
class SplitTensorsTransform(OnnxTransform): | ||
|
@@ -86,16 +172,30 @@ def apply( | |
:param file_chunk_size: Chunk size to split external files into. | ||
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally. | ||
""" | ||
file_num = 0 | ||
current_file_size = 0 | ||
transformed = False | ||
external_data_helper.load_external_data_for_model(model, onnx_base_dir) | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): | ||
transformed = True | ||
current_file_size += tsize | ||
if current_file_size > file_chunk_size: | ||
file_num += 1 | ||
current_file_size = tsize | ||
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") | ||
return model, transformed | ||
try: | ||
file_num = 0 | ||
current_file_size = 0 | ||
transformed = False | ||
|
||
# --- Adjustment: The initial check and load will now use the new bulk loader --- | ||
# This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself. | ||
cls._load_external_data(model, onnx_base_dir) | ||
|
||
processed_count = 0 | ||
for tensor in external_data_helper._get_all_tensors(model): | ||
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): | ||
transformed = True | ||
current_file_size += tsize | ||
if current_file_size > file_chunk_size: | ||
file_num += 1 | ||
current_file_size = tsize | ||
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") | ||
|
||
processed_count += 1 | ||
if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0: | ||
cls._cleanup_memory() | ||
|
||
return model, transformed | ||
finally: | ||
# Ensure cleanup happens even if an exception occurs | ||
cls._cleanup_memory() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# QEfficient Memory Profiling | ||
|
||
A memory profiling solution for QEfficient workflows with manual operation marking. | ||
|
||
|
||
|
||
## Quick Start | ||
|
||
```python | ||
from scripts.memory_profiling import QEffMemoryProfiler | ||
from QEfficient import QEFFAutoModelForCausalLM | ||
|
||
# Initialize profiler | ||
profiler = QEffMemoryProfiler(verbose=True) | ||
profiler.start_monitoring() | ||
|
||
# Your QEfficient workflow | ||
model = QEFFAutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | ||
model.export() | ||
model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16) | ||
output = model.generate(prompts=["Hello world"]) | ||
|
||
# Generate report and visualization | ||
profiler.stop_monitoring() | ||
print(profiler.get_memory_report()) | ||
profiler.generate_memory_graph("profile.png") | ||
``` | ||
|
||
## Configuration | ||
|
||
### Basic Configuration | ||
|
||
```python | ||
profiler = QEffMemoryProfiler( | ||
sampling_interval=0.1, # Sample every 100ms | ||
output_file="my_profile.png", # Custom output file | ||
verbose=True, # Enable detailed logging | ||
enable_cpu_monitoring=True, # Monitor CPU usage | ||
enable_disk_monitoring=True, # Monitor disk I/O | ||
) | ||
``` | ||
|
||
### Manual Operation Marking | ||
|
||
```python | ||
profiler = QEffMemoryProfiler() | ||
profiler.start_monitoring() | ||
|
||
# Manual operation marking | ||
profiler.mark_operation("Custom Operation 1") | ||
# ... your code ... | ||
|
||
profiler.mark_operation("Custom Operation 2") | ||
# ... more code ... | ||
|
||
profiler.stop_monitoring() | ||
``` | ||
|
||
## API Reference | ||
|
||
### QEffMemoryProfiler | ||
|
||
#### Constructor Parameters | ||
|
||
| Parameter | Type | Default | Description | | ||
|-----------|------|---------|-------------| | ||
| `sampling_interval` | `float` | `0.05` | Time between samples (seconds) | | ||
| `output_file` | `str` | `"qeff_memory_profile.png"` | Output file path | | ||
| `verbose` | `bool` | `False` | Enable verbose logging | | ||
| `enable_cpu_monitoring` | `bool` | `True` | Monitor CPU usage | | ||
| `enable_disk_monitoring` | `bool` | `True` | Monitor disk I/O | | ||
|
||
#### Methods | ||
|
||
- **`start_monitoring()`**: Start background monitoring | ||
- **`stop_monitoring()`**: Stop monitoring and mark completion | ||
- **`mark_operation(name: str)`**: Manually mark operation start | ||
- **`get_memory_report() -> str`**: Generate comprehensive text report | ||
- **`generate_memory_graph(filename: str)`**: Create visualization | ||
- **`stop_and_save(filename: str) -> str`**: Convenience method to stop and save | ||
|
||
#### Properties | ||
|
||
- **`peak_rss`**: Peak RSS memory usage (MB) | ||
- **`peak_operation`**: Operation during peak memory | ||
- **`samples`**: List of collected profiling samples | ||
- **`operations`**: List of marked operations with timestamps | ||
|
||
## Operation Types | ||
|
||
The profiler supports marking these common QEfficient operations: | ||
|
||
- **Model Loading**: `from_pretrained`, `AutoModel`, `AutoTokenizer` | ||
- **Export**: `model.export()`, ONNX transforms, PyTorch transforms | ||
- **Compilation**: `model.compile()`, QNN compilation | ||
- **Generation**: `model.generate()`, inference execution | ||
- **Cleanup**: Memory cleanup, garbage collection | ||
|
||
## Output | ||
|
||
### Console Report | ||
``` | ||
QEFFICIENT PERFORMANCE MONITORING REPORT | ||
============================================================ | ||
Peak Memory Usage: | ||
• RSS (Physical): 18.7 GB at 14:23:45 | ||
• Peak during: Compilation | ||
|
||
Memory Statistics: | ||
• Current RSS: 16.2 GB (Delta: +15.8 GB) | ||
• Duration: 185.3 seconds | ||
• Operations: 4 | ||
|
||
QEfficient Operations Timeline: | ||
1. 0.0s - Model Loading (25.2s) [+8.2 GB] | ||
2. 25.2s - Export (15.4s) [+2.1 GB] | ||
3. 40.6s - Compilation (120.8s) [+6.3 GB] <- Peak | ||
4. 161.4s - Generation (18.7s) [+1.2 GB] | ||
``` | ||
|
||
### Visualization | ||
|
||
The profiler generates a comprehensive 4-panel visualization: | ||
|
||
1. **Memory Timeline**: RSS usage with colored operation phases | ||
2. **CPU Usage**: CPU utilization with performance zones | ||
3. **Disk I/O**: Read/write activity per operation phase | ||
4. **Phase Duration**: Timing analysis with duration labels | ||
|
||
#### Sample Output | ||
|
||
 | ||
|
||
*Example memory profiling output showing QEfficient workflow phases including model loading, ONNX transforms, compilation, and generation phases with detailed memory, CPU, and disk I/O metrics.* | ||
|
||
## Advanced Usage | ||
|
||
|
||
### Accessing Raw Data | ||
|
||
```python | ||
# Get synchronized data arrays | ||
data = profiler.get_synchronized_data() | ||
timestamps = data['timestamps'] | ||
memory_usage = data['rss_memory'] | ||
cpu_usage = data['cpu_usage'] | ||
|
||
# Access individual samples | ||
for sample in profiler.samples: | ||
print(f"Time: {sample.timestamp}, RSS: {sample.rss_mb} MB") | ||
``` | ||
|
||
## Integration Examples | ||
|
||
### With Existing QEfficient Scripts | ||
|
||
```python | ||
# Add to existing QEfficient workflow | ||
profiler = QEffMemoryProfiler(output_file="workflow_profile.png") | ||
profiler.start_monitoring() | ||
|
||
# Existing QEfficient code unchanged | ||
model = QEFFAutoModelForCausalLM.from_pretrained(model_name) | ||
# ... rest of workflow ... | ||
|
||
# Add at end | ||
report = profiler.stop_and_save() | ||
print(report) | ||
``` | ||
|
||
|
||
## Compatibility | ||
|
||
- **Python**: 3.7+ | ||
- **Dependencies**: `psutil`, `matplotlib`, `numpy` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: TRANSFORM - spell check