diff --git a/docs/docs/examples/agent/mistral_agent.ipynb b/docs/docs/examples/agent/mistral_agent.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..289cc25f453e170926ffc6d7dfbf0e851edddcdd --- /dev/null +++ b/docs/docs/examples/agent/mistral_agent.ipynb @@ -0,0 +1,403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "24103c51", + "metadata": {}, + "source": [ + "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/agent/mistral_agent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" + ] + }, + { + "cell_type": "markdown", + "id": "99cea58c-48bc-4af6-8358-df9695659983", + "metadata": {}, + "source": [ + "# Function Calling Mistral Agent" + ] + }, + { + "cell_type": "markdown", + "id": "673df1fe-eb6c-46ea-9a73-a96e7ae7942e", + "metadata": {}, + "source": [ + "This notebook shows you how to use our Mistral agent, powered by function calling capabilities." + ] + }, + { + "cell_type": "markdown", + "id": "54b7bc2e-606f-411a-9490-fcfab9236dfc", + "metadata": {}, + "source": [ + "## Initial Setup " + ] + }, + { + "cell_type": "markdown", + "id": "23e80e5b-aaee-4f23-b338-7ae62b08141f", + "metadata": {}, + "source": [ + "Let's start by importing some simple building blocks. \n", + "\n", + "The main thing we need is:\n", + "1. the OpenAI API (using our own `llama_index` LLM class)\n", + "2. a place to keep conversation history \n", + "3. a definition for tools that our agent can use." + ] + }, + { + "cell_type": "markdown", + "id": "41101795", + "metadata": {}, + "source": [ + "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4985c578", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-mistralai\n", + "%pip install llama-index-embeddings-mistralai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c61c873d", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d47283b-025e-4874-88ed-76245b22f82e", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Sequence, List\n", + "\n", + "from llama_index.llms.mistralai import MistralAI\n", + "from llama_index.core.llms import ChatMessage\n", + "from llama_index.core.tools import BaseTool, FunctionTool\n", + "\n", + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "id": "6fe08eb1-e638-4c00-9103-5c305bfacccf", + "metadata": {}, + "source": [ + "Let's define some very simple calculator tools for our agent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dd3c4a6-f3e0-46f9-ad3b-7ba57d1bc992", + "metadata": {}, + "outputs": [], + "source": [ + "def multiply(a: int, b: int) -> int:\n", + " \"\"\"Multiple two integers and returns the result integer\"\"\"\n", + " return a * b\n", + "\n", + "\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfcfb78b-7d4f-48d9-8d4c-ffcded23e7ac", + "metadata": {}, + "outputs": [], + "source": [ + "def add(a: int, b: int) -> int:\n", + " \"\"\"Add two integers and returns the result integer\"\"\"\n", + " return a + b\n", + "\n", + "\n", + "add_tool = FunctionTool.from_defaults(fn=add)" + ] + }, + { + "cell_type": "markdown", + "id": "eeac7d4c-58fd-42a5-9da9-c258375c61a0", + "metadata": {}, + "source": [ + "Make sure your MISTRAL_API_KEY is set. Otherwise explicitly specify the `api_key` parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4becf171-6632-42e5-bdec-918a00934696", + "metadata": {}, + "outputs": [], + "source": [ + "llm = MistralAI(model=\"mistral-large-latest\")" + ] + }, + { + "cell_type": "markdown", + "id": "707d30b8-6405-4187-a9ed-6146dcc42167", + "metadata": {}, + "source": [ + "## Initialize Mistral Agent" + ] + }, + { + "cell_type": "markdown", + "id": "798ca3fd-6711-4c0c-a853-d868dd14b484", + "metadata": {}, + "source": [ + "Here we initialize a simple Mistral agent with calculator functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38ab3938-1138-43ea-b085-f430b42f5377", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.agent import FunctionCallingAgentWorker\n", + "from llama_index.core.agent import AgentRunner\n", + "\n", + "agent_worker = FunctionCallingAgentWorker.from_tools(\n", + " [multiply_tool, add_tool], llm=llm, verbose=True\n", + ")\n", + "agent = AgentRunner(agent_worker)" + ] + }, + { + "cell_type": "markdown", + "id": "500cbee4", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9450401d-769f-46e8-8bab-0f27f7362f5d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: What is (121 + 2) * 5?\n", + "=== Calling Function ===\n", + "Calling function: add with args: {\"a\": 121, \"b\": 2}\n", + "=== Calling Function ===\n", + "Calling function: multiply with args: {\"a\": 123, \"b\": 5}\n", + "assistant: The result of (121 + 2) * 5 is 615.\n" + ] + } + ], + "source": [ + "response = agent.chat(\"What is (121 + 2) * 5?\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "538bf32f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ToolOutput(content='123', tool_name='add', raw_input={'args': (), 'kwargs': {'a': 121, 'b': 2}}, raw_output=123), ToolOutput(content='615', tool_name='multiply', raw_input={'args': (), 'kwargs': {'a': 123, 'b': 5}}, raw_output=615)]\n" + ] + } + ], + "source": [ + "# inspect sources\n", + "print(response.sources)" + ] + }, + { + "cell_type": "markdown", + "id": "fb33983c", + "metadata": {}, + "source": [ + "### Async Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d1fc974", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: What is (121 * 3) + 5? Use one tool at a time.\n", + "=== Calling Function ===\n", + "Calling function: multiply with args: {\"a\": 121, \"b\": 3}\n", + "=== Calling Function ===\n", + "Calling function: add with args: {\"a\": 363, \"b\": 5}\n", + "assistant: The result of (121 * 3) + 5 is 368.\n" + ] + } + ], + "source": [ + "response = await agent.achat(\"What is (121 * 3) + 5? Use one tool at a time.\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "markdown", + "id": "cabfdf01-8d63-43ff-b06e-a3059ede2ddf", + "metadata": {}, + "source": [ + "## Mistral Agent over RAG Pipeline\n", + "\n", + "Build a Mistral agent over a simple 10K document. We use both Mistral embeddings and mistral-medium to construct the RAG pipeline, and pass it to the Mistral agent as a tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48120dd4-7f50-426f-bc7e-a903e090d32e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-03-23 11:13:41-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/uber_2021.pdf\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8003::154, 2606:50c0:8002::154, 2606:50c0:8001::154, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8003::154|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1880483 (1.8M) [application/octet-stream]\n", + "Saving to: ‘data/10k/uber_2021.pdf’\n", + "\n", + "data/10k/uber_2021. 100%[===================>] 1.79M --.-KB/s in 0.09s \n", + "\n", + "2024-03-23 11:13:41 (19.3 MB/s) - ‘data/10k/uber_2021.pdf’ saved [1880483/1880483]\n", + "\n" + ] + } + ], + "source": [ + "!mkdir -p 'data/10k/'\n", + "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48c0cf98-3f10-4599-8437-d88dc89cefad", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.tools import QueryEngineTool, ToolMetadata\n", + "from llama_index.core import SimpleDirectoryReader, VectorStoreIndex\n", + "from llama_index.embeddings.mistralai import MistralAIEmbedding\n", + "from llama_index.llms.mistralai import MistralAI\n", + "\n", + "embed_model = MistralAIEmbedding()\n", + "query_llm = MistralAI(model=\"mistral-medium\")\n", + "\n", + "# load data\n", + "uber_docs = SimpleDirectoryReader(\n", + " input_files=[\"./data/10k/uber_2021.pdf\"]\n", + ").load_data()\n", + "# build index\n", + "uber_index = VectorStoreIndex.from_documents(\n", + " uber_docs, embed_model=embed_model\n", + ")\n", + "uber_engine = uber_index.as_query_engine(similarity_top_k=3, llm=query_llm)\n", + "query_engine_tool = QueryEngineTool(\n", + " query_engine=uber_engine,\n", + " metadata=ToolMetadata(\n", + " name=\"uber_10k\",\n", + " description=(\n", + " \"Provides information about Uber financials for year 2021. \"\n", + " \"Use a detailed plain text question as input to the tool.\"\n", + " ),\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebfdaf80-e5e1-4c60-b556-20558da3d5e3", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.agent import FunctionCallingAgentWorker\n", + "from llama_index.core.agent import AgentRunner\n", + "\n", + "agent_worker = FunctionCallingAgentWorker.from_tools(\n", + " [query_engine_tool], llm=llm, verbose=True\n", + ")\n", + "agent = AgentRunner(agent_worker)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58c53f2a-0a3f-4abe-b8b6-97a974ec7546", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: Tell me the risk factors for Uber? Use one tool at a time.\n", + "=== Calling Function ===\n", + "Calling function: uber_10k with args: {\"input\": \"What are the risk factors for Uber?\"}\n", + "assistant: Uber faces several risk factors that could negatively impact its business. These include the potential failure to offer autonomous vehicle technologies on its platform, the loss of high-quality personnel due to attrition or unsuccessful succession planning, and security or data privacy breaches. Additionally, cyberattacks such as malware, ransomware, and phishing attacks could harm Uber's reputation and business. The company is also subject to climate change risks and legal and regulatory risks. Furthermore, Uber relies on third parties to maintain open marketplaces for distributing its platform and providing software, and any interference from these parties could adversely affect Uber's business. The company may also require additional capital to support its growth, and there is no guarantee that this capital will be available on reasonable terms or at all. Finally, Uber's business is subject to extensive government regulation and oversight relating to the provision of payment and financial services, and the company faces risks related to its collection, use, transfer, disclosure, and other processing of data.\n" + ] + } + ], + "source": [ + "response = agent.chat(\n", + " \"Tell me the risk factors for Uber? Use one tool at a time.\"\n", + ")\n", + "print(str(response))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llama_index_v3", + "language": "python", + "name": "llama_index_v3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/examples/agent/react_agent_with_query_engine.ipynb b/docs/docs/examples/agent/react_agent_with_query_engine.ipynb index f69358e48e706fc4cc1bb14249adce0c13cd8ce7..df92fc98ed03b33ad29f414d5e931c1e327e43fd 100644 --- a/docs/docs/examples/agent/react_agent_with_query_engine.ipynb +++ b/docs/docs/examples/agent/react_agent_with_query_engine.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "6b0186a4", "metadata": {}, @@ -87,7 +86,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "6a79cbc9", "metadata": {}, diff --git a/docs/docs/module_guides/deploying/agents/modules.md b/docs/docs/module_guides/deploying/agents/modules.md index 5a6454c0dc5bf9288ee9fc6440b03ecb76b9b4c6..a6627f80feb5bec34d0555d596477751dd140cdb 100644 --- a/docs/docs/module_guides/deploying/agents/modules.md +++ b/docs/docs/module_guides/deploying/agents/modules.md @@ -28,6 +28,10 @@ For more detailed guides on how to use specific tools, check out our [tools modu - [ReAct Agent](../../../examples/agent/react_agent.ipynb) - [ReAct Agent with Query Engine Tools](../../../examples/agent/react_agent_with_query_engine.ipynb) +## Function Calling Agents + +- [Mistral Agent](../../../examples/agent/mistral_agent.ipynb) + ## Additional Agents (available on LlamaHub) - [LLMCompiler Agent Cookbook](https://github.com/run-llama/llama-hub/blob/main/llama_hub/llama_packs/agents/llm_compiler/llm_compiler.ipynb) diff --git a/llama-index-core/llama_index/core/agent/__init__.py b/llama-index-core/llama_index/core/agent/__init__.py index 1206908d86894e695b30d0868e873d5837f8b3b2..3e9579ad6db21b7a5a675cd0a17fb44d48f929e0 100644 --- a/llama-index-core/llama_index/core/agent/__init__.py +++ b/llama-index-core/llama_index/core/agent/__init__.py @@ -9,6 +9,7 @@ from llama_index.core.agent.runner.base import AgentRunner from llama_index.core.agent.runner.parallel import ParallelAgentRunner from llama_index.core.agent.types import Task from llama_index.core.chat_engine.types import AgentChatResponse +from llama_index.core.agent.function_calling.step import FunctionCallingAgentWorker __all__ = [ "AgentRunner", @@ -18,6 +19,7 @@ __all__ = [ "CustomSimpleAgentWorker", "QueryPipelineAgentWorker", "ReActChatFormatter", + "FunctionCallingAgentWorker", # beta "MultimodalReActAgentWorker", # schema-related diff --git a/llama-index-core/llama_index/core/agent/function_calling/BUILD b/llama-index-core/llama_index/core/agent/function_calling/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db46e8d6c978c67e301dd6c47bee08c1b3fd141c --- /dev/null +++ b/llama-index-core/llama_index/core/agent/function_calling/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-core/llama_index/core/agent/function_calling/__init__.py b/llama-index-core/llama_index/core/agent/function_calling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-index-core/llama_index/core/agent/function_calling/base.py b/llama-index-core/llama_index/core/agent/function_calling/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7a5f2ba2fe3941ff6f820a8bdabbbdee3ea1e7 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/function_calling/base.py @@ -0,0 +1,23 @@ +"""Function calling agent.""" + + +from llama_index.core.agent.runner.base import AgentRunner + + +class FunctionCallingAgent(AgentRunner): + """Function calling agent. + + Calls any LLM that supports function calling in a while loop until the task is complete. + + """ + + # def __init__( + # self, + # tools: List[BaseTool], + # llm: OpenAI, + # memory: BaseMemory, + # prefix_messages: List[ChatMessage], + # verbose: bool = False, + # max_function_calls: int = 5, + # default_tool_choice: str = "auto", + # ) diff --git a/llama-index-core/llama_index/core/agent/function_calling/step.py b/llama-index-core/llama_index/core/agent/function_calling/step.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdc5795e7cd505fa8aac431a89d181bb9b2836f --- /dev/null +++ b/llama-index-core/llama_index/core/agent/function_calling/step.py @@ -0,0 +1,354 @@ +"""Function calling agent worker.""" + +import json +import logging +import uuid +from typing import Any, List, Optional, cast + +from llama_index.core.agent.types import ( + BaseAgentWorker, + Task, + TaskStep, + TaskStepOutput, +) +from llama_index.core.agent.utils import add_user_step_to_memory +from llama_index.core.base.llms.types import MessageRole +from llama_index.core.callbacks import ( + CallbackManager, + CBEventType, + EventPayload, + trace_method, +) +from llama_index.core.chat_engine.types import ( + AgentChatResponse, +) +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.llms.llm import LLM, ToolSelection +from llama_index.core.memory import BaseMemory, ChatMemoryBuffer +from llama_index.core.objects.base import ObjectRetriever +from llama_index.core.settings import Settings +from llama_index.core.tools import BaseTool, ToolOutput, adapt_to_async_tool +from llama_index.core.tools.calling import ( + call_tool_with_selection, + acall_tool_with_selection, +) +from llama_index.llms.openai import OpenAI +from llama_index.core.tools import BaseTool, ToolOutput, adapt_to_async_tool +from llama_index.core.tools.types import AsyncBaseTool + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +DEFAULT_MAX_FUNCTION_CALLS = 5 + + +def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool: + """Get function by name.""" + name_to_tool = {tool.metadata.name: tool for tool in tools} + if name not in name_to_tool: + raise ValueError(f"Tool with name {name} not found") + return name_to_tool[name] + + +class FunctionCallingAgentWorker(BaseAgentWorker): + """Function calling agent worker.""" + + def __init__( + self, + tools: List[BaseTool], + llm: OpenAI, + prefix_messages: List[ChatMessage], + verbose: bool = False, + max_function_calls: int = 5, + callback_manager: Optional[CallbackManager] = None, + tool_retriever: Optional[ObjectRetriever[BaseTool]] = None, + ) -> None: + """Init params.""" + if not llm.metadata.is_function_calling_model: + raise ValueError( + f"Model name {llm.model} does not support function calling API. " + ) + self._llm = llm + self._verbose = verbose + self._max_function_calls = max_function_calls + self.prefix_messages = prefix_messages + self.callback_manager = callback_manager or self._llm.callback_manager + + if len(tools) > 0 and tool_retriever is not None: + raise ValueError("Cannot specify both tools and tool_retriever") + elif len(tools) > 0: + self._get_tools = lambda _: tools + elif tool_retriever is not None: + tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever) + self._get_tools = lambda message: tool_retriever_c.retrieve(message) + else: + # no tools + self._get_tools = lambda _: [] + + @classmethod + def from_tools( + cls, + tools: Optional[List[BaseTool]] = None, + tool_retriever: Optional[ObjectRetriever[BaseTool]] = None, + llm: Optional[LLM] = None, + verbose: bool = False, + max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, + callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + prefix_messages: Optional[List[ChatMessage]] = None, + **kwargs: Any, + ) -> "FunctionCallingAgentWorker": + """Create an FunctionCallingAgentWorker from a list of tools. + + Similar to `from_defaults` in other classes, this method will + infer defaults for a variety of parameters, including the LLM, + if they are not specified. + + """ + tools = tools or [] + + llm = llm or Settings.llm + if callback_manager is not None: + llm.callback_manager = callback_manager + + if system_prompt is not None: + if prefix_messages is not None: + raise ValueError( + "Cannot specify both system_prompt and prefix_messages" + ) + prefix_messages = [ChatMessage(content=system_prompt, role="system")] + + prefix_messages = prefix_messages or [] + + return cls( + tools=tools, + tool_retriever=tool_retriever, + llm=llm, + prefix_messages=prefix_messages, + verbose=verbose, + max_function_calls=max_function_calls, + callback_manager=callback_manager, + ) + + def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep: + """Initialize step from task.""" + sources: List[ToolOutput] = [] + # temporary memory for new messages + new_memory = ChatMemoryBuffer.from_defaults() + # initialize task state + task_state = { + "sources": sources, + "n_function_calls": 0, + "new_memory": new_memory, + } + task.extra_state.update(task_state) + + return TaskStep( + task_id=task.task_id, + step_id=str(uuid.uuid4()), + input=task.input, + ) + + def get_tools(self, input: str) -> List[AsyncBaseTool]: + """Get tools.""" + return [adapt_to_async_tool(t) for t in self._get_tools(input)] + + def get_all_messages(self, task: Task) -> List[ChatMessage]: + return ( + self.prefix_messages + + task.memory.get() + + task.extra_state["new_memory"].get_all() + ) + + def _call_function( + self, + tools: List[BaseTool], + tool_call: ToolSelection, + memory: BaseMemory, + sources: List[ToolOutput], + verbose: bool = False, + ) -> None: + with self.callback_manager.event( + CBEventType.FUNCTION_CALL, + payload={ + EventPayload.FUNCTION_CALL: json.dumps(tool_call.tool_kwargs), + EventPayload.TOOL: get_function_by_name( + tools, tool_call.tool_name + ).metadata, + }, + ) as event: + tool_output = call_tool_with_selection(tool_call, tools, verbose=verbose) + event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) + + function_message = ChatMessage( + content=str(tool_output), + role=MessageRole.TOOL, + additional_kwargs={ + "name": tool_call.tool_name, + "tool_call_id": tool_call.tool_id, + }, + ) + sources.append(tool_output) + memory.put(function_message) + + async def _acall_function( + self, + tools: List[BaseTool], + tool_call: ToolSelection, + memory: BaseMemory, + sources: List[ToolOutput], + verbose: bool = False, + ) -> None: + with self.callback_manager.event( + CBEventType.FUNCTION_CALL, + payload={ + EventPayload.FUNCTION_CALL: json.dumps(tool_call.tool_kwargs), + EventPayload.TOOL: get_function_by_name( + tools, tool_call.tool_name + ).metadata, + }, + ) as event: + tool_output = await acall_tool_with_selection( + tool_call, tools, verbose=verbose + ) + event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) + + function_message = ChatMessage( + content=str(tool_output), + role=MessageRole.TOOL, + additional_kwargs={ + "name": tool_call.tool_name, + "tool_call_id": tool_call.tool_id, + }, + ) + sources.append(tool_output) + memory.put(function_message) + + @trace_method("run_step") + def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: + """Run step.""" + if step.input is not None: + add_user_step_to_memory( + step, task.extra_state["new_memory"], verbose=self._verbose + ) + # TODO: see if we want to do step-based inputs + tools = self.get_tools(task.input) + + # get response and tool call (if exists) + response = self._llm.chat_with_tool( + tools=tools, + user_msg=None, + chat_history=self.get_all_messages(task), + verbose=self._verbose, + ) + tool_call = self._llm._get_tool_call_from_response( + response, error_on_no_tool_call=False + ) + task.extra_state["new_memory"].put(response.message) + if tool_call is None: + # we are done + is_done = True + new_steps = [] + else: + is_done = False + self._call_function( + tools, + tool_call, + task.extra_state["new_memory"], + task.extra_state["sources"], + verbose=self._verbose, + ) + # put tool output in sources and memory + new_steps = [ + step.get_next_step( + step_id=str(uuid.uuid4()), + # NOTE: input is unused + input=None, + ) + ] + agent_response = AgentChatResponse( + response=str(response), sources=task.extra_state["sources"] + ) + + return TaskStepOutput( + output=agent_response, + task_step=step, + is_last=is_done, + next_steps=new_steps, + ) + + @trace_method("run_step") + async def arun_step( + self, step: TaskStep, task: Task, **kwargs: Any + ) -> TaskStepOutput: + """Run step (async).""" + if step.input is not None: + add_user_step_to_memory( + step, task.extra_state["new_memory"], verbose=self._verbose + ) + # TODO: see if we want to do step-based inputs + tools = self.get_tools(task.input) + + response = await self._llm.achat_with_tool( + tools=tools, + user_msg=None, + chat_history=self.get_all_messages(task), + verbose=self._verbose, + ) + tool_call = self._llm._get_tool_call_from_response( + response, error_on_no_tool_call=False + ) + task.extra_state["new_memory"].put(response.message) + if tool_call is None: + # we are done + is_done = True + new_steps = [] + else: + is_done = False + await self._acall_function( + tools, + tool_call, + task.extra_state["new_memory"], + task.extra_state["sources"], + verbose=self._verbose, + ) + # put tool output in sources and memory + new_steps = [ + step.get_next_step( + step_id=str(uuid.uuid4()), + # NOTE: input is unused + input=None, + ) + ] + + agent_response = AgentChatResponse( + response=str(response), sources=task.extra_state["sources"] + ) + + return TaskStepOutput( + output=agent_response, + task_step=step, + is_last=is_done, + next_steps=new_steps, + ) + + @trace_method("run_step") + def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: + """Run step (stream).""" + raise NotImplementedError("Stream not supported for function calling agent") + + @trace_method("run_step") + async def astream_step( + self, step: TaskStep, task: Task, **kwargs: Any + ) -> TaskStepOutput: + """Run step (async stream).""" + raise NotImplementedError("Stream not supported for function calling agent") + + def finalize_task(self, task: Task, **kwargs: Any) -> None: + """Finalize task, after all the steps are completed.""" + # add new messages to memory + task.memory.set( + task.memory.get_all() + task.extra_state["new_memory"].get_all() + ) + # reset new memory + task.extra_state["new_memory"].reset() diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py index 04a201ce128f2d5da4b73fe26a2a2db0405014a5..e2470684ce11bc4033093b44b8f3f5e2a3c4159d 100644 --- a/llama-index-core/llama_index/core/llms/llm.py +++ b/llama-index-core/llama_index/core/llms/llm.py @@ -54,6 +54,10 @@ from llama_index.core.instrumentation.events.llm import ( ) import llama_index.core.instrumentation as instrument +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, +) dispatcher = instrument.get_dispatcher(__name__) @@ -62,6 +66,15 @@ if TYPE_CHECKING: from llama_index.core.tools.types import BaseTool +class ToolSelection(BaseModel): + """Tool selection.""" + + tool_id: str = Field(description="Tool ID to select.") + tool_name: str = Field(description="Tool name to select.") + tool_kwargs: Dict[str, Any] = Field(description="Keyword arguments for the tool.") + # NOTE: no args for now + + # NOTE: These two protocols are needed to appease mypy @runtime_checkable class MessagesToPromptType(Protocol): @@ -534,6 +547,38 @@ class LLM(BaseLLM): # -- Tool Calling -- + def chat_with_tool( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + raise NotImplementedError("predict_tool is not supported by default.") + + async def achat_with_tool( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" + raise NotImplementedError("predict_tool is not supported by default.") + + def _get_tool_call_from_response( + self, + response: "AgentChatResponse", + **kwargs: Any, + ) -> ToolSelection: + """Predict and call the tool.""" + raise NotImplementedError( + "_get_tool_call_from_response is not supported by default." + ) + def predict_and_call( self, tools: List["BaseTool"], diff --git a/llama-index-core/llama_index/core/tools/calling.py b/llama-index-core/llama_index/core/tools/calling.py index 7dbe600674c96b51ba9ca7a8b54b033119cffb81..76e9b89ced0780aab04ab3ff09a0749997d66da1 100644 --- a/llama-index-core/llama_index/core/tools/calling.py +++ b/llama-index-core/llama_index/core/tools/calling.py @@ -1,4 +1,10 @@ from llama_index.core.tools.types import BaseTool, ToolOutput, adapt_to_async_tool +from typing import TYPE_CHECKING, List +from llama_index.core.llms.llm import ToolSelection +import json + +if TYPE_CHECKING: + from llama_index.core.tools.types import BaseTool def call_tool(tool: BaseTool, arguments: dict) -> ToolOutput: @@ -40,3 +46,37 @@ async def acall_tool(tool: BaseTool, arguments: dict) -> ToolOutput: raw_input=arguments, raw_output=str(e), ) + + +def call_tool_with_selection( + tool_call: ToolSelection, + tools: List["BaseTool"], + verbose: bool = False, +) -> ToolOutput: + from llama_index.core.tools.calling import call_tool + + tools_by_name = {tool.metadata.name: tool for tool in tools} + name = tool_call.tool_name + if verbose: + arguments_str = json.dumps(tool_call.tool_kwargs) + print("=== Calling Function ===") + print(f"Calling function: {name} with args: {arguments_str}") + tool = tools_by_name[name] + return call_tool(tool, tool_call.tool_kwargs) + + +async def acall_tool_with_selection( + tool_call: ToolSelection, + tools: List["BaseTool"], + verbose: bool = False, +) -> ToolOutput: + from llama_index.core.tools.calling import acall_tool + + tools_by_name = {tool.metadata.name: tool for tool in tools} + name = tool_call.tool_name + if verbose: + arguments_str = json.dumps(tool_call.tool_kwargs) + print("=== Calling Function ===") + print(f"Calling function: {name} with args: {arguments_str}") + tool = tools_by_name[name] + return await acall_tool(tool, tool_call.tool_kwargs) diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 7484494b4556e9d2c135452805663ec933be21f7..bd3b1953da737cae90c7e9ec15ee359534d9eecf 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -26,7 +26,7 @@ from llama_index.core.base.llms.generic_utils import ( get_from_param_or_env, stream_chat_to_completion_decorator, ) -from llama_index.core.llms.llm import LLM +from llama_index.core.llms.llm import LLM, ToolSelection from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.mistralai.utils import ( is_mistralai_function_calling_model, @@ -45,6 +45,23 @@ DEFAULT_MISTRALAI_MODEL = "mistral-tiny" DEFAULT_MISTRALAI_ENDPOINT = "https://api.mistral.ai" DEFAULT_MISTRALAI_MAX_TOKENS = 512 +from mistralai.models.chat_completion import ChatMessage as mistral_chatmessage + + +def to_mistral_chatmessage( + messages: Sequence[ChatMessage], +) -> List[mistral_chatmessage]: + new_messages = [] + for m in messages: + tool_calls = m.additional_kwargs.get("tool_calls") + new_messages.append( + mistral_chatmessage( + role=m.role.value, content=m.content, tool_calls=tool_calls + ) + ) + + return new_messages + class MistralAI(LLM): """MistralAI LLM. @@ -200,11 +217,8 @@ class MistralAI(LLM): @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: # convert messages to mistral ChatMessage - from mistralai.models.chat_completion import ChatMessage as mistral_chatmessage - messages = [ - mistral_chatmessage(role=x.role, content=x.content) for x in messages - ] + messages = to_mistral_chatmessage(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = self._client.chat(messages=messages, **all_kwargs) @@ -233,12 +247,8 @@ class MistralAI(LLM): self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: # convert messages to mistral ChatMessage - from mistralai.models.chat_completion import ChatMessage as mistral_chatmessage - messages = [ - mistral_chatmessage(role=message.role, content=message.content) - for message in messages - ] + messages = to_mistral_chatmessage(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = self._client.chat_stream(messages=messages, **all_kwargs) @@ -266,48 +276,6 @@ class MistralAI(LLM): stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) return stream_complete_fn(prompt, **kwargs) - def _get_tool_call( - self, - response: ChatResponse, - ) -> ToolCall: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) - - if len(tool_calls) < 1: - raise ValueError( - f"Expected at least one tool call, but got {len(tool_calls)} tool calls." - ) - - # TODO: support more than one tool call? - tool_call = tool_calls[0] - if not isinstance(tool_call, ToolCall): - raise ValueError("Invalid tool_call object") - - if tool_call.type != "function": - raise ValueError("Invalid tool type. Unsupported by Mistralai.") - - return tool_call - - def _call_tool( - self, - tool_call: ToolCall, - tools_by_name: Dict[str, "BaseTool"], - verbose: bool = False, - ) -> "AgentChatResponse": - from llama_index.core.chat_engine.types import AgentChatResponse - from llama_index.core.tools.calling import call_tool - - arguments_str = tool_call.function.arguments - name = tool_call.function.name - if verbose: - print("=== Calling Function ===") - print(f"Calling function: {name} with args: {arguments_str}") - tool = tools_by_name[name] - argument_dict = json.loads(arguments_str) - - tool_output = call_tool(tool, argument_dict) - - return AgentChatResponse(response=tool_output.content, sources=[tool_output]) - def predict_and_call( self, tools: List["BaseTool"], @@ -316,6 +284,10 @@ class MistralAI(LLM): verbose: bool = False, **kwargs: Any, ) -> "AgentChatResponse": + from llama_index.core.tools.calling import ( + call_tool_with_selection, + ) + if not self.metadata.is_function_calling_model: return super().predict_and_call( tools, @@ -325,43 +297,30 @@ class MistralAI(LLM): **kwargs, ) - # misralai uses the same openai tool format - tool_specs = [tool.metadata.to_openai_tool() for tool in tools] - tools_by_name = {tool.metadata.name: tool for tool in tools} - - if isinstance(user_msg, str): - user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) - - messages = chat_history or [] - if user_msg: - messages.append(user_msg) - - response = self.chat( - messages, - tools=tool_specs, - **kwargs, + response = self.chat_with_tool( + tools, user_msg, chat_history=chat_history, verbose=verbose, **kwargs ) - - tool_call = self._get_tool_call(response) - - return self._call_tool(tool_call, tools_by_name, verbose=verbose) + tool_call = self._get_tool_call_from_response(response) + tool_output = call_tool_with_selection(tool_call, tools, verbose=verbose) + return AgentChatResponse(response=tool_output.content, sources=[tool_output]) @llm_chat_callback() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponse: # convert messages to mistral ChatMessage - from mistralai.models.chat_completion import ChatMessage as mistral_chatmessage - messages = [ - mistral_chatmessage(role=message.role, content=message.content) - for message in messages - ] + messages = to_mistral_chatmessage(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = await self._aclient.chat(messages=messages, **all_kwargs) + tool_calls = response.choices[0].message.tool_calls return ChatResponse( message=ChatMessage( - role=MessageRole.ASSISTANT, content=response.choices[0].message.content + role=MessageRole.ASSISTANT, + content=response.choices[0].message.content, + additional_kwargs={"tool_calls": tool_calls} + if tool_calls is not None + else {}, ), raw=dict(response), ) @@ -378,11 +337,8 @@ class MistralAI(LLM): self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseAsyncGen: # convert messages to mistral ChatMessage - from mistralai.models.chat_completion import ChatMessage as mistral_chatmessage - messages = [ - mistral_chatmessage(role=x.role, content=x.content) for x in messages - ] + messages = to_mistral_chatmessage(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = await self._aclient.chat_stream(messages=messages, **all_kwargs) @@ -410,41 +366,75 @@ class MistralAI(LLM): astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) return await astream_complete_fn(prompt, **kwargs) - async def _acall_tool( + async def apredict_and_call( self, - tool_call: ToolCall, - tools_by_name: Dict[str, "BaseTool"], + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, + **kwargs: Any, ) -> "AgentChatResponse": - from llama_index.core.chat_engine.types import AgentChatResponse - from llama_index.core.tools.calling import acall_tool - - arguments_str = tool_call.function.arguments - name = tool_call.function.name - if verbose: - print("=== Calling Function ===") - print(f"Calling function: {name} with args: {arguments_str}") - tool = tools_by_name[name] - argument_dict = json.loads(arguments_str) + from llama_index.core.tools.calling import ( + acall_tool_with_selection, + ) - tool_output = await acall_tool(tool, argument_dict) + if not self.metadata.is_function_calling_model: + return await super().predict_and_call( + tools, + user_msg=user_msg, + chat_history=chat_history, + verbose=verbose, + **kwargs, + ) + response = await self.achat_with_tool( + tools, user_msg, chat_history=chat_history, verbose=verbose, **kwargs + ) + tool_call = self._get_tool_call_from_response(response) + tool_output = acall_tool_with_selection(tool_call, tools, verbose=verbose) return AgentChatResponse(response=tool_output.content, sources=[tool_output]) - async def apredict_and_call( + def chat_with_tool( self, tools: List["BaseTool"], user_msg: Optional[Union[str, ChatMessage]] = None, chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, **kwargs: Any, - ) -> "AgentChatResponse": - if not self.metadata.is_function_calling_model: - return await super().apredict_and_call(user_msg, tools, verbose, **kwargs) + ) -> ChatResponse: + """Predict and call the tool.""" + # misralai uses the same openai tool format + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + response = self.chat( + messages, + tools=tool_specs, + **kwargs, + ) + # TODO: this is a hack, in the future we should support multiple tool calls + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + return response + async def achat_with_tool( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """Predict and call the tool.""" # misralai uses the same openai tool format tool_specs = [tool.metadata.to_openai_tool() for tool in tools] - tools_by_name = {tool.metadata.name: tool for tool in tools} if isinstance(user_msg, str): user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) @@ -458,7 +448,40 @@ class MistralAI(LLM): tools=tool_specs, **kwargs, ) + # TODO: this is a hack, in the future we should support multiple tool calls + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + return response + + def _get_tool_call_from_response( + self, + response: "AgentChatResponse", + error_on_no_tool_call: bool = True, + ) -> Optional[ToolSelection]: + """Predict and call the tool.""" + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return None + + # TODO: support more than one tool call? + tool_call = tool_calls[0] + if not isinstance(tool_call, ToolCall): + raise ValueError("Invalid tool_call object") - tool_call = self._get_tool_call(response) + if tool_call.type != "function": + raise ValueError("Invalid tool type. Unsupported by Mistralai.") + + argument_dict = json.loads(tool_call.function.arguments) - return await self._acall_tool(tool_call, tools_by_name, verbose=verbose) + return ToolSelection( + tool_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=argument_dict, + )