-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
I'm trying to export a llama2 model with StaticCache
https://huggingface.co/docs/transformers/v4.52.3/en/internal/generation_utils#transformers.StaticCache
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
inputs = tokenizer(text="My name is Llama", return_tensors="pt")
# Prepare a cache class and pass it to model's forward
# Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
max_generated_length = inputs.input_ids.shape[1] + 10
past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) # model can run with the given inputs
import torch
ep = torch.export.export(model, (inputs['input_ids'], inputs['attention_mask']), {'past_key_values': past_key_values, 'use_cache': True}) # export fails
I hit this error
Traceback (most recent call last):
File "/workspace/GenAI/static_cache.py", line 19, in <module>
ep = torch.export.export(model, (inputs['input_ids'], inputs['attention_mask']), {'past_key_values': past_key_values, 'use_cache': True})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export
raise e
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
return _export(
^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1933, in _non_strict_export
) = make_fake_inputs(
^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 347, in make_fake_inputs
fake_args, fake_kwargs = tree_map_with_path(
^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/utils/_pytree.py", line 2077, in tree_map_with_path
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/utils/_pytree.py", line 1197, in unflatten
leaves = list(leaves)
^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/utils/_pytree.py", line 2077, in <genexpr>
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 348, in <lambda>
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 124, in fakify
raise ValueError(
ValueError: Unsupported input type <class 'transformers.cache_utils.StaticCache'>. Export only supports pytree containers of basic types (Tensor, int, float, ...) as input. To register a custom dataclass, use torch.export.register_dataclass. To register a custom container type, use torch.utils._pytree.register_pytree_node. To register a constant input, use torch.utils._pytree.register_constant
Is this usage of StaticCache
correct?
I also tried playing with https://huggingface.co/docs/transformers/en/main_classes/executorch#transformers.convert_and_export_with_cache but could not set it up correctly.
Versions
PyTorch version: 2.8.0.dev20250612+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.27.0
Libc version: glibc-2.39
Python version: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A5000
Nvidia driver version: 560.28.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4