From 46c59642c80cefd1dafb65d94a93db2211d09281 Mon Sep 17 00:00:00 2001 From: Simon Suo <simonsdsuo@gmail.com> Date: Sun, 2 Jul 2023 14:49:18 -0700 Subject: [PATCH] Hook up new LLM abstraction with agents and programs (#6679) * add to agent and programs * wip * update notebook * wip * wip * wiup * wip * wip * wip * wip * wip * wip * wip * wip * remove unused imports * update change log --- CHANGELOG.md | 5 + docs/examples/agent/openai_agent.ipynb | 101 +++++++++++------ .../openai_agent_context_retrieval.ipynb | 8 +- .../agent/openai_agent_query_cookbook.ipynb | 9 +- .../agent/openai_agent_query_plan.ipynb | 10 +- .../agent/openai_agent_retrieval.ipynb | 26 ++--- .../openai_agent_with_query_engine.ipynb | 56 ++++----- .../openai_pydantic_program.ipynb | 79 ++++++------- llama_index/agent/context_retriever_agent.py | 49 ++++---- llama_index/agent/openai_agent.py | 107 +++++++++--------- llama_index/agent/retriever_openai_agent.py | 39 ++++--- llama_index/program/openai_program.py | 53 ++++----- llama_index/selectors/pydantic_selectors.py | 12 +- 13 files changed, 288 insertions(+), 266 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00700b677a..bf3a30e066 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # ChangeLog +## Unreleased + +### Breaking/Deprecated API Changes +- Change `BaseOpenAIAgent` to use `llama_index.llms.OpenAI`. Adjust `chat_history` to use `List[ChatMessage]]` as type. + ## [v0.6.38] - 2023-07-02 ### New Features diff --git a/docs/examples/agent/openai_agent.ipynb b/docs/examples/agent/openai_agent.ipynb index 00091e7067..59c9746bee 100644 --- a/docs/examples/agent/openai_agent.ipynb +++ b/docs/examples/agent/openai_agent.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "99cea58c-48bc-4af6-8358-df9695659983", "metadata": { @@ -11,6 +12,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "673df1fe-eb6c-46ea-9a73-a96e7ae7942e", "metadata": { @@ -23,6 +25,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "54b7bc2e-606f-411a-9490-fcfab9236dfc", "metadata": { @@ -33,6 +36,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "23e80e5b-aaee-4f23-b338-7ae62b08141f", "metadata": {}, @@ -40,14 +44,14 @@ "Let's start by importing some simple building blocks. \n", "\n", "The main thing we need is:\n", - "1. the OpenAI API (we will use langchain's ChatOpenAI wrapper for convienience here.)\n", + "1. the OpenAI API\n", "2. a place to keep conversation history \n", "3. a definition for tools that our agent can use." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "9d47283b-025e-4874-88ed-76245b22f82e", "metadata": { "tags": [] @@ -55,15 +59,14 @@ "outputs": [], "source": [ "import json\n", - "from typing import Sequence\n", + "from typing import Sequence, List\n", "\n", - "from langchain.chat_models import ChatOpenAI\n", - "from langchain.memory import ChatMessageHistory\n", - "from langchain.schema import FunctionMessage\n", + "from llama_index.llms import OpenAI, ChatMessage\n", "from llama_index.tools import BaseTool, FunctionTool" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "6fe08eb1-e638-4c00-9103-5c305bfacccf", "metadata": {}, @@ -73,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "3dd3c4a6-f3e0-46f9-ad3b-7ba57d1bc992", "metadata": { "tags": [] @@ -90,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "bfcfb78b-7d4f-48d9-8d4c-ffcded23e7ac", "metadata": { "tags": [] @@ -105,6 +108,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fbcbd5ea-f377-44a0-a492-4568daa8b0b6", "metadata": { @@ -115,6 +119,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5b737e6c-64eb-4ae6-a8f7-350b1953e612", "metadata": {}, @@ -131,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 16, "id": "a0e068f7-fd24-4f74-8243-5e6e4840f7a6", "metadata": { "tags": [] @@ -142,43 +147,45 @@ " def __init__(\n", " self,\n", " tools: Sequence[BaseTool] = [],\n", - " llm: ChatOpenAI = ChatOpenAI(temperature=0, model_name=\"gpt-3.5-turbo-0613\"),\n", - " chat_history: ChatMessageHistory = ChatMessageHistory(),\n", + " llm: OpenAI = OpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n", + " chat_history: List[ChatMessage] = [],\n", " ) -> None:\n", " self._llm = llm\n", " self._tools = {tool.metadata.name: tool for tool in tools}\n", " self._chat_history = chat_history\n", "\n", " def reset(self) -> None:\n", - " self._chat_history.clear()\n", + " self._chat_history = []\n", "\n", " def chat(self, message: str) -> str:\n", " chat_history = self._chat_history\n", - " chat_history.add_user_message(message)\n", + " chat_history.append(ChatMessage(role='user', content=message))\n", " functions = [tool.metadata.to_openai_function() for _, tool in self._tools.items()]\n", "\n", - " ai_message = self._llm.predict_messages(chat_history.messages, functions=functions)\n", - " chat_history.add_message(ai_message)\n", + " ai_message = self._llm.chat(chat_history, functions=functions).message\n", + " chat_history.append(ai_message)\n", "\n", " function_call = ai_message.additional_kwargs.get(\"function_call\", None)\n", " if function_call is not None:\n", " function_message = self._call_function(function_call)\n", - " chat_history.add_message(function_message)\n", - " ai_message = self._llm.predict_messages(chat_history.messages)\n", - " chat_history.add_message(ai_message)\n", + " chat_history.append(function_message)\n", + " ai_message = self._llm.chat(chat_history).message\n", + " chat_history.append(ai_message)\n", "\n", " return ai_message.content\n", "\n", - " def _call_function(self, function_call: dict) -> FunctionMessage:\n", + " def _call_function(self, function_call: dict) -> ChatMessage:\n", " tool = self._tools[function_call[\"name\"]]\n", " output = tool(**json.loads(function_call[\"arguments\"]))\n", - " return FunctionMessage(\n", - " name=function_call[\"name\"],\n", + " return ChatMessage(\n", " content=str(output), \n", + " role='function',\n", + " name=function_call[\"name\"],\n", " )" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fbc2cec5-6cc0-4814-92a1-ca0bd237528f", "metadata": {}, @@ -188,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 17, "id": "08928f6e-610c-420b-8a7b-7a7042bbd6c6", "metadata": { "tags": [] @@ -200,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 18, "id": "d5cefbad-32c4-4273-807a-cc179bae4473", "metadata": { "tags": [] @@ -212,7 +219,7 @@ "'Hello! How can I assist you today?'" ] }, - "execution_count": 9, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -223,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 19, "id": "b8f7650d-57b8-4ef4-b19d-651281ddb1be", "metadata": { "tags": [] @@ -235,7 +242,7 @@ "'The product of 2123 multiplied by 215123 is 456,706,129.'" ] }, - "execution_count": 10, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -245,6 +252,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "707d30b8-6405-4187-a9ed-6146dcc42167", "metadata": { @@ -255,6 +263,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "798ca3fd-6711-4c0c-a853-d868dd14b484", "metadata": {}, @@ -270,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 20, "id": "38ab3938-1138-43ea-b085-f430b42f5377", "metadata": { "tags": [] @@ -282,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 21, "id": "d852ece7-e5a1-4368-9d59-c7014e0b5b4d", "metadata": { "tags": [] @@ -294,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 22, "id": "33ea069f-819b-4ec1-a93c-fcbaacb362a1", "metadata": { "tags": [] @@ -306,8 +315,20 @@ "text": [ "===== Entering Chat REPL =====\n", "Type \"exit\" to exit.\n", - "\n", - "Human: What's 212 * 122 + 213. Make sure to use tools for any calculation\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Human: What's 212 * 122 + 213\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "=== Calling Function ===\n", "Calling function: multiply with args: {\n", " \"a\": 212,\n", @@ -323,8 +344,14 @@ "Got output: 26077\n", "========================\n", "Assistant: The result of 212 * 122 + 213 is 26077.\n", - "\n", - "Human: exit\n" + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Human: exit\n" ] } ], @@ -335,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1a4199b1", + "id": "bc0c2180-3fef-410a-ac3b-74956e5e8a56", "metadata": {}, "outputs": [], "source": [] @@ -343,9 +370,9 @@ ], "metadata": { "kernelspec": { - "display_name": "llama-index", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "llama-index" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -357,7 +384,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/docs/examples/agent/openai_agent_context_retrieval.ipynb b/docs/examples/agent/openai_agent_context_retrieval.ipynb index 6deeabf1d3..34eff8ef7b 100644 --- a/docs/examples/agent/openai_agent_context_retrieval.ipynb +++ b/docs/examples/agent/openai_agent_context_retrieval.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "99cea58c-48bc-4af6-8358-df9695659983", "metadata": { @@ -11,6 +12,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "673df1fe-eb6c-46ea-9a73-a96e7ae7942e", "metadata": { @@ -22,6 +24,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "54b7bc2e-606f-411a-9490-fcfab9236dfc", "metadata": { @@ -32,6 +35,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "23e80e5b-aaee-4f23-b338-7ae62b08141f", "metadata": {}, @@ -51,8 +55,6 @@ "import json\n", "from typing import Sequence\n", "\n", - "from langchain.chat_models import ChatOpenAI\n", - "from langchain.memory import ChatMessageHistory\n", "from llama_index import (\n", " SimpleDirectoryReader, \n", " VectorStoreIndex, \n", @@ -166,6 +168,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "b08efb96-ce44-4706-a22d-b0c670b23a60", "metadata": {}, @@ -334,6 +337,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "ad81c4e1-4ecb-405d-bb03-a4c3549816e7", "metadata": {}, diff --git a/docs/examples/agent/openai_agent_query_cookbook.ipynb b/docs/examples/agent/openai_agent_query_cookbook.ipynb index edd0075562..e9b7b55aac 100644 --- a/docs/examples/agent/openai_agent_query_cookbook.ipynb +++ b/docs/examples/agent/openai_agent_query_cookbook.ipynb @@ -290,11 +290,11 @@ "outputs": [], "source": [ "from llama_index.agent import OpenAIAgent\n", - "from langchain.chat_models import ChatOpenAI\n", + "from llama_index.llms import OpenAI\n", "\n", "agent = OpenAIAgent.from_tools(\n", " [auto_retrieve_tool], \n", - " llm=ChatOpenAI(temperature=0, model_name=\"gpt-4-0613\"),\n", + " llm=OpenAI(temperature=0, model=\"gpt-4-0613\"),\n", " verbose=True\n", ")" ] @@ -766,12 +766,11 @@ "outputs": [], "source": [ "from llama_index.agent import OpenAIAgent\n", - "from langchain.chat_models import ChatOpenAI\n", + "from llama_index.llms import OpenAI\n", "\n", "agent = OpenAIAgent.from_tools(\n", " [sql_tool, vector_tool], \n", - " # llm=ChatOpenAI(temperature=0, model_name=\"gpt-3.5-turbo-0613\"),\n", - " llm=ChatOpenAI(temperature=0, model_name=\"gpt-4-0613\"),\n", + " llm=OpenAI(temperature=0, model=\"gpt-4-0613\"),\n", " verbose=True\n", ")" ] diff --git a/docs/examples/agent/openai_agent_query_plan.ipynb b/docs/examples/agent/openai_agent_query_plan.ipynb index 8ee829844d..695e3d988d 100644 --- a/docs/examples/agent/openai_agent_query_plan.ipynb +++ b/docs/examples/agent/openai_agent_query_plan.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "034e355d-83a0-4bd2-877e-0f493c5f713d", "metadata": { @@ -59,7 +60,6 @@ "source": [ "from llama_index import SimpleDirectoryReader, LLMPredictor, ServiceContext, GPTVectorStoreIndex\n", "from llama_index.response.pprint_utils import pprint_response\n", - "from langchain import OpenAI\n", "from langchain.chat_models import ChatOpenAI" ] }, @@ -81,6 +81,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "e7113434-0e41-46b6-a74e-284ce211fd38", "metadata": { @@ -105,6 +106,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fd541b68-c67f-4cbf-b579-5437d48e5b8f", "metadata": {}, @@ -143,6 +145,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "2d6471ed-5645-4bb0-b8db-1b964ff7cd23", "metadata": { @@ -196,7 +199,6 @@ "source": [ "# define query plan tool\n", "from llama_index.tools import QueryPlanTool\n", - "# from llama_index.tools.query_plan_v2 import QueryPlanTool\n", "from llama_index import ResponseSynthesizer\n", "\n", "response_synthesizer = ResponseSynthesizer.from_args(\n", @@ -267,13 +269,13 @@ "outputs": [], "source": [ "from llama_index.agent import OpenAIAgent\n", - "from llama_index.selectors.llm_selectors import LLMSingleSelector\n", + "from llama_index.llms import OpenAI\n", "\n", "\n", "agent = OpenAIAgent.from_tools(\n", " [query_plan_tool],\n", " max_function_calls=10,\n", - " llm=ChatOpenAI(temperature=0, model_name=\"gpt-4-0613\"),\n", + " llm=OpenAI(temperature=0, model=\"gpt-4-0613\"),\n", " verbose=True\n", ")" ] diff --git a/docs/examples/agent/openai_agent_retrieval.ipynb b/docs/examples/agent/openai_agent_retrieval.ipynb index 4992747175..5206c367cf 100644 --- a/docs/examples/agent/openai_agent_retrieval.ipynb +++ b/docs/examples/agent/openai_agent_retrieval.ipynb @@ -43,7 +43,7 @@ "Let's start by importing some simple building blocks. \n", "\n", "The main thing we need is:\n", - "1. the OpenAI API (we will use langchain's ChatOpenAI wrapper for convienience here.)\n", + "1. the OpenAI API\n", "2. a place to keep conversation history \n", "3. a definition for tools that our agent can use." ] @@ -60,8 +60,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "/Users/suo/miniconda3/envs/llama/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.7) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" ] } ], @@ -69,8 +69,6 @@ "import json\n", "from typing import Sequence\n", "\n", - "from langchain.chat_models import ChatOpenAI\n", - "from langchain.memory import ChatMessageHistory\n", "from llama_index.tools import BaseTool, FunctionTool" ] }, @@ -204,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "id": "33ea069f-819b-4ec1-a93c-fcbaacb362a1", "metadata": { "tags": [] @@ -226,10 +224,10 @@ { "data": { "text/plain": [ - "Response(response='The result of multiplying 212 by 122 is 25,864.', source_nodes=[], extra_info=None)" + "Response(response='212 multiplied by 122 is 25,864.', source_nodes=[], metadata=None)" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -240,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "id": "ec423b90-59cd-40ef-b497-a3842b3e7b58", "metadata": { "tags": [] @@ -262,10 +260,10 @@ { "data": { "text/plain": [ - "Response(response='The sum of 212 and 122 is 334.', source_nodes=[], extra_info=None)" + "Response(response='212 added to 122 is 334.', source_nodes=[], metadata=None)" ] }, - "execution_count": 7, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -285,9 +283,9 @@ ], "metadata": { "kernelspec": { - "display_name": "llama_index_v2", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "llama_index_v2" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -299,7 +297,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/docs/examples/agent/openai_agent_with_query_engine.ipynb b/docs/examples/agent/openai_agent_with_query_engine.ipynb index d1c3492aaa..4bbbd13399 100644 --- a/docs/examples/agent/openai_agent_with_query_engine.ipynb +++ b/docs/examples/agent/openai_agent_with_query_engine.ipynb @@ -25,7 +25,16 @@ "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/suo/miniconda3/envs/llama/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.7) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from llama_index import SimpleDirectoryReader, VectorStoreIndex, StorageContext, load_index_from_storage\n", "\n", @@ -91,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 5, "id": "f9f3158a-7647-4442-8de1-4db80723b4d2", "metadata": { "tags": [] @@ -130,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 6, "id": "32f71a46-bdf6-4365-b1f1-e23a0d913a3d", "metadata": { "tags": [] @@ -142,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 7, "id": "ded93297-fee8-4329-bf37-cf77e87621ae", "metadata": { "tags": [] @@ -164,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "42a1bce0-b398-4937-9008-6cee04368ac4", "metadata": { "tags": [] @@ -195,48 +204,25 @@ " \"input\": \"What was Lyft's revenue growth in 2021?\"\n", "}\n", "Got output: \n", - "Lyft's revenue grew by 36% in 2021 compared to the prior year.\n", + "Lyft's revenue growth in 2021 was 36%.\n", "========================\n", "=== Calling Function ===\n", - "Calling function: uber_10k with args: {\n", + "Calling function: uber_10k with args: \n", + "{\n", " \"input\": \"What was Uber's revenue growth in 2021?\"\n", "}\n", "Got output: \n", "Uber's revenue growth in 2021 was 57%.\n", "========================\n", - "Assistant: In 2021, Lyft's revenue grew by 36% compared to the previous year, while Uber's revenue growth was higher at 57%. This indicates that Uber experienced a faster rate of revenue growth than Lyft in 2021.\n", - "\n", - "The higher revenue growth for Uber can be attributed to several factors. Firstly, Uber has a larger market presence and operates in more countries and cities compared to Lyft. This broader reach allows Uber to capture a larger customer base and generate more revenue.\n", + "Assistant: In 2021, both Lyft and Uber experienced significant revenue growth. Lyft's revenue grew by 36%, while Uber's revenue grew by 57%. \n", "\n", - "Secondly, Uber has diversified its business beyond ride-hailing services. The company has expanded into food delivery with Uber Eats, which has experienced significant growth during the COVID-19 pandemic. This diversification has helped Uber generate additional revenue streams and mitigate the impact of fluctuations in the ride-hailing market.\n", + "The higher revenue growth rate of Uber indicates that it had a stronger performance in terms of generating revenue compared to Lyft. This could be attributed to several factors, including Uber's larger market presence and global reach. Uber operates in more countries and cities compared to Lyft, which allows it to capture a larger customer base and generate more revenue.\n", "\n", - "On the other hand, Lyft's revenue growth, although lower than Uber's, is still significant. Lyft has been focusing on improving its market share in the United States and has made efforts to expand its presence in suburban and rural areas. This targeted approach has helped Lyft attract new customers and increase its revenue.\n", + "However, it's important to note that revenue growth alone does not provide a complete picture of a company's financial performance. Other factors such as profitability, market share, and operational efficiency also play a crucial role in assessing the overall success of a company.\n", "\n", - "Overall, both Lyft and Uber have experienced strong revenue growth in 2021. However, Uber's larger market presence and diversification into other services have contributed to its higher revenue growth rate compared to Lyft. It will be interesting to see how these companies continue to innovate and adapt to changing market conditions in the future.\n", + "In terms of revenue growth, Uber outperformed Lyft in 2021. However, a comprehensive analysis would require considering other financial metrics and factors to get a complete understanding of the two companies' performance in the market.\n", "\n" ] - }, - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Human: Thanks\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Assistant: You're welcome! If you have any more questions, feel free to ask.\n", - "\n" - ] - }, - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Human: exit\n" - ] } ], "source": [ diff --git a/docs/examples/output_parsing/openai_pydantic_program.ipynb b/docs/examples/output_parsing/openai_pydantic_program.ipynb index 2b13ff0492..89f25a70fa 100644 --- a/docs/examples/output_parsing/openai_pydantic_program.ipynb +++ b/docs/examples/output_parsing/openai_pydantic_program.ipynb @@ -36,21 +36,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "id": "f7a83b49-5c34-45d5-8cf4-62f348fb1299", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "from pydantic import BaseModel\n", "from typing import List\n", @@ -68,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "id": "42053ea8-2580-4639-9dcf-566e8427c44e", "metadata": { "tags": [] @@ -76,10 +67,12 @@ "outputs": [], "source": [ "class Song(BaseModel):\n", + " \"\"\"Data model for a song.\"\"\"\n", " title: str\n", " length_seconds: int\n", " \n", "class Album(BaseModel):\n", + " \"\"\"Data model for an album.\"\"\"\n", " name: str\n", " artist: str\n", " songs: List[Song]" @@ -95,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "id": "fe756697-c299-4f9a-a108-944b6693f824", "metadata": { "tags": [] @@ -123,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "id": "25d02228-2907-4810-932e-83ec9fc71f6b", "metadata": { "tags": [] @@ -133,29 +126,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "Function call: output_pydantic with args: {\n", + "Function call: Album with args: {\n", " \"name\": \"The Shining\",\n", - " \"artist\": \"Wendy Carlos\",\n", + " \"artist\": \"Various Artists\",\n", " \"songs\": [\n", " {\n", " \"title\": \"Main Title\",\n", " \"length_seconds\": 180\n", " },\n", " {\n", - " \"title\": \"Rocky Mountains\",\n", - " \"length_seconds\": 240\n", + " \"title\": \"Opening Credits\",\n", + " \"length_seconds\": 120\n", " },\n", " {\n", - " \"title\": \"Lullaby\",\n", - " \"length_seconds\": 150\n", + " \"title\": \"The Overlook Hotel\",\n", + " \"length_seconds\": 240\n", " },\n", " {\n", - " \"title\": \"Music for Strings, Percussion and Celesta\",\n", - " \"length_seconds\": 300\n", + " \"title\": \"Redrum\",\n", + " \"length_seconds\": 150\n", " },\n", " {\n", - " \"title\": \"Midnight\",\n", - " \"length_seconds\": 210\n", + " \"title\": \"Here's Johnny\",\n", + " \"length_seconds\": 200\n", " }\n", " ]\n", "}\n" @@ -178,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "id": "3e51bcf4-e7df-47b9-b380-8e5b900a31e1", "metadata": { "tags": [] @@ -187,10 +180,10 @@ { "data": { "text/plain": [ - "Album(name='The Shining', artist='Wendy Carlos', songs=[Song(title='Main Title', length_seconds=180), Song(title='Rocky Mountains', length_seconds=240), Song(title='Lullaby', length_seconds=150), Song(title='Music for Strings, Percussion and Celesta', length_seconds=300), Song(title='Midnight', length_seconds=210)])" + "Album(name='The Shining', artist='Various Artists', songs=[Song(title='Main Title', length_seconds=180), Song(title='Opening Credits', length_seconds=120), Song(title='The Overlook Hotel', length_seconds=240), Song(title='Redrum', length_seconds=150), Song(title=\"Here's Johnny\", length_seconds=200)])" ] }, - "execution_count": 5, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -215,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "id": "b58f6a12-3f5c-414b-80df-4492f6e18be5", "metadata": { "tags": [] @@ -228,7 +221,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "id": "bc1c7eeb-f0a3-4d72-86ee-c6b63e01b0ff", "metadata": { "tags": [] @@ -238,7 +231,7 @@ "data": { "text/plain": [ "{'title': 'DirectoryTree',\n", - " 'description': 'Container class representing a directory tree.\\n\\nArgs:\\n root (Node): The root node of the tree.\\n\\nMethods:\\n print_paths: Prints the paths of the root node and its children.',\n", + " 'description': 'Container class representing a directory tree.\\n\\nArgs:\\n root (Node): The root node of the tree.',\n", " 'type': 'object',\n", " 'properties': {'root': {'title': 'Root',\n", " 'description': 'Root folder of the directory tree',\n", @@ -249,7 +242,7 @@ " 'enum': ['file', 'folder'],\n", " 'type': 'string'},\n", " 'Node': {'title': 'Node',\n", - " 'description': 'Class representing a single node in a filesystem. Can be either a file or a folder.\\nNote that a file cannot have children, but a folder can.\\n\\nArgs:\\n name (str): The name of the node.\\n children (List[Node]): The list of child nodes (if any).\\n node_type (NodeType): The type of the node, either a file or a folder.\\n\\nMethods:\\n print_paths: Prints the path of the node and its children.',\n", + " 'description': 'Class representing a single node in a filesystem. Can be either a file or a folder.\\nNote that a file cannot have children, but a folder can.\\n\\nArgs:\\n name (str): The name of the node.\\n children (List[Node]): The list of child nodes (if any).\\n node_type (NodeType): The type of the node, either a file or a folder.',\n", " 'type': 'object',\n", " 'properties': {'name': {'title': 'Name',\n", " 'description': 'Name of the folder',\n", @@ -264,7 +257,7 @@ " 'required': ['name']}}}" ] }, - "execution_count": 7, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -275,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "id": "02c4a7a1-f145-40bc-83b8-4153a531a8eb", "metadata": { "tags": [] @@ -291,7 +284,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "id": "c88cf49f-a52f-4415-bddc-14d70c897629", "metadata": { "tags": [] @@ -375,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "id": "3885032f-0f3a-4afb-9157-54851e810843", "metadata": { "tags": [] @@ -387,7 +380,7 @@ "DirectoryTree(root=Node(name='root', children=[Node(name='folder1', children=[Node(name='file1.txt', children=[], node_type=<NodeType.FILE: 'file'>), Node(name='file2.txt', children=[], node_type=<NodeType.FILE: 'file'>)], node_type=<NodeType.FOLDER: 'folder'>), Node(name='folder2', children=[Node(name='file3.txt', children=[], node_type=<NodeType.FILE: 'file'>), Node(name='subfolder1', children=[Node(name='file4.txt', children=[], node_type=<NodeType.FILE: 'file'>)], node_type=<NodeType.FOLDER: 'folder'>)], node_type=<NodeType.FOLDER: 'folder'>)], node_type=<NodeType.FOLDER: 'folder'>))" ] }, - "execution_count": 10, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -395,13 +388,21 @@ "source": [ "output" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75e712e2-5f84-4899-a000-54a7a53fc72c", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "llama_index_v2", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "llama_index_v2" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -413,7 +414,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/llama_index/agent/context_retriever_agent.py b/llama_index/agent/context_retriever_agent.py index 34a3f80b64..71293555f4 100644 --- a/llama_index/agent/context_retriever_agent.py +++ b/llama_index/agent/context_retriever_agent.py @@ -2,19 +2,21 @@ from typing import List, Optional -from llama_index.bridge.langchain import ChatMessageHistory, ChatOpenAI, print_text - -from llama_index.callbacks.base import CallbackManager -from llama_index.schema import NodeWithScore -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.response.schema import RESPONSE_TYPE -from llama_index.tools import BaseTool -from llama_index.prompts.prompts import QuestionAnswerPrompt from llama_index.agent.openai_agent import ( - BaseOpenAIAgent, DEFAULT_MAX_FUNCTION_CALLS, + DEFAULT_MODEL_NAME, SUPPORTED_MODEL_NAMES, + BaseOpenAIAgent, ) +from llama_index.bridge.langchain import print_text +from llama_index.callbacks.base import CallbackManager +from llama_index.indices.base_retriever import BaseRetriever +from llama_index.llms.base import ChatMessage +from llama_index.llms.openai import OpenAI +from llama_index.prompts.prompts import QuestionAnswerPrompt +from llama_index.response.schema import RESPONSE_TYPE +from llama_index.schema import NodeWithScore +from llama_index.tools import BaseTool # inspired by DEFAULT_QA_PROMPT_TMPL from llama_index/prompts/default_prompts.py DEFAULT_QA_PROMPT_TMPL = ( @@ -37,11 +39,12 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): NOTE: this is a beta feature, function interfaces might change. Args: + tools (List[BaseTool]): A list of tools. retriever (BaseRetriever): A retriever. qa_prompt (Optional[QuestionAnswerPrompt]): A QA prompt. context_separator (str): A context separator. - llm (Optional[ChatOpenAI]): An LLM. - chat_history (Optional[ChatMessageHistory]): A chat history. + llm (Optional[OpenAI]): An LLM. + chat_history (Optional[List[ChatMessage]]): A chat history. verbose (bool): Whether to print debug statements. max_function_calls (int): Maximum number of function calls. callback_manager (Optional[CallbackManager]): A callback manager. @@ -54,8 +57,8 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): retriever: BaseRetriever, qa_prompt: QuestionAnswerPrompt, context_separator: str, - llm: ChatOpenAI, - chat_history: ChatMessageHistory, + llm: OpenAI, + chat_history: List[ChatMessage], verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, @@ -79,8 +82,8 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): retriever: BaseRetriever, qa_prompt: Optional[QuestionAnswerPrompt] = None, context_separator: str = "\n", - llm: Optional[ChatOpenAI] = None, - chat_history: Optional[ChatMessageHistory] = None, + llm: Optional[OpenAI] = None, + chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, @@ -99,14 +102,14 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): """ qa_prompt = qa_prompt or DEFAULT_QA_PROMPT - lc_chat_history = chat_history or ChatMessageHistory() - llm = llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") - if not isinstance(llm, ChatOpenAI): - raise ValueError("llm must be a ChatOpenAI instance") + chat_history = chat_history or [] + llm = llm or OpenAI(model=DEFAULT_MODEL_NAME) + if not isinstance(llm, OpenAI): + raise ValueError("llm must be a OpenAI instance") - if llm.model_name not in SUPPORTED_MODEL_NAMES: + if llm.model not in SUPPORTED_MODEL_NAMES: raise ValueError( - f"Model name {llm.model_name} not supported. " + f"Model name {llm.model} not supported. " f"Supported model names: {SUPPORTED_MODEL_NAMES}" ) @@ -116,7 +119,7 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): qa_prompt=qa_prompt, context_separator=context_separator, llm=llm, - chat_history=lc_chat_history, + chat_history=chat_history, verbose=verbose, max_function_calls=max_function_calls, callback_manager=callback_manager, @@ -127,7 +130,7 @@ class ContextRetrieverOpenAIAgent(BaseOpenAIAgent): return self._tools def chat( - self, message: str, chat_history: Optional[ChatMessageHistory] = None + self, message: str, chat_history: Optional[List[ChatMessage]] = None ) -> RESPONSE_TYPE: """Chat.""" # augment user message diff --git a/llama_index/agent/openai_agent.py b/llama_index/agent/openai_agent.py index 93fb0b553e..ea773bbfa2 100644 --- a/llama_index/agent/openai_agent.py +++ b/llama_index/agent/openai_agent.py @@ -2,18 +2,19 @@ import json from abc import abstractmethod from typing import Callable, List, Optional -from llama_index.bridge.langchain import FunctionMessage, ChatMessageHistory, ChatOpenAI - from llama_index.callbacks.base import CallbackManager from llama_index.chat_engine.types import BaseChatEngine -from llama_index.schema import BaseNode, NodeWithScore from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle +from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.openai import OpenAI from llama_index.response.schema import RESPONSE_TYPE, Response +from llama_index.schema import BaseNode, NodeWithScore from llama_index.tools import BaseTool DEFAULT_MAX_FUNCTION_CALLS = 5 +DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613" SUPPORTED_MODEL_NAMES = [ "gpt-3.5-turbo-0613", "gpt-4-0613", @@ -30,7 +31,7 @@ def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool: def call_function( tools: List[BaseTool], function_call: dict, verbose: bool = False -) -> FunctionMessage: +) -> ChatMessage: """Call a function and return the output as a string.""" name = function_call["name"] arguments_str = function_call["arguments"] @@ -43,7 +44,13 @@ def call_function( if verbose: print(f"Got output: {output}") print("========================") - return FunctionMessage(content=str(output), name=function_call["name"]) + return ChatMessage( + content=str(output), + role=MessageRole.FUNCTION, + additional_kwargs={ + "name": function_call["name"], + }, + ) class BaseOpenAIAgent(BaseChatEngine, BaseQueryEngine): @@ -51,8 +58,8 @@ class BaseOpenAIAgent(BaseChatEngine, BaseQueryEngine): def __init__( self, - llm: ChatOpenAI, - chat_history: ChatMessageHistory, + llm: OpenAI, + chat_history: List[ChatMessage], verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, @@ -71,18 +78,17 @@ class BaseOpenAIAgent(BaseChatEngine, BaseQueryEngine): """Get tools.""" def chat( - self, message: str, chat_history: Optional[ChatMessageHistory] = None + self, message: str, chat_history: Optional[List[ChatMessage]] = None ) -> RESPONSE_TYPE: chat_history = chat_history or self._chat_history - chat_history.add_user_message(message) + chat_history.append(ChatMessage(content=message, role="user")) tools = self._get_tools(message) functions = [tool.metadata.to_openai_function() for tool in tools] # TODO: Support forced function call - ai_message = self._llm.predict_messages( - chat_history.messages, functions=functions - ) - chat_history.add_message(ai_message) + chat_response = self._llm.chat(chat_history, functions=functions) + ai_message = chat_response.message + chat_history.append(ai_message) n_function_calls = 0 function_call = ai_message.additional_kwargs.get("function_call", None) @@ -94,31 +100,29 @@ class BaseOpenAIAgent(BaseChatEngine, BaseQueryEngine): function_message = call_function( tools, function_call, verbose=self._verbose ) - chat_history.add_message(function_message) + chat_history.append(function_message) n_function_calls += 1 # send function call & output back to get another response - ai_message = self._llm.predict_messages( - chat_history.messages, functions=functions - ) - chat_history.add_message(ai_message) + chat_response = self._llm.chat(chat_history, functions=functions) + ai_message = chat_response.message + chat_history.append(ai_message) function_call = ai_message.additional_kwargs.get("function_call", None) return Response(ai_message.content) async def achat( - self, message: str, chat_history: Optional[ChatMessageHistory] = None + self, message: str, chat_history: Optional[List[ChatMessage]] = None ) -> RESPONSE_TYPE: chat_history = chat_history or self._chat_history - chat_history.add_user_message(message) + chat_history.append(ChatMessage(content=message, role="user")) tools = self._get_tools(message) functions = [tool.metadata.to_openai_function() for tool in tools] # TODO: Support forced function call - ai_message = await self._llm.apredict_messages( - chat_history.messages, functions=functions - ) - chat_history.add_message(ai_message) + chat_response = await self._llm.achat(chat_history, functions=functions) + ai_message = chat_response.message + chat_history.append(ai_message) n_function_calls = 0 function_call = ai_message.additional_kwargs.get("function_call", None) @@ -130,14 +134,13 @@ class BaseOpenAIAgent(BaseChatEngine, BaseQueryEngine): function_message = call_function( tools, function_call, verbose=self._verbose ) - chat_history.add_message(function_message) + chat_history.append(function_message) n_function_calls += 1 # send function call & output back to get another response - ai_message = await self._llm.apredict_messages( - chat_history.messages, functions=functions - ) - chat_history.add_message(ai_message) + response = await self._llm.achat(chat_history, functions=functions) + ai_message = response.message + chat_history.append(ai_message) function_call = ai_message.additional_kwargs.get("function_call", None) return Response(ai_message.content) @@ -146,13 +149,13 @@ class BaseOpenAIAgent(BaseChatEngine, BaseQueryEngine): def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: return self.chat( query_bundle.query_str, - chat_history=ChatMessageHistory(), + chat_history=[], ) async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: return await self.achat( query_bundle.query_str, - chat_history=ChatMessageHistory(), + chat_history=[], ) @@ -160,8 +163,8 @@ class OpenAIAgent(BaseOpenAIAgent): def __init__( self, tools: List[BaseTool], - llm: ChatOpenAI, - chat_history: ChatMessageHistory, + llm: OpenAI, + chat_history: List[ChatMessage], verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, @@ -179,28 +182,28 @@ class OpenAIAgent(BaseOpenAIAgent): def from_tools( cls, tools: Optional[List[BaseTool]] = None, - llm: Optional[ChatOpenAI] = None, - chat_history: Optional[ChatMessageHistory] = None, + llm: Optional[OpenAI] = None, + chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, ) -> "OpenAIAgent": tools = tools or [] - lc_chat_history = chat_history or ChatMessageHistory() - llm = llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") - if not isinstance(llm, ChatOpenAI): - raise ValueError("llm must be a ChatOpenAI instance") + chat_history = chat_history or [] + llm = llm or OpenAI(model=DEFAULT_MODEL_NAME) + if not isinstance(llm, OpenAI): + raise ValueError("llm must be a OpenAI instance") - if llm.model_name not in SUPPORTED_MODEL_NAMES: + if llm.model not in SUPPORTED_MODEL_NAMES: raise ValueError( - f"Model name {llm.model_name} not supported. " + f"Model name {llm.model} not supported. " f"Supported model names: {SUPPORTED_MODEL_NAMES}" ) return cls( tools=tools, llm=llm, - chat_history=lc_chat_history, + chat_history=chat_history, verbose=verbose, max_function_calls=max_function_calls, callback_manager=callback_manager, @@ -229,8 +232,8 @@ class RetrieverOpenAIAgent(BaseOpenAIAgent): self, retriever: BaseRetriever, node_to_tool_fn: Callable[[BaseNode], BaseTool], - llm: ChatOpenAI, - chat_history: ChatMessageHistory, + llm: OpenAI, + chat_history: List[ChatMessage], verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, @@ -250,20 +253,20 @@ class RetrieverOpenAIAgent(BaseOpenAIAgent): cls, retriever: BaseRetriever, node_to_tool_fn: Callable[[BaseNode], BaseTool], - llm: Optional[ChatOpenAI] = None, - chat_history: Optional[ChatMessageHistory] = None, + llm: Optional[OpenAI] = None, + chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, ) -> "RetrieverOpenAIAgent": - lc_chat_history = chat_history or ChatMessageHistory() - llm = llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") - if not isinstance(llm, ChatOpenAI): - raise ValueError("llm must be a ChatOpenAI instance") + lc_chat_history = chat_history or [] + llm = llm or OpenAI(model=DEFAULT_MODEL_NAME) + if not isinstance(llm, OpenAI): + raise ValueError("llm must be a OpenAI instance") - if llm.model_name not in SUPPORTED_MODEL_NAMES: + if llm.model not in SUPPORTED_MODEL_NAMES: raise ValueError( - f"Model name {llm.model_name} not supported. " + f"Model name {llm.model} not supported. " f"Supported model names: {SUPPORTED_MODEL_NAMES}" ) diff --git a/llama_index/agent/retriever_openai_agent.py b/llama_index/agent/retriever_openai_agent.py index e03f0e6b80..39bbcbe3e3 100644 --- a/llama_index/agent/retriever_openai_agent.py +++ b/llama_index/agent/retriever_openai_agent.py @@ -1,15 +1,18 @@ """Retriever OpenAI agent.""" -from llama_index.agent.openai_agent import BaseOpenAIAgent -from llama_index.objects.base import ObjectRetriever -from llama_index.tools.types import BaseTool -from typing import Optional, List -from llama_index.bridge.langchain import ChatOpenAI, ChatMessageHistory -from llama_index.callbacks.base import CallbackManager +from typing import List, Optional + from llama_index.agent.openai_agent import ( - SUPPORTED_MODEL_NAMES, DEFAULT_MAX_FUNCTION_CALLS, + DEFAULT_MODEL_NAME, + SUPPORTED_MODEL_NAMES, + BaseOpenAIAgent, ) +from llama_index.callbacks.base import CallbackManager +from llama_index.llms.base import ChatMessage +from llama_index.llms.openai import OpenAI +from llama_index.objects.base import ObjectRetriever +from llama_index.tools.types import BaseTool class FnRetrieverOpenAIAgent(BaseOpenAIAgent): @@ -22,8 +25,8 @@ class FnRetrieverOpenAIAgent(BaseOpenAIAgent): def __init__( self, retriever: ObjectRetriever[BaseTool], - llm: ChatOpenAI, - chat_history: ChatMessageHistory, + llm: OpenAI, + chat_history: List[ChatMessage], verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, @@ -41,27 +44,27 @@ class FnRetrieverOpenAIAgent(BaseOpenAIAgent): def from_retriever( cls, retriever: ObjectRetriever[BaseTool], - llm: Optional[ChatOpenAI] = None, - chat_history: Optional[ChatMessageHistory] = None, + llm: Optional[OpenAI] = None, + chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, ) -> "FnRetrieverOpenAIAgent": - lc_chat_history = chat_history or ChatMessageHistory() - llm = llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") - if not isinstance(llm, ChatOpenAI): - raise ValueError("llm must be a ChatOpenAI instance") + chat_history = chat_history or [] + llm = llm or OpenAI(model=DEFAULT_MODEL_NAME) + if not isinstance(llm, OpenAI): + raise ValueError("llm must be a OpenAI instance") - if llm.model_name not in SUPPORTED_MODEL_NAMES: + if llm.model not in SUPPORTED_MODEL_NAMES: raise ValueError( - f"Model name {llm.model_name} not supported. " + f"Model name {llm.model} not supported. " f"Supported model names: {SUPPORTED_MODEL_NAMES}" ) return cls( retriever=retriever, llm=llm, - chat_history=lc_chat_history, + chat_history=chat_history, verbose=verbose, max_function_calls=max_function_calls, callback_manager=callback_manager, diff --git a/llama_index/program/openai_program.py b/llama_index/program/openai_program.py index 7bb26f51a7..71ebdf8124 100644 --- a/llama_index/program/openai_program.py +++ b/llama_index/program/openai_program.py @@ -1,6 +1,9 @@ from typing import Any, Dict, Optional, Type, Union -from llama_index.bridge.langchain import ChatOpenAI, HumanMessage +from llama_index.llms.openai import OpenAI +from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.openai_utils import to_openai_function +from llama_index.types import Model from llama_index.program.llm_prompt_program import BaseLLMFunctionProgram from llama_index.prompts.base import Prompt @@ -12,17 +15,7 @@ SUPPORTED_MODEL_NAMES = [ ] -def _openai_function(output_cls: Type[BaseModel]) -> Dict[str, Any]: - """Convert pydantic class to OpenAI function.""" - schema = output_cls.schema() - return { - "name": schema["title"], - "description": schema["description"], - "parameters": output_cls.schema(), - } - - -def _openai_function_call(output_cls: Type[BaseModel]) -> Dict[str, Any]: +def _default_function_call(output_cls: Type[BaseModel]) -> Dict[str, Any]: """Default OpenAI function to call.""" schema = output_cls.schema() return { @@ -30,7 +23,7 @@ def _openai_function_call(output_cls: Type[BaseModel]) -> Dict[str, Any]: } -class OpenAIPydanticProgram(BaseLLMFunctionProgram[ChatOpenAI]): +class OpenAIPydanticProgram(BaseLLMFunctionProgram[OpenAI]): """ An OpenAI-based function that returns a pydantic model. @@ -39,8 +32,8 @@ class OpenAIPydanticProgram(BaseLLMFunctionProgram[ChatOpenAI]): def __init__( self, - output_cls: Type[BaseModel], - llm: ChatOpenAI, + output_cls: Type[Model], + llm: OpenAI, prompt: Prompt, function_call: Union[str, Dict[str, Any]], verbose: bool = False, @@ -55,24 +48,24 @@ class OpenAIPydanticProgram(BaseLLMFunctionProgram[ChatOpenAI]): @classmethod def from_defaults( cls, - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt_template_str: str, - llm: Optional[ChatOpenAI] = None, + llm: Optional[OpenAI] = None, verbose: bool = False, function_call: Optional[Union[str, Dict[str, Any]]] = None, **kwargs: Any, - ) -> "BaseLLMFunctionProgram": - llm = llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613") - if not isinstance(llm, ChatOpenAI): - raise ValueError("llm must be a ChatOpenAI instance") + ) -> "OpenAIPydanticProgram": + llm = llm or OpenAI(model="gpt-3.5-turbo-0613") + if not isinstance(llm, OpenAI): + raise ValueError("llm must be a OpenAI instance") - if llm.model_name not in SUPPORTED_MODEL_NAMES: + if llm.model not in SUPPORTED_MODEL_NAMES: raise ValueError( - f"Model name {llm.model_name} not supported. " + f"Model name {llm.model} not supported. " f"Supported model names: {SUPPORTED_MODEL_NAMES}" ) prompt = Prompt(prompt_template_str) - function_call = function_call or {"name": output_cls.schema()["title"]} + function_call = function_call or _default_function_call(output_cls) return cls( output_cls=output_cls, llm=llm, @@ -92,21 +85,21 @@ class OpenAIPydanticProgram(BaseLLMFunctionProgram[ChatOpenAI]): ) -> BaseModel: formatted_prompt = self._prompt.format(**kwargs) - openai_fn_spec = _openai_function(self._output_cls) + openai_fn_spec = to_openai_function(self._output_cls) - ai_message = self._llm.predict_messages( - messages=[HumanMessage(content=formatted_prompt)], + chat_response = self._llm.chat( + messages=[ChatMessage(role=MessageRole.USER, content=formatted_prompt)], functions=[openai_fn_spec], - # TODO: support forcing the desired function call function_call=self._function_call, ) - if "function_call" not in ai_message.additional_kwargs: + message = chat_response.message + if "function_call" not in message.additional_kwargs: raise ValueError( "Expected function call in ai_message.additional_kwargs, " "but none found." ) - function_call = ai_message.additional_kwargs["function_call"] + function_call = message.additional_kwargs["function_call"] if self._verbose: name = function_call["name"] arguments_str = function_call["arguments"] diff --git a/llama_index/selectors/pydantic_selectors.py b/llama_index/selectors/pydantic_selectors.py index 6d7b021565..d3eef300a8 100644 --- a/llama_index/selectors/pydantic_selectors.py +++ b/llama_index/selectors/pydantic_selectors.py @@ -1,11 +1,9 @@ -from llama_index.bridge.langchain import ChatOpenAI from typing import Any, Optional, Sequence from llama_index.indices.query.schema import QueryBundle -from llama_index.program.openai_program import ( - OpenAIPydanticProgram, -) +from llama_index.llms.openai import OpenAI from llama_index.program.base_program import BasePydanticProgram +from llama_index.program.openai_program import OpenAIPydanticProgram from llama_index.selectors.llm_selectors import _build_choices_text from llama_index.selectors.prompts import ( DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL, @@ -13,8 +11,8 @@ from llama_index.selectors.prompts import ( ) from llama_index.selectors.types import ( BaseSelector, - SelectorResult, MultiSelection, + SelectorResult, SingleSelection, ) from llama_index.tools.types import ToolMetadata @@ -44,7 +42,7 @@ class PydanticSingleSelector(BaseSelector): def from_defaults( cls, program: Optional[BasePydanticProgram] = None, - llm: Optional[ChatOpenAI] = None, + llm: Optional[OpenAI] = None, prompt_template_str: str = DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL, verbose: bool = False, ) -> "PydanticSingleSelector": @@ -93,7 +91,7 @@ class PydanticMultiSelector(BaseSelector): def from_defaults( cls, program: Optional[BasePydanticProgram] = None, - llm: Optional[ChatOpenAI] = None, + llm: Optional[OpenAI] = None, prompt_template_str: str = DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL, max_outputs: Optional[int] = None, verbose: bool = False, -- GitLab