Skip to content

Export Huggingface models with StaticCache #155862

@yuanyao-nv

Description

@yuanyao-nv

🐛 Describe the bug

I'm trying to export a llama2 model with StaticCache https://huggingface.co/docs/transformers/v4.52.3/en/internal/generation_utils#transformers.StaticCache

from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache


model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

inputs = tokenizer(text="My name is Llama", return_tensors="pt")

# Prepare a cache class and pass it to model's forward
# Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
max_generated_length = inputs.input_ids.shape[1] + 10
past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)  # model can run with the given inputs

import torch
ep = torch.export.export(model, (inputs['input_ids'], inputs['attention_mask']), {'past_key_values': past_key_values, 'use_cache': True})  # export fails

I hit this error

Traceback (most recent call last):
  File "/workspace/GenAI/static_cache.py", line 19, in <module>
    ep = torch.export.export(model, (inputs['input_ids'], inputs['attention_mask']), {'past_key_values': past_key_values, 'use_cache': True})
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export
    raise e
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
    return _export(
           ^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
    raise e
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
    raise e
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/export/_trace.py", line 1933, in _non_strict_export
    ) = make_fake_inputs(
        ^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 347, in make_fake_inputs
    fake_args, fake_kwargs = tree_map_with_path(
                             ^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/utils/_pytree.py", line 2077, in tree_map_with_path
    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/utils/_pytree.py", line 1197, in unflatten
    leaves = list(leaves)
             ^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/utils/_pytree.py", line 2077, in <genexpr>
    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
                              ^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 348, in <lambda>
    lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuanyao/.local/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 124, in fakify
    raise ValueError(
ValueError: Unsupported input type <class 'transformers.cache_utils.StaticCache'>. Export only supports pytree containers of basic types (Tensor, int, float, ...) as input. To register a custom dataclass, use torch.export.register_dataclass. To register a custom container type, use torch.utils._pytree.register_pytree_node. To register a constant input, use torch.utils._pytree.register_constant

Is this usage of StaticCache correct?

I also tried playing with https://huggingface.co/docs/transformers/en/main_classes/executorch#transformers.convert_and_export_with_cache but could not set it up correctly.

Versions

PyTorch version: 2.8.0.dev20250612+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.27.0
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A5000
Nvidia driver version: 560.28.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions