From ae5fe63843bf469d36ad9e5dc250e7fd69749879 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Thu, 4 Apr 2024 18:38:53 -0600 Subject: [PATCH] add anthropic tool calling (#12591) * add anthropic tool calling * update docs --- .../docs/examples/agent/anthropic_agent.ipynb | 414 ++++++++++++++++++ docs/mkdocs.yml | 1 + .../llama_index/llms/anthropic/base.py | 170 ++++++- .../llama_index/llms/anthropic/utils.py | 52 ++- .../llama-index-llms-anthropic/pyproject.toml | 4 +- 5 files changed, 624 insertions(+), 17 deletions(-) create mode 100644 docs/docs/examples/agent/anthropic_agent.ipynb diff --git a/docs/docs/examples/agent/anthropic_agent.ipynb b/docs/docs/examples/agent/anthropic_agent.ipynb new file mode 100644 index 000000000..4ff125a99 --- /dev/null +++ b/docs/docs/examples/agent/anthropic_agent.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "24103c51", + "metadata": {}, + "source": [ + "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/docs/examples/agent/mistral_agent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" + ] + }, + { + "cell_type": "markdown", + "id": "99cea58c-48bc-4af6-8358-df9695659983", + "metadata": {}, + "source": [ + "# Function Calling Anthropic Agent" + ] + }, + { + "cell_type": "markdown", + "id": "673df1fe-eb6c-46ea-9a73-a96e7ae7942e", + "metadata": {}, + "source": [ + "This notebook shows you how to use our Anthropic agent, powered by function calling capabilities.\n", + "\n", + "**NOTE:** Only claude-3 models support function calling using Anthropic's API." + ] + }, + { + "cell_type": "markdown", + "id": "54b7bc2e-606f-411a-9490-fcfab9236dfc", + "metadata": {}, + "source": [ + "## Initial Setup " + ] + }, + { + "cell_type": "markdown", + "id": "23e80e5b-aaee-4f23-b338-7ae62b08141f", + "metadata": {}, + "source": [ + "Let's start by importing some simple building blocks. \n", + "\n", + "The main thing we need is:\n", + "1. the Anthropic API (using our own `llama_index` LLM class)\n", + "2. a place to keep conversation history \n", + "3. a definition for tools that our agent can use." + ] + }, + { + "cell_type": "markdown", + "id": "41101795", + "metadata": {}, + "source": [ + "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4985c578", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-anthropic\n", + "%pip install llama-index-embeddings-openai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c61c873d", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d47283b-025e-4874-88ed-76245b22f82e", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.anthropic import Anthropic\n", + "from llama_index.core.tools import FunctionTool\n", + "\n", + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "id": "6fe08eb1-e638-4c00-9103-5c305bfacccf", + "metadata": {}, + "source": [ + "Let's define some very simple calculator tools for our agent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dd3c4a6-f3e0-46f9-ad3b-7ba57d1bc992", + "metadata": {}, + "outputs": [], + "source": [ + "def multiply(a: int, b: int) -> int:\n", + " \"\"\"Multiple two integers and returns the result integer\"\"\"\n", + " return a * b\n", + "\n", + "\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfcfb78b-7d4f-48d9-8d4c-ffcded23e7ac", + "metadata": {}, + "outputs": [], + "source": [ + "def add(a: int, b: int) -> int:\n", + " \"\"\"Add two integers and returns the result integer\"\"\"\n", + " return a + b\n", + "\n", + "\n", + "add_tool = FunctionTool.from_defaults(fn=add)" + ] + }, + { + "cell_type": "markdown", + "id": "eeac7d4c-58fd-42a5-9da9-c258375c61a0", + "metadata": {}, + "source": [ + "Make sure your ANTHROPIC_API_KEY is set. Otherwise explicitly specify the `api_key` parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4becf171-6632-42e5-bdec-918a00934696", + "metadata": {}, + "outputs": [], + "source": [ + "llm = Anthropic(model=\"claude-3-opus-20240229\", api_key=\"sk-ant-...\")" + ] + }, + { + "cell_type": "markdown", + "id": "707d30b8-6405-4187-a9ed-6146dcc42167", + "metadata": {}, + "source": [ + "## Initialize Anthropic Agent" + ] + }, + { + "cell_type": "markdown", + "id": "798ca3fd-6711-4c0c-a853-d868dd14b484", + "metadata": {}, + "source": [ + "Here we initialize a simple Mistral agent with calculator functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38ab3938-1138-43ea-b085-f430b42f5377", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.agent import FunctionCallingAgentWorker\n", + "from llama_index.core.agent import AgentRunner\n", + "\n", + "agent_worker = FunctionCallingAgentWorker.from_tools(\n", + " [multiply_tool, add_tool],\n", + " llm=llm,\n", + " verbose=True,\n", + " allow_parallel_tool_calls=False,\n", + ")\n", + "agent = AgentRunner(agent_worker)" + ] + }, + { + "cell_type": "markdown", + "id": "500cbee4", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9450401d-769f-46e8-8bab-0f27f7362f5d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: What is (121 + 2) * 5?\n", + "=== Calling Function ===\n", + "Calling function: add with args: {\"a\": 121, \"b\": 2}\n", + "=== Calling Function ===\n", + "Calling function: multiply with args: {\"a\": 123, \"b\": 5}\n", + "assistant: Therefore, (121 + 2) * 5 = 615\n" + ] + } + ], + "source": [ + "response = agent.chat(\"What is (121 + 2) * 5?\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "538bf32f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ToolOutput(content='123', tool_name='add', raw_input={'args': (), 'kwargs': {'a': 121, 'b': 2}}, raw_output=123), ToolOutput(content='615', tool_name='multiply', raw_input={'args': (), 'kwargs': {'a': 123, 'b': 5}}, raw_output=615)]\n" + ] + } + ], + "source": [ + "# inspect sources\n", + "print(response.sources)" + ] + }, + { + "cell_type": "markdown", + "id": "fb33983c", + "metadata": {}, + "source": [ + "### Async Chat\n", + "\n", + "Also let's re-enable parallel function calling so that we can call two `multiply` operations simultaneously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d1fc974", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: What is (121 * 3) + (5 * 8)?\n", + "=== Calling Function ===\n", + "Calling function: multiply with args: {\"a\": 121, \"b\": 3}\n", + "=== Calling Function ===\n", + "Calling function: multiply with args: {\"a\": 5, \"b\": 8}\n", + "=== Calling Function ===\n", + "Calling function: add with args: {\"a\": 363, \"b\": 40}\n", + "assistant: Therefore, the result of (121 * 3) + (5 * 8) is 403.\n" + ] + } + ], + "source": [ + "# enable parallel function calling\n", + "agent_worker = FunctionCallingAgentWorker.from_tools(\n", + " [multiply_tool, add_tool],\n", + " llm=llm,\n", + " verbose=True,\n", + " allow_parallel_tool_calls=True,\n", + ")\n", + "agent = AgentRunner(agent_worker)\n", + "response = await agent.achat(\"What is (121 * 3) + (5 * 8)?\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "markdown", + "id": "cabfdf01-8d63-43ff-b06e-a3059ede2ddf", + "metadata": {}, + "source": [ + "## Anthropic Agent over RAG Pipeline\n", + "\n", + "Build a Anthropic agent over a simple 10K document. We use OpenAI embeddings and claude-3-haiku-20240307 to construct the RAG pipeline, and pass it to the Anthropic Opus agent as a tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48120dd4-7f50-426f-bc7e-a903e090d32e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-04-04 18:12:42-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/uber_2021.pdf\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1880483 (1.8M) [application/octet-stream]\n", + "Saving to: ‘data/10k/uber_2021.pdf’\n", + "\n", + "data/10k/uber_2021. 100%[===================>] 1.79M 6.09MB/s in 0.3s \n", + "\n", + "2024-04-04 18:12:43 (6.09 MB/s) - ‘data/10k/uber_2021.pdf’ saved [1880483/1880483]\n", + "\n" + ] + } + ], + "source": [ + "!mkdir -p 'data/10k/'\n", + "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48c0cf98-3f10-4599-8437-d88dc89cefad", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.tools import QueryEngineTool, ToolMetadata\n", + "from llama_index.core import SimpleDirectoryReader, VectorStoreIndex\n", + "from llama_index.embeddings.openai import OpenAIEmbedding\n", + "from llama_index.llms.anthropic import Anthropic\n", + "\n", + "embed_model = OpenAIEmbedding(api_key=\"sk-...\")\n", + "query_llm = Anthropic(model=\"claude-3-haiku-20240307\", api_key=\"sk-ant-...\")\n", + "\n", + "# load data\n", + "uber_docs = SimpleDirectoryReader(\n", + " input_files=[\"./data/10k/uber_2021.pdf\"]\n", + ").load_data()\n", + "# build index\n", + "uber_index = VectorStoreIndex.from_documents(\n", + " uber_docs, embed_model=embed_model\n", + ")\n", + "uber_engine = uber_index.as_query_engine(similarity_top_k=3, llm=query_llm)\n", + "query_engine_tool = QueryEngineTool(\n", + " query_engine=uber_engine,\n", + " metadata=ToolMetadata(\n", + " name=\"uber_10k\",\n", + " description=(\n", + " \"Provides information about Uber financials for year 2021. \"\n", + " \"Use a detailed plain text question as input to the tool.\"\n", + " ),\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebfdaf80-e5e1-4c60-b556-20558da3d5e3", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.agent import FunctionCallingAgentWorker\n", + "from llama_index.core.agent import AgentRunner\n", + "\n", + "agent_worker = FunctionCallingAgentWorker.from_tools(\n", + " [query_engine_tool], llm=llm, verbose=True\n", + ")\n", + "agent = AgentRunner(agent_worker)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58c53f2a-0a3f-4abe-b8b6-97a974ec7546", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: Tell me both the risk factors and tailwinds for Uber?\n", + "=== Calling Function ===\n", + "Calling function: uber_10k with args: {\"input\": \"What were some of the key risk factors and tailwinds mentioned for Uber's business in 2021?\"}\n", + "assistant: In summary, some of the key risk factors Uber faced in 2021 included regulatory challenges, IP protection, staying competitive with new technologies, seasonality and forecasting challenges due to COVID-19, and risks of international expansion. However, Uber also benefited from tailwinds like accelerated growth in food delivery due to the pandemic and adapting well to new remote work arrangements.\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Tell me both the risk factors and tailwinds for Uber?\")\n", + "print(str(response))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index b3ba86175..1e45a1b6b 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -79,6 +79,7 @@ nav: - ./examples/agent/agent_runner/agent_around_query_pipeline_with_HyDE_for_PDFs.ipynb - ./examples/agent/mistral_agent.ipynb - ./examples/agent/openai_agent_tool_call_parser.ipynb + - ./examples/agent/anthropic_agent.ipynb - Callbacks: - ./examples/callbacks/HoneyHiveLlamaIndexTracer.ipynb - ./examples/callbacks/PromptLayerHandler.ipynb diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py index 58d223d47..3e159d405 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py @@ -1,5 +1,19 @@ -from typing import Any, Callable, Dict, Optional, Sequence -from anthropic.types import ContentBlockDeltaEvent +import anthropic +import json +from anthropic.types import ContentBlockDeltaEvent, TextBlock +from anthropic.types.beta.tools import ToolUseBlock +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) + from llama_index.core.base.llms.types import ( ChatMessage, ChatResponse, @@ -24,21 +38,26 @@ from llama_index.core.base.llms.generic_utils import ( chat_to_completion_decorator, stream_chat_to_completion_decorator, ) -from llama_index.core.llms.llm import LLM +from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.anthropic.utils import ( anthropic_modelname_to_contextsize, + force_single_tool_call, + is_function_calling_model, messages_to_anthropic_messages, ) from llama_index.core.utils import Tokenizer -import anthropic +if TYPE_CHECKING: + from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.tools.types import BaseTool + DEFAULT_ANTHROPIC_MODEL = "claude-2.1" DEFAULT_ANTHROPIC_MAX_TOKENS = 512 -class Anthropic(LLM): +class Anthropic(FunctionCallingLLM): """Anthropic LLM. Examples: @@ -146,6 +165,7 @@ class Anthropic(LLM): num_output=self.max_tokens, is_chat_model=True, model_name=self.model, + is_function_calling_model=is_function_calling_model(self.model), ) @property @@ -170,20 +190,38 @@ class Anthropic(LLM): **kwargs, } + def _get_content_and_tool_calls( + self, response: Any + ) -> Tuple[str, List[ToolUseBlock]]: + tool_calls = [] + content = "" + for content_block in response.content: + if isinstance(content_block, TextBlock): + content += content_block.text + elif isinstance(content_block, ToolUseBlock): + tool_calls.append(content_block.dict()) + + return content, tool_calls + @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: anthropic_messages, system_prompt = messages_to_anthropic_messages(messages) all_kwargs = self._get_all_kwargs(**kwargs) - response = self._client.messages.create( + response = self._client.beta.tools.messages.create( messages=anthropic_messages, stream=False, system=system_prompt, **all_kwargs, ) + + content, tool_calls = self._get_content_and_tool_calls(response) + return ChatResponse( message=ChatMessage( - role=MessageRole.ASSISTANT, content=response.content[0].text + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs={"tool_calls": tool_calls}, ), raw=dict(response), ) @@ -235,15 +273,20 @@ class Anthropic(LLM): anthropic_messages, system_prompt = messages_to_anthropic_messages(messages) all_kwargs = self._get_all_kwargs(**kwargs) - response = await self._aclient.messages.create( + response = await self._aclient.beta.tools.messages.create( messages=anthropic_messages, system=system_prompt, stream=False, **all_kwargs, ) + + content, tool_calls = self._get_content_and_tool_calls(response) + return ChatResponse( message=ChatMessage( - role=MessageRole.ASSISTANT, content=response.content[0].text + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs={"tool_calls": tool_calls}, ), raw=dict(response), ) @@ -287,3 +330,112 @@ class Anthropic(LLM): ) -> CompletionResponseAsyncGen: astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) return await astream_complete_fn(prompt, **kwargs) + + def chat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + chat_history = chat_history or [] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + chat_history.append(user_msg) + + tool_dicts = [] + for tool in tools: + tool_dicts.append( + { + "name": tool.metadata.name, + "description": tool.metadata.description, + "input_schema": tool.metadata.get_parameters_dict(), + } + ) + + response = self.chat(chat_history, tools=tool_dicts, **kwargs) + + if not allow_parallel_tool_calls: + force_single_tool_call(response) + + return response + + async def achat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + chat_history = chat_history or [] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + chat_history.append(user_msg) + + tool_dicts = [] + for tool in tools: + tool_dicts.append( + { + "name": tool.metadata.name, + "description": tool.metadata.description, + "input_schema": tool.metadata.get_parameters_dict(), + } + ) + + response = await self.achat(chat_history, tools=tool_dicts, **kwargs) + + if not allow_parallel_tool_calls: + force_single_tool_call(response) + + return response + + def get_tool_calls_from_response( + self, + response: "AgentChatResponse", + error_on_no_tool_call: bool = True, + **kwargs: Any, + ) -> List[ToolSelection]: + """Predict and call the tool.""" + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + if ( + "input" not in tool_call + or "id" not in tool_call + or "name" not in tool_call + ): + raise ValueError("Invalid tool call.") + if tool_call["type"] != "tool_use": + raise ValueError("Invalid tool type. Unsupported by Anthropic") + argument_dict = ( + json.loads(tool_call["input"]) + if isinstance(tool_call["input"], str) + else tool_call["input"] + ) + + tool_selections.append( + ToolSelection( + tool_id=tool_call["id"], + tool_name=tool_call["name"], + tool_kwargs=argument_dict, + ) + ) + + return tool_selections diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py index 28f5aba9e..961eabf04 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py @@ -1,8 +1,9 @@ from typing import Dict, Sequence, Tuple -from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole from anthropic.types import MessageParam, TextBlockParam +from anthropic.types.beta.tools import ToolResultBlockParam, ToolUseBlockParam HUMAN_PREFIX = "\n\nHuman:" ASSISTANT_PREFIX = "\n\nAssistant:" @@ -19,6 +20,10 @@ CLAUDE_MODELS: Dict[str, int] = { } +def is_function_calling_model(modelname: str) -> bool: + return "claude-3" in modelname + + def anthropic_modelname_to_contextsize(modelname: str) -> int: if modelname not in CLAUDE_MODELS: raise ValueError( @@ -63,14 +68,43 @@ def messages_to_anthropic_messages( for message in messages: if message.role == MessageRole.SYSTEM: system_prompt = message.content + elif message.role == MessageRole.FUNCTION or message.role == MessageRole.TOOL: + content = ToolResultBlockParam( + tool_use_id=message.additional_kwargs["tool_call_id"], + type="tool_result", + content=[TextBlockParam(text=message.content, type="text")], + ) + anth_message = MessageParam( + role=MessageRole.USER.value, + content=[content], + ) + anthropic_messages.append(anth_message) else: - message = MessageParam( + content = [] + if message.content: + content.append(TextBlockParam(text=message.content, type="text")) + + tool_calls = message.additional_kwargs.get("tool_calls", []) + for tool_call in tool_calls: + assert "id" in tool_call + assert "input" in tool_call + assert "name" in tool_call + + content.append( + ToolUseBlockParam( + id=tool_call["id"], + input=tool_call["input"], + name=tool_call["name"], + type="tool_use", + ) + ) + + anth_message = MessageParam( role=message.role.value, - content=[ - TextBlockParam(text=message.content, type="text") - ], # TODO: type detect for multimodal + content=content, # TODO: type detect for multimodal ) - anthropic_messages.append(message) + anthropic_messages.append(anth_message) + return __merge_common_role_msgs(anthropic_messages), system_prompt @@ -103,3 +137,9 @@ def messages_to_anthropic_prompt(messages: Sequence[ChatMessage]) -> str: str_list = [_message_to_anthropic_prompt(message) for message in messages] return "".join(str_list) + + +def force_single_tool_call(response: ChatResponse) -> None: + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml index af7ab4cd6..393bc782d 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-anthropic" readme = "README.md" -version = "0.1.8" +version = "0.1.9" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -anthropic = "^0.20.0" +anthropic = "^0.23.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" -- GitLab