Skip to content

Commit b4862a0

Browse files
committed
feat: add cache all and limit cache point in AnthropicModel
1 parent 1b576dd commit b4862a0

File tree

3 files changed

+309
-4
lines changed

3 files changed

+309
-4
lines changed

docs/models/anthropic.md

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,29 @@ agent = Agent(model)
8080

8181
## Prompt Caching
8282

83-
Anthropic supports [prompt caching](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) to reduce costs by caching parts of your prompts. Pydantic AI provides three ways to use prompt caching:
83+
Anthropic supports [prompt caching](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) to reduce costs by caching parts of your prompts. Pydantic AI provides four ways to use prompt caching:
8484

8585
1. **Cache User Messages with [`CachePoint`][pydantic_ai.messages.CachePoint]**: Insert a `CachePoint` marker in your user messages to cache everything before it
8686
2. **Cache System Instructions**: Set [`AnthropicModelSettings.anthropic_cache_instructions`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_instructions] to `True` (uses 5m TTL by default) or specify `'5m'` / `'1h'` directly
8787
3. **Cache Tool Definitions**: Set [`AnthropicModelSettings.anthropic_cache_tool_definitions`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_tool_definitions] to `True` (uses 5m TTL by default) or specify `'5m'` / `'1h'` directly
88+
4. **Cache All (Convenience)**: Set [`AnthropicModelSettings.anthropic_cache_all`][pydantic_ai.models.anthropic.AnthropicModelSettings.anthropic_cache_all] to `True` to automatically cache both system instructions and the last user message
8889

89-
You can combine all three strategies for maximum savings:
90+
You can combine multiple strategies for maximum savings:
9091

9192
```python {test="skip"}
9293
from pydantic_ai import Agent, CachePoint, RunContext
9394
from pydantic_ai.models.anthropic import AnthropicModelSettings
9495

96+
# Option 1: Use anthropic_cache_all for convenience (caches system + last message)
97+
agent = Agent(
98+
'anthropic:claude-sonnet-4-5',
99+
system_prompt='Detailed instructions...',
100+
model_settings=AnthropicModelSettings(
101+
anthropic_cache_all=True, # Caches both system prompt and last message
102+
),
103+
)
104+
105+
# Option 2: Fine-grained control with individual settings
95106
agent = Agent(
96107
'anthropic:claude-sonnet-4-5',
97108
system_prompt='Detailed instructions...',
@@ -145,3 +156,37 @@ async def main():
145156
print(f'Cache write tokens: {usage.cache_write_tokens}')
146157
print(f'Cache read tokens: {usage.cache_read_tokens}')
147158
```
159+
160+
### Cache Point Limits
161+
162+
Anthropic enforces a maximum of 4 cache points per request. Pydantic AI automatically manages this limit:
163+
164+
- **`anthropic_cache_all`**: Uses 2 cache points (system instructions + last message)
165+
- **`anthropic_cache_instructions`**: Uses 1 cache point
166+
- **`anthropic_cache_tool_definitions`**: Uses 1 cache point
167+
- **`CachePoint` markers**: Use remaining available cache points
168+
169+
When the total exceeds 4 cache points, Pydantic AI automatically removes cache points from **older messages** (keeping the most recent ones), ensuring your requests always comply with Anthropic's limits without errors.
170+
171+
```python {test="skip"}
172+
from pydantic_ai import Agent, CachePoint
173+
from pydantic_ai.models.anthropic import AnthropicModelSettings
174+
175+
agent = Agent(
176+
'anthropic:claude-sonnet-4-5',
177+
system_prompt='Instructions...',
178+
model_settings=AnthropicModelSettings(
179+
anthropic_cache_all=True, # Uses 2 cache points
180+
),
181+
)
182+
183+
async def main():
184+
# Even with multiple CachePoint markers, only 2 more will be kept
185+
# (4 total limit - 2 from cache_all = 2 available)
186+
result = await agent.run([
187+
'Context 1', CachePoint(), # Will be kept
188+
'Context 2', CachePoint(), # Will be kept
189+
'Context 3', CachePoint(), # Automatically removed (oldest)
190+
'Question'
191+
])
192+
```

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,22 @@ class AnthropicModelSettings(ModelSettings, total=False):
169169
See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information.
170170
"""
171171

172+
anthropic_cache_all: bool | Literal['5m', '1h']
173+
"""Convenience setting to enable caching for both system instructions and the last user message.
174+
175+
When enabled, this automatically adds cache points to:
176+
1. The last system prompt block (system instructions)
177+
2. The last content block in the final user message
178+
179+
This is equivalent to setting both `anthropic_cache_instructions` and adding a cache point
180+
to the last message, but more convenient for common use cases.
181+
If `True`, uses TTL='5m'. You can also specify '5m' or '1h' directly.
182+
183+
Note: Uses 2 of Anthropic's 4 available cache points per request. Any additional CachePoint
184+
markers in messages will be automatically limited to respect the 4-cache-point maximum.
185+
See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information.
186+
"""
187+
172188

173189
@dataclass(init=False)
174190
class AnthropicModel(Model):
@@ -478,7 +494,10 @@ def _get_tools(
478494
]
479495

480496
# Add cache_control to the last tool if enabled
481-
if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')):
497+
if tools and (
498+
cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')
499+
or model_settings.get('anthropic_cache_all')
500+
):
482501
# If True, use '5m'; otherwise use the specified ttl value
483502
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
484503
last_tool = tools[-1]
@@ -747,8 +766,32 @@ async def _map_message( # noqa: C901
747766
system_prompt_parts.insert(0, instructions)
748767
system_prompt = '\n\n'.join(system_prompt_parts)
749768

769+
# Add cache_control to the last message content if anthropic_cache_all is enabled
770+
if anthropic_messages and (cache_all := model_settings.get('anthropic_cache_all')):
771+
ttl: Literal['5m', '1h'] = '5m' if cache_all is True else cache_all
772+
m = anthropic_messages[-1]
773+
content = m['content']
774+
if isinstance(content, str):
775+
# Convert string content to list format with cache_control
776+
m['content'] = [
777+
{
778+
'text': content,
779+
'type': 'text',
780+
'cache_control': BetaCacheControlEphemeralParam(type='ephemeral', ttl=ttl),
781+
}
782+
]
783+
else:
784+
# Add cache_control to the last content block
785+
content = cast(list[BetaContentBlockParam], content)
786+
self._add_cache_control_to_last_param(content, ttl)
787+
788+
# Ensure total cache points don't exceed Anthropic's limit of 4
789+
self._limit_cache_points(anthropic_messages, model_settings)
750790
# If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control
751-
if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')):
791+
if system_prompt and (
792+
cache_instructions := model_settings.get('anthropic_cache_instructions')
793+
or model_settings.get('anthropic_cache_all')
794+
):
752795
# If True, use '5m'; otherwise use the specified ttl value
753796
ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions
754797
system_prompt_blocks = [
@@ -762,6 +805,63 @@ async def _map_message( # noqa: C901
762805

763806
return system_prompt, anthropic_messages
764807

808+
@staticmethod
809+
def _limit_cache_points(messages: list[BetaMessageParam], model_settings: AnthropicModelSettings) -> None:
810+
"""Limit the number of cache points in messages to comply with Anthropic's 4-cache-point maximum.
811+
812+
Anthropic allows a maximum of 4 cache points per request. This method ensures compliance by:
813+
1. Calculating how many cache points are already used by system-level settings
814+
(anthropic_cache_instructions, anthropic_cache_tool_definitions, anthropic_cache_all)
815+
2. Determining how many cache points remain available for message-level caching
816+
3. Traversing messages from newest to oldest, keeping only the allowed number of cache points
817+
4. Removing cache_control from older cache points that exceed the limit
818+
819+
This prioritizes recent cache points, which are typically more valuable for conversation continuity.
820+
821+
Args:
822+
messages: List of message parameters to limit cache points in.
823+
model_settings: Model settings containing cache configuration.
824+
"""
825+
# Anthropic's maximum cache points per request
826+
max_cache_points = 4
827+
used_cache_points = 0
828+
829+
# Calculate cache points used by system-level settings
830+
if model_settings.get('anthropic_cache_all'):
831+
# anthropic_cache_all adds cache points for both system instructions and last message
832+
used_cache_points += 2
833+
else:
834+
if model_settings.get('anthropic_cache_instructions'):
835+
used_cache_points += 1
836+
if model_settings.get('anthropic_cache_tool_definitions'):
837+
# Assume used one cache point for tool definitions
838+
used_cache_points += 1
839+
840+
# Calculate remaining cache points available for message content
841+
keep_cache_points = max_cache_points - used_cache_points
842+
843+
# Traverse messages from back to front (newest to oldest)
844+
remaining_cache_points = keep_cache_points
845+
for message in reversed(messages):
846+
content = message['content']
847+
# Skip if content is a string or None
848+
if isinstance(content, str):
849+
continue
850+
content = cast(list[BetaContentBlockParam], content)
851+
# Traverse content blocks from back to front within each message
852+
for block in reversed(content):
853+
# Cast to dict for TypedDict manipulation
854+
block_dict = cast(dict[str, Any], block)
855+
856+
# Check if this block has cache_control
857+
if 'cache_control' in block_dict:
858+
if remaining_cache_points > 0:
859+
# Keep this cache point (within limit)
860+
remaining_cache_points -= 1
861+
else:
862+
# Remove cache_control as we've exceeded the limit
863+
del block_dict['cache_control']
864+
765865
@staticmethod
766866
def _add_cache_control_to_last_param(params: list[BetaContentBlockParam], ttl: Literal['5m', '1h'] = '5m') -> None:
767867
"""Add cache control to the last content block param.

tests/models/test_anthropic.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,166 @@ def my_tool(value: str) -> str: # pragma: no cover
588588
assert system[0]['cache_control'] == snapshot({'type': 'ephemeral', 'ttl': '5m'})
589589

590590

591+
async def test_anthropic_cache_all(allow_model_requests: None):
592+
"""Test that anthropic_cache_all caches both system instructions and last message."""
593+
c = completion_message(
594+
[BetaTextBlock(text='Response', type='text')],
595+
usage=BetaUsage(input_tokens=10, output_tokens=5),
596+
)
597+
mock_client = MockAnthropic.create_mock(c)
598+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
599+
agent = Agent(
600+
m,
601+
system_prompt='System instructions to cache.',
602+
model_settings=AnthropicModelSettings(
603+
anthropic_cache_all=True,
604+
),
605+
)
606+
607+
await agent.run('User message')
608+
609+
# Verify both system and last message have cache_control
610+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
611+
system = completion_kwargs['system']
612+
messages = completion_kwargs['messages']
613+
614+
# System should have cache_control
615+
assert system == snapshot(
616+
[{'type': 'text', 'text': 'System instructions to cache.', 'cache_control': {'type': 'ephemeral', 'ttl': '5m'}}]
617+
)
618+
619+
# Last message content should have cache_control
620+
assert messages[-1]['content'][-1] == snapshot(
621+
{'type': 'text', 'text': 'User message', 'cache_control': {'type': 'ephemeral', 'ttl': '5m'}}
622+
)
623+
624+
625+
async def test_anthropic_cache_all_with_custom_ttl(allow_model_requests: None):
626+
"""Test that anthropic_cache_all supports custom TTL values."""
627+
c = completion_message(
628+
[BetaTextBlock(text='Response', type='text')],
629+
usage=BetaUsage(input_tokens=10, output_tokens=5),
630+
)
631+
mock_client = MockAnthropic.create_mock(c)
632+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
633+
agent = Agent(
634+
m,
635+
system_prompt='System instructions.',
636+
model_settings=AnthropicModelSettings(
637+
anthropic_cache_all='1h', # Custom 1h TTL
638+
),
639+
)
640+
641+
await agent.run('User message')
642+
643+
# Verify both use 1h TTL
644+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
645+
system = completion_kwargs['system']
646+
messages = completion_kwargs['messages']
647+
648+
assert system[0]['cache_control'] == snapshot({'type': 'ephemeral', 'ttl': '1h'})
649+
assert messages[-1]['content'][-1]['cache_control'] == snapshot({'type': 'ephemeral', 'ttl': '1h'})
650+
651+
652+
async def test_limit_cache_points_with_cache_all(allow_model_requests: None):
653+
"""Test that cache points are limited when using cache_all + CachePoint markers."""
654+
c = completion_message(
655+
[BetaTextBlock(text='Response', type='text')],
656+
usage=BetaUsage(input_tokens=10, output_tokens=5),
657+
)
658+
mock_client = MockAnthropic.create_mock(c)
659+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
660+
agent = Agent(
661+
m,
662+
system_prompt='System instructions.',
663+
model_settings=AnthropicModelSettings(
664+
anthropic_cache_all=True, # Uses 2 cache points
665+
),
666+
)
667+
668+
# Add 3 CachePoint markers (total would be 5: 2 from cache_all + 3 from markers)
669+
# Only 2 CachePoint markers should be kept (newest ones)
670+
await agent.run(
671+
[
672+
'Context 1',
673+
CachePoint(), # Oldest, should be removed
674+
'Context 2',
675+
CachePoint(), # Should be kept
676+
'Context 3',
677+
CachePoint(), # Should be kept
678+
'Question',
679+
]
680+
)
681+
682+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
683+
messages = completion_kwargs['messages']
684+
685+
# Count cache_control occurrences in messages
686+
cache_count = 0
687+
for msg in messages:
688+
for block in msg['content']:
689+
if 'cache_control' in block:
690+
cache_count += 1
691+
692+
# anthropic_cache_all uses 2 cache points (system + last message)
693+
# With 3 CachePoint markers, we'd have 5 total
694+
# Limit is 4, so 1 oldest CachePoint should be removed
695+
# Result: 2 cache points in messages (from the 2 newest CachePoints)
696+
# The cache_all's last message cache is applied after limiting
697+
assert cache_count == 2
698+
699+
700+
async def test_limit_cache_points_all_settings(allow_model_requests: None):
701+
"""Test cache point limiting with all cache settings enabled."""
702+
c = completion_message(
703+
[BetaTextBlock(text='Response', type='text')],
704+
usage=BetaUsage(input_tokens=10, output_tokens=5),
705+
)
706+
mock_client = MockAnthropic.create_mock(c)
707+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
708+
709+
agent = Agent(
710+
m,
711+
system_prompt='System instructions.',
712+
model_settings=AnthropicModelSettings(
713+
anthropic_cache_instructions=True, # 1 cache point
714+
anthropic_cache_tool_definitions=True, # 1 cache point
715+
),
716+
)
717+
718+
@agent.tool_plain
719+
def my_tool() -> str: # pragma: no cover
720+
return 'result'
721+
722+
# Add 3 CachePoint markers (total would be 5: 2 from settings + 3 from markers)
723+
# Only 2 CachePoint markers should be kept
724+
await agent.run(
725+
[
726+
'Context 1',
727+
CachePoint(), # Oldest, should be removed
728+
'Context 2',
729+
CachePoint(), # Should be kept
730+
'Context 3',
731+
CachePoint(), # Should be kept
732+
'Question',
733+
]
734+
)
735+
736+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
737+
messages = completion_kwargs['messages']
738+
739+
# Count cache_control in messages (excluding system and tools)
740+
cache_count = 0
741+
for msg in messages:
742+
for block in msg['content']:
743+
if 'cache_control' in block:
744+
cache_count += 1
745+
746+
# Should have exactly 2 cache points in messages
747+
# (4 total - 1 system - 1 tool = 2 available for messages)
748+
assert cache_count == 2
749+
750+
591751
async def test_async_request_text_response(allow_model_requests: None):
592752
c = completion_message(
593753
[BetaTextBlock(text='world', type='text')],

0 commit comments

Comments
 (0)