Skip to content
Snippets Groups Projects
Unverified Commit 5731f915 authored by Nicholas Albion's avatar Nicholas Albion Committed by GitHub
Browse files

support caching of anthropic system prompt (#18008)

parent ea1f987b
No related branches found
No related tags found
No related merge requests found
...@@ -467,14 +467,20 @@ class Anthropic(FunctionCallingLLM): ...@@ -467,14 +467,20 @@ class Anthropic(FunctionCallingLLM):
chat_history.append(user_msg) chat_history.append(user_msg)
tool_dicts = [] tool_dicts = []
for tool in tools: if tools:
tool_dicts.append( for tool in tools:
{ tool_dicts.append(
"name": tool.metadata.name, {
"description": tool.metadata.description, "name": tool.metadata.name,
"input_schema": tool.metadata.get_parameters_dict(), "description": tool.metadata.description,
} "input_schema": tool.metadata.get_parameters_dict(),
) }
)
if "prompt-caching" in kwargs.get("extra_headers", {}).get(
"anthropic-beta", ""
):
tool_dicts[-1]["cache_control"] = {"type": "ephemeral"}
return {"messages": chat_history, "tools": tool_dicts, **kwargs} return {"messages": chat_history, "tools": tool_dicts, **kwargs}
def _validate_chat_with_tools_response( def _validate_chat_with_tools_response(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Utility functions for the Anthropic SDK LLM integration. Utility functions for the Anthropic SDK LLM integration.
""" """
from typing import Dict, Sequence, Tuple from typing import Any, Dict, Sequence, Tuple
from llama_index.core.base.llms.types import ( from llama_index.core.base.llms.types import (
ChatMessage, ChatMessage,
...@@ -139,13 +139,16 @@ def messages_to_anthropic_messages( ...@@ -139,13 +139,16 @@ def messages_to_anthropic_messages(
- System prompt - System prompt
""" """
anthropic_messages = [] anthropic_messages = []
system_prompt = "" system_prompt = []
for message in messages: for message in messages:
if message.role == MessageRole.SYSTEM: if message.role == MessageRole.SYSTEM:
# For system messages, concatenate all text blocks
for block in message.blocks: for block in message.blocks:
if isinstance(block, TextBlock): if isinstance(block, TextBlock) and block.text:
system_prompt += block.text + "\n" system_prompt.append(
_text_block_to_anthropic_message(
block, message.additional_kwargs
)
)
elif message.role == MessageRole.FUNCTION or message.role == MessageRole.TOOL: elif message.role == MessageRole.FUNCTION or message.role == MessageRole.TOOL:
content = ToolResultBlockParam( content = ToolResultBlockParam(
tool_use_id=message.additional_kwargs["tool_call_id"], tool_use_id=message.additional_kwargs["tool_call_id"],
...@@ -161,19 +164,12 @@ def messages_to_anthropic_messages( ...@@ -161,19 +164,12 @@ def messages_to_anthropic_messages(
content: list[TextBlockParam | ImageBlockParam] = [] content: list[TextBlockParam | ImageBlockParam] = []
for block in message.blocks: for block in message.blocks:
if isinstance(block, TextBlock): if isinstance(block, TextBlock):
content_ = ( if block.text:
TextBlockParam( content.append(
text=block.text, _text_block_to_anthropic_message(
type="text", block, message.additional_kwargs
cache_control=CacheControlEphemeralParam(type="ephemeral"), )
) )
if "cache_control" in message.additional_kwargs
else TextBlockParam(text=block.text, type="text")
)
# avoid empty text blocks
if content_["text"]:
content.append(content_)
elif isinstance(block, ImageBlock): elif isinstance(block, ImageBlock):
# FUTURE: Claude does not support URLs, so we need to always convert to base64 # FUTURE: Claude does not support URLs, so we need to always convert to base64
img_bytes = block.resolve_image(as_base64=True).read() img_bytes = block.resolve_image(as_base64=True).read()
...@@ -214,7 +210,19 @@ def messages_to_anthropic_messages( ...@@ -214,7 +210,19 @@ def messages_to_anthropic_messages(
content=content, content=content,
) )
anthropic_messages.append(anth_message) anthropic_messages.append(anth_message)
return __merge_common_role_msgs(anthropic_messages), system_prompt.strip() return __merge_common_role_msgs(anthropic_messages), system_prompt
def _text_block_to_anthropic_message(
block: TextBlock, kwargs: dict[str, Any]
) -> TextBlockParam:
if "cache_control" in kwargs:
return TextBlockParam(
text=block.text,
type="text",
cache_control=CacheControlEphemeralParam(type="ephemeral"),
)
return TextBlockParam(text=block.text, type="text")
# Function used in bedrock # Function used in bedrock
......
...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"] ...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-llms-anthropic" name = "llama-index-llms-anthropic"
readme = "README.md" readme = "README.md"
version = "0.6.7" version = "0.6.8"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment