From 95add73c38275b2d222e7f0373545b0791aa8df5 Mon Sep 17 00:00:00 2001 From: Emanuel Ferreira <contatoferreirads@gmail.com> Date: Sun, 18 Feb 2024 18:43:52 -0300 Subject: [PATCH] feat: multi-document agents (#531) --- .changeset/thin-seals-cough.md | 5 + apps/docs/docs/modules/agent/_category_.yml | 1 + .../modules/agent/multi_document_agent.mdx | 314 ++++++++++++++++++ apps/docs/docs/modules/agent/openai.mdx | 4 + .../docs/modules/agent/query_engine_tool.mdx | 4 + examples/.gitignore | 1 + examples/agent/helpers/extractWikipedia.ts | 55 +++ examples/agent/multi_document_agent.ts | 157 +++++++++ examples/agent/query_openai_agent.ts | 46 +++ examples/vectorIndexEmbed3.ts | 1 + packages/core/package.json | 5 + packages/core/src/agent/openai/base.ts | 4 +- packages/core/src/agent/openai/worker.ts | 7 +- packages/core/src/agent/react/base.ts | 2 +- packages/core/src/agent/react/worker.ts | 2 +- .../engines/query/SubQuestionQueryEngine.ts | 2 + packages/core/src/index.ts | 1 + .../indices/vectorStore/VectorStoreIndex.ts | 4 +- packages/core/src/llm/LLM.ts | 7 +- packages/core/src/llm/openai/utils.ts | 7 - packages/core/src/objects/base.ts | 160 +++++++-- packages/core/src/objects/index.ts | 1 + .../storage/vectorStore/SimpleVectorStore.ts | 1 - packages/core/src/tests/Embedding.test.ts | 4 +- .../core/src/tests/llms/openai/utils.test.ts | 6 +- .../src/tests/objects/ObjectIndex.test.ts | 130 ++++++++ packages/core/src/tests/utility/mockOpenAI.ts | 10 +- packages/core/src/tools/index.ts | 1 + 28 files changed, 895 insertions(+), 47 deletions(-) create mode 100644 .changeset/thin-seals-cough.md create mode 100644 apps/docs/docs/modules/agent/multi_document_agent.mdx create mode 100644 examples/agent/helpers/extractWikipedia.ts create mode 100644 examples/agent/multi_document_agent.ts create mode 100644 examples/agent/query_openai_agent.ts delete mode 100644 packages/core/src/llm/openai/utils.ts create mode 100644 packages/core/src/objects/index.ts create mode 100644 packages/core/src/tests/objects/ObjectIndex.test.ts diff --git a/.changeset/thin-seals-cough.md b/.changeset/thin-seals-cough.md new file mode 100644 index 000000000..6d07e650d --- /dev/null +++ b/.changeset/thin-seals-cough.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: multi-document agent diff --git a/apps/docs/docs/modules/agent/_category_.yml b/apps/docs/docs/modules/agent/_category_.yml index 549b74c5a..0ea1a6657 100644 --- a/apps/docs/docs/modules/agent/_category_.yml +++ b/apps/docs/docs/modules/agent/_category_.yml @@ -1 +1,2 @@ label: "Agents" +position: 3 diff --git a/apps/docs/docs/modules/agent/multi_document_agent.mdx b/apps/docs/docs/modules/agent/multi_document_agent.mdx new file mode 100644 index 000000000..cbe9e3c21 --- /dev/null +++ b/apps/docs/docs/modules/agent/multi_document_agent.mdx @@ -0,0 +1,314 @@ +# Multi-Document Agent + +In this guide, you learn towards setting up an agent that can effectively answer different types of questions over a larger set of documents. + +These questions include the following + +- QA over a specific doc +- QA comparing different docs +- Summaries over a specific doc +- Comparing summaries between different docs + +We do this with the following architecture: + +- setup a “document agent” over each Document: each doc agent can do QA/summarization within its doc +- setup a top-level agent over this set of document agents. Do tool retrieval and then do CoT over the set of tools to answer a question. + +## Setup and Download Data + +We first start by installing the necessary libraries and downloading the data. + +```bash +pnpm i llamaindex +``` + +```ts +import { + Document, + ObjectIndex, + OpenAI, + OpenAIAgent, + QueryEngineTool, + SimpleNodeParser, + SimpleToolNodeMapping, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, + storageContextFromDefaults, +} from "llamaindex"; +``` + +And then for the data we will run through a list of countries and download the wikipedia page for each country. + +```ts +import fs from "fs"; +import path from "path"; + +const dataPath = path.join(__dirname, "tmp_data"); + +const extractWikipediaTitle = async (title: string) => { + const fileExists = fs.existsSync(path.join(dataPath, `${title}.txt`)); + + if (fileExists) { + console.log(`File already exists for the title: ${title}`); + return; + } + + const queryParams = new URLSearchParams({ + action: "query", + format: "json", + titles: title, + prop: "extracts", + explaintext: "true", + }); + + const url = `https://en.wikipedia.org/w/api.php?${queryParams}`; + + const response = await fetch(url); + const data: any = await response.json(); + + const pages = data.query.pages; + const page = pages[Object.keys(pages)[0]]; + const wikiText = page.extract; + + await new Promise((resolve) => { + fs.writeFile(path.join(dataPath, `${title}.txt`), wikiText, (err: any) => { + if (err) { + console.error(err); + resolve(title); + return; + } + console.log(`${title} stored in file!`); + + resolve(title); + }); + }); +}; +``` + +```ts +export const extractWikipedia = async (titles: string[]) => { + if (!fs.existsSync(dataPath)) { + fs.mkdirSync(dataPath); + } + + for await (const title of titles) { + await extractWikipediaTitle(title); + } + + console.log("Extration finished!"); +``` + +These files will be saved in the `tmp_data` folder. + +Now we can call the function to download the data for each country. + +```ts +await extractWikipedia([ + "Brazil", + "United States", + "Canada", + "Mexico", + "Argentina", + "Chile", + "Colombia", + "Peru", + "Venezuela", + "Ecuador", + "Bolivia", + "Paraguay", + "Uruguay", + "Guyana", + "Suriname", + "French Guiana", + "Falkland Islands", +]); +``` + +## Load the data + +Now that we have the data, we can load it into the LlamaIndex and store as a document. + +```ts +import { Document } from "llamaindex"; + +const countryDocs: Record<string, Document> = {}; + +for (const title of wikiTitles) { + const path = `./agent/helpers/tmp_data/${title}.txt`; + const text = await fs.readFile(path, "utf-8"); + const document = new Document({ text: text, id_: path }); + countryDocs[title] = document; +} +``` + +## Setup LLM and StorageContext + +We will be using gpt-4 for this example and we will use the `StorageContext` to store the documents in-memory. + +```ts +const llm = new OpenAI({ + model: "gpt-4", +}); + +const ctx = serviceContextFromDefaults({ llm }); + +const storageContext = await storageContextFromDefaults({ + persistDir: "./storage", +}); +``` + +## Building Multi-Document Agents + +In this section we show you how to construct the multi-document agent. We first build a document agent for each document, and then define the top-level parent agent with an object index. + +```ts +const documentAgents: Record<string, any> = {}; +const queryEngines: Record<string, any> = {}; +``` + +Now we iterate over each country and create a document agent for each one. + +### Build Agent for each Document + +In this section we define “document agents” for each document. + +We define both a vector index (for semantic search) and summary index (for summarization) for each document. The two query engines are then converted into tools that are passed to an OpenAI function calling agent. + +This document agent can dynamically choose to perform semantic search or summarization within a given document. + +We create a separate document agent for each coutnry. + +```ts +for (const title of wikiTitles) { + // parse the document into nodes + const nodes = new SimpleNodeParser({ + chunkSize: 200, + chunkOverlap: 20, + }).getNodesFromDocuments([countryDocs[title]]); + + // create the vector index for specific search + const vectorIndex = await VectorStoreIndex.init({ + serviceContext: serviceContext, + storageContext: storageContext, + nodes, + }); + + // create the summary index for broader search + const summaryIndex = await SummaryIndex.init({ + serviceContext: serviceContext, + nodes, + }); + + const vectorQueryEngine = summaryIndex.asQueryEngine(); + const summaryQueryEngine = summaryIndex.asQueryEngine(); + + // create the query engines for each task + const queryEngineTools = [ + new QueryEngineTool({ + queryEngine: vectorQueryEngine, + metadata: { + name: "vector_tool", + description: `Useful for questions related to specific aspects of ${title} (e.g. the history, arts and culture, sports, demographics, or more).`, + }, + }), + new QueryEngineTool({ + queryEngine: summaryQueryEngine, + metadata: { + name: "summary_tool", + description: `Useful for any requests that require a holistic summary of EVERYTHING about ${title}. For questions about more specific sections, please use the vector_tool.`, + }, + }), + ]; + + // create the document agent + const agent = new OpenAIAgent({ + tools: queryEngineTools, + llm, + verbose: true, + }); + + documentAgents[title] = agent; + queryEngines[title] = vectorIndex.asQueryEngine(); +} +``` + +## Build Top-Level Agent + +Now we define the top-level agent that can answer questions over the set of document agents. + +This agent takes in all document agents as tools. This specific agent RetrieverOpenAIAgent performs tool retrieval before tool use (unlike a default agent that tries to put all tools in the prompt). + +Here we use a top-k retriever, but we encourage you to customize the tool retriever method! + +Firstly, we create a tool for each document agent + +```ts +const allTools: QueryEngineTool[] = []; +``` + +```ts +for (const title of wikiTitles) { + const wikiSummary = ` + This content contains Wikipedia articles about ${title}. + Use this tool if you want to answer any questions about ${title} + `; + + const docTool = new QueryEngineTool({ + queryEngine: documentAgents[title], + metadata: { + name: `tool_${title}`, + description: wikiSummary, + }, + }); + + allTools.push(docTool); +} +``` + +Our top level agent will use this document agents as tools and use toolRetriever to retrieve the best tool to answer a question. + +```ts +// map the tools to nodes +const toolMapping = SimpleToolNodeMapping.fromObjects(allTools); + +// create the object index +const objectIndex = await ObjectIndex.fromObjects( + allTools, + toolMapping, + VectorStoreIndex, + { + serviceContext, + storageContext, + }, +); + +// create the top agent +const topAgent = new OpenAIAgent({ + toolRetriever: await objectIndex.asRetriever({}), + llm, + verbose: true, + prefixMessages: [ + { + content: + "You are an agent designed to answer queries about a set of given countries. Please always use the tools provided to answer a question. Do not rely on prior knowledge.", + role: "system", + }, + ], +}); +``` + +## Use the Agent + +Now we can use the agent to answer questions. + +```ts +const response = await topAgent.chat({ + message: "Tell me the differences between Brazil and Canada economics?", +}); + +// print output +console.log(response); +``` + +You can find the full code for this example [here](https://github.com/run-llama/LlamaIndexTS/tree/main/examples/agent/multi-document-agent.ts) diff --git a/apps/docs/docs/modules/agent/openai.mdx b/apps/docs/docs/modules/agent/openai.mdx index f0130a985..37a4364f9 100644 --- a/apps/docs/docs/modules/agent/openai.mdx +++ b/apps/docs/docs/modules/agent/openai.mdx @@ -1,3 +1,7 @@ +--- +sidebar_position: 0 +--- + # OpenAI Agent OpenAI API that supports function calling, it’s never been easier to build your own agent! diff --git a/apps/docs/docs/modules/agent/query_engine_tool.mdx b/apps/docs/docs/modules/agent/query_engine_tool.mdx index 83ce66f92..85b007014 100644 --- a/apps/docs/docs/modules/agent/query_engine_tool.mdx +++ b/apps/docs/docs/modules/agent/query_engine_tool.mdx @@ -1,3 +1,7 @@ +--- +sidebar_position: 1 +--- + # OpenAI Agent + QueryEngineTool QueryEngineTool is a tool that allows you to query a vector index. In this example, we will create a vector index from a set of documents and then create a QueryEngineTool from the vector index. We will then create an OpenAIAgent with the QueryEngineTool and chat with the agent. diff --git a/examples/.gitignore b/examples/.gitignore index 8a77cb22f..0e6a8af04 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -1,2 +1,3 @@ package-lock.json storage +tmp_data \ No newline at end of file diff --git a/examples/agent/helpers/extractWikipedia.ts b/examples/agent/helpers/extractWikipedia.ts new file mode 100644 index 000000000..05f391c29 --- /dev/null +++ b/examples/agent/helpers/extractWikipedia.ts @@ -0,0 +1,55 @@ +import fs from "fs"; +import path from "path"; + +const dataPath = path.join(__dirname, "tmp_data"); + +const extractWikipediaTitle = async (title: string) => { + const fileExists = fs.existsSync(path.join(dataPath, `${title}.txt`)); + + if (fileExists) { + console.log(`Arquivo já existe para o título: ${title}`); + return; + } + + const queryParams = new URLSearchParams({ + action: "query", + format: "json", + titles: title, + prop: "extracts", + explaintext: "true", + }); + + const url = `https://en.wikipedia.org/w/api.php?${queryParams}`; + + const response = await fetch(url); + const data: any = await response.json(); + + const pages = data.query.pages; + const page = pages[Object.keys(pages)[0]]; + const wikiText = page.extract; + + await new Promise((resolve) => { + fs.writeFile(path.join(dataPath, `${title}.txt`), wikiText, (err: any) => { + if (err) { + console.error(err); + resolve(title); + return; + } + console.log(`${title} stored!`); + + resolve(title); + }); + }); +}; + +export const extractWikipedia = async (titles: string[]) => { + if (!fs.existsSync(dataPath)) { + fs.mkdirSync(dataPath); + } + + for await (const title of titles) { + await extractWikipediaTitle(title); + } + + console.log("Extration finished!"); +}; diff --git a/examples/agent/multi_document_agent.ts b/examples/agent/multi_document_agent.ts new file mode 100644 index 000000000..b14449fff --- /dev/null +++ b/examples/agent/multi_document_agent.ts @@ -0,0 +1,157 @@ +import fs from "node:fs/promises"; + +import { + Document, + ObjectIndex, + OpenAI, + OpenAIAgent, + QueryEngineTool, + SimpleNodeParser, + SimpleToolNodeMapping, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, + storageContextFromDefaults, +} from "llamaindex"; + +import { extractWikipedia } from "./helpers/extractWikipedia"; + +const wikiTitles = ["Brazil", "Canada"]; + +async function main() { + await extractWikipedia(wikiTitles); + + const countryDocs: Record<string, Document> = {}; + + for (const title of wikiTitles) { + const path = `./agent/helpers/tmp_data/${title}.txt`; + const text = await fs.readFile(path, "utf-8"); + const document = new Document({ text: text, id_: path }); + countryDocs[title] = document; + } + + const llm = new OpenAI({ + model: "gpt-4", + }); + + const serviceContext = serviceContextFromDefaults({ llm }); + const storageContext = await storageContextFromDefaults({ + persistDir: "./storage", + }); + + // TODO: fix any + const documentAgents: any = {}; + const queryEngines: any = {}; + + for (const title of wikiTitles) { + console.log(`Processing ${title}`); + + const nodes = new SimpleNodeParser({ + chunkSize: 200, + chunkOverlap: 20, + }).getNodesFromDocuments([countryDocs[title]]); + + console.log(`Creating index for ${title}`); + + const vectorIndex = await VectorStoreIndex.init({ + serviceContext: serviceContext, + storageContext: storageContext, + nodes, + }); + + const summaryIndex = await SummaryIndex.init({ + serviceContext: serviceContext, + nodes, + }); + + console.log(`Creating query engines for ${title}`); + + const vectorQueryEngine = summaryIndex.asQueryEngine(); + const summaryQueryEngine = summaryIndex.asQueryEngine(); + + const queryEngineTools = [ + new QueryEngineTool({ + queryEngine: vectorQueryEngine, + metadata: { + name: "vector_tool", + description: `Useful for questions related to specific aspects of ${title} (e.g. the history, arts and culture, sports, demographics, or more).`, + }, + }), + new QueryEngineTool({ + queryEngine: summaryQueryEngine, + metadata: { + name: "summary_tool", + description: `Useful for any requests that require a holistic summary of EVERYTHING about ${title}. For questions about more specific sections, please use the vector_tool.`, + }, + }), + ]; + + console.log(`Creating agents for ${title}`); + + const agent = new OpenAIAgent({ + tools: queryEngineTools, + llm, + verbose: true, + }); + + documentAgents[title] = agent; + queryEngines[title] = vectorIndex.asQueryEngine(); + } + + const allTools: QueryEngineTool[] = []; + + console.log(`Creating tools for all countries`); + + for (const title of wikiTitles) { + const wikiSummary = `This content contains Wikipedia articles about ${title}. Use this tool if you want to answer any questions about ${title}`; + + console.log(`Creating tool for ${title}`); + + const docTool = new QueryEngineTool({ + queryEngine: documentAgents[title], + metadata: { + name: `tool_${title}`, + description: wikiSummary, + }, + }); + + allTools.push(docTool); + } + + console.log("creating tool mapping"); + + const toolMapping = SimpleToolNodeMapping.fromObjects(allTools); + + const objectIndex = await ObjectIndex.fromObjects( + allTools, + toolMapping, + VectorStoreIndex, + { + serviceContext, + storageContext, + }, + ); + + const topAgent = new OpenAIAgent({ + toolRetriever: await objectIndex.asRetriever({}), + llm, + verbose: true, + prefixMessages: [ + { + content: + "You are an agent designed to answer queries about a set of given countries. Please always use the tools provided to answer a question. Do not rely on prior knowledge.", + role: "system", + }, + ], + }); + + const response = await topAgent.chat({ + message: "Tell me the differences between Brazil and Canada economics?", + }); + + console.log({ + capitalOfBrazil: response, + }); +} + +main(); diff --git a/examples/agent/query_openai_agent.ts b/examples/agent/query_openai_agent.ts new file mode 100644 index 000000000..37614b46d --- /dev/null +++ b/examples/agent/query_openai_agent.ts @@ -0,0 +1,46 @@ +import { + OpenAIAgent, + QueryEngineTool, + SimpleDirectoryReader, + VectorStoreIndex, +} from "llamaindex"; + +async function main() { + // Load the documents + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples/", + }); + + // Create a vector index from the documents + const vectorIndex = await VectorStoreIndex.fromDocuments(documents); + + // Create a query engine from the vector index + const abramovQueryEngine = vectorIndex.asQueryEngine(); + + // Create a QueryEngineTool with the query engine + const queryEngineTool = new QueryEngineTool({ + queryEngine: abramovQueryEngine, + metadata: { + name: "abramov_query_engine", + description: "A query engine for the Abramov documents", + }, + }); + + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [queryEngineTool], + verbose: true, + }); + + // Chat with the agent + const response = await agent.chat({ + message: "What was his salary?", + }); + + // Print the response + console.log(String(response)); +} + +main().then(() => { + console.log("Done"); +}); diff --git a/examples/vectorIndexEmbed3.ts b/examples/vectorIndexEmbed3.ts index 9ecc33d28..1d6c52b90 100644 --- a/examples/vectorIndexEmbed3.ts +++ b/examples/vectorIndexEmbed3.ts @@ -30,6 +30,7 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); + const response = await queryEngine.query({ query: "What did the author do in college?", }); diff --git a/packages/core/package.json b/packages/core/package.json index ff4032902..93e1bf867 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -134,6 +134,11 @@ "import": "./dist/tools.mjs", "require": "./dist/tools.js" }, + "./objects": { + "types": "./dist/objects.d.mts", + "import": "./dist/objects.mjs", + "require": "./dist/objects.js" + }, "./readers": { "types": "./dist/readers.d.mts", "import": "./dist/readers.mjs", diff --git a/packages/core/src/agent/openai/base.ts b/packages/core/src/agent/openai/base.ts index 5b916ebae..6c5c15e61 100644 --- a/packages/core/src/agent/openai/base.ts +++ b/packages/core/src/agent/openai/base.ts @@ -6,7 +6,7 @@ import { AgentRunner } from "../runner/base"; import { OpenAIAgentWorker } from "./worker"; type OpenAIAgentParams = { - tools: BaseTool[]; + tools?: BaseTool[]; llm?: OpenAI; memory?: any; prefixMessages?: ChatMessage[]; @@ -14,7 +14,7 @@ type OpenAIAgentParams = { maxFunctionCalls?: number; defaultToolChoice?: string; callbackManager?: CallbackManager; - toolRetriever?: ObjectRetriever<BaseTool>; + toolRetriever?: ObjectRetriever; systemPrompt?: string; }; diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 33b6e470f..779a7c014 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -69,13 +69,13 @@ async function callFunction( } type OpenAIAgentWorkerParams = { - tools: BaseTool[]; + tools?: BaseTool[]; llm?: OpenAI; prefixMessages?: ChatMessage[]; verbose?: boolean; maxFunctionCalls?: number; callbackManager?: CallbackManager | undefined; - toolRetriever?: ObjectRetriever<BaseTool>; + toolRetriever?: ObjectRetriever; }; type CallFunctionOutput = { @@ -101,7 +101,7 @@ export class OpenAIAgentWorker implements AgentWorker { * Initialize. */ constructor({ - tools, + tools = [], llm, prefixMessages, verbose, @@ -191,6 +191,7 @@ export class OpenAIAgentWorker implements AgentWorker { ): AgentChatResponse | AsyncIterable<ChatResponseChunk> { const aiMessage = chatResponse.message; task.extraState.newMemory.put(aiMessage); + return new AgentChatResponse(aiMessage.content, task.extraState.sources); } diff --git a/packages/core/src/agent/react/base.ts b/packages/core/src/agent/react/base.ts index ecafcff38..e0883e4bb 100644 --- a/packages/core/src/agent/react/base.ts +++ b/packages/core/src/agent/react/base.ts @@ -14,7 +14,7 @@ type ReActAgentParams = { maxInteractions?: number; defaultToolChoice?: string; callbackManager?: CallbackManager; - toolRetriever?: ObjectRetriever<BaseTool>; + toolRetriever?: ObjectRetriever; }; /** diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts index ff5b1a9fd..1bb48691a 100644 --- a/packages/core/src/agent/react/worker.ts +++ b/packages/core/src/agent/react/worker.ts @@ -24,7 +24,7 @@ type ReActAgentWorkerParams = { outputParser?: ReActOutputParser | undefined; callbackManager?: CallbackManager | undefined; verbose?: boolean | undefined; - toolRetriever?: ObjectRetriever<BaseTool> | undefined; + toolRetriever?: ObjectRetriever | undefined; }; /** diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index 4c874b15a..73d0ada6c 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -12,6 +12,7 @@ import { CompactAndRefine, ResponseSynthesizer, } from "../../synthesizers"; + import { BaseQueryEngine, BaseTool, @@ -19,6 +20,7 @@ import { QueryEngineParamsStreaming, ToolMetadata, } from "../../types"; + import { BaseQuestionGenerator, SubQuestion } from "./types"; /** diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 16faedf93..f2449885b 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,6 +20,7 @@ export * from "./indices"; export * from "./ingestion"; export * from "./llm"; export * from "./nodeParsers"; +export * from "./objects"; export * from "./postprocessors"; export * from "./readers"; export * from "./selectors"; diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index b416126e7..9515bab44 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -84,7 +84,9 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { * @param options * @returns */ - static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> { + public static async init( + options: VectorIndexOptions, + ): Promise<VectorStoreIndex> { const storageContext = options.storageContext ?? (await storageContextFromDefaults({})); const serviceContext = diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index bdba512d9..1a3a1ad41 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -26,7 +26,6 @@ import { } from "./azure"; import { BaseLLM } from "./base"; import { OpenAISession, getOpenAISession } from "./open_ai"; -import { isFunctionCallingModel } from "./openai/utils"; import { PortkeySession, getPortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate_ai"; import { @@ -67,6 +66,12 @@ export const ALL_AVAILABLE_OPENAI_MODELS = { ...GPT35_MODELS, }; +export const isFunctionCallingModel = (model: string): boolean => { + const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); + const isOld = model.includes("0314") || model.includes("0301"); + return isChatModel && !isOld; +}; + /** * OpenAI LLM implementation */ diff --git a/packages/core/src/llm/openai/utils.ts b/packages/core/src/llm/openai/utils.ts deleted file mode 100644 index 0521c9a41..000000000 --- a/packages/core/src/llm/openai/utils.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { ALL_AVAILABLE_OPENAI_MODELS } from ".."; - -export const isFunctionCallingModel = (model: string): boolean => { - const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); - const isOld = model.includes("0314") || model.includes("0301"); - return isChatModel && !isOld; -}; diff --git a/packages/core/src/objects/base.ts b/packages/core/src/objects/base.ts index 7adeb9490..db39e2f92 100644 --- a/packages/core/src/objects/base.ts +++ b/packages/core/src/objects/base.ts @@ -1,42 +1,41 @@ -import { BaseNode, TextNode } from "../Node"; +import { BaseNode, Metadata, TextNode } from "../Node"; import { BaseRetriever } from "../Retriever"; +import { VectorStoreIndex } from "../indices"; +import { BaseTool } from "../types"; // Assuming that necessary interfaces and classes (like OT, TextNode, BaseNode, etc.) are defined elsewhere // Import statements (e.g., for TextNode, BaseNode) should be added based on your project's structure -export abstract class BaseObjectNodeMapping<OT> { +export abstract class BaseObjectNodeMapping { // TypeScript doesn't support Python's classmethod directly, but we can use static methods as an alternative - abstract fromObjects<OT>( - objs: OT[], - ...args: any[] - ): BaseObjectNodeMapping<OT>; + abstract fromObjects<OT>(objs: OT[], ...args: any[]): BaseObjectNodeMapping; // Abstract methods in TypeScript abstract objNodeMapping(): Record<any, any>; - abstract toNode(obj: OT): TextNode; + abstract toNode(obj: any): TextNode; // Concrete methods can be defined as usual - validateObject(obj: OT): void {} + validateObject(obj: any): void {} // Implementing the add object logic - addObj(obj: OT): void { + addObj(obj: any): void { this.validateObject(obj); this._addObj(obj); } // Abstract method for internal add object logic - protected abstract _addObj(obj: OT): void; + abstract _addObj(obj: any): void; // Implementing toNodes method - toNodes(objs: OT[]): TextNode[] { + toNodes(objs: any[]): TextNode[] { return objs.map((obj) => this.toNode(obj)); } // Abstract method for internal from node logic - protected abstract _fromNode(node: BaseNode): OT; + abstract _fromNode(node: BaseNode): any; // Implementing fromNode method - fromNode(node: BaseNode): OT { + fromNode(node: BaseNode): any { const obj = this._fromNode(node); this.validateObject(obj); return obj; @@ -50,13 +49,13 @@ export abstract class BaseObjectNodeMapping<OT> { type QueryType = string; -export class ObjectRetriever<OT> { - private _retriever: BaseRetriever; - private _objectNodeMapping: BaseObjectNodeMapping<OT>; +export class ObjectRetriever { + _retriever: BaseRetriever; + _objectNodeMapping: BaseObjectNodeMapping; constructor( retriever: BaseRetriever, - objectNodeMapping: BaseObjectNodeMapping<OT>, + objectNodeMapping: BaseObjectNodeMapping, ) { this._retriever = retriever; this._objectNodeMapping = objectNodeMapping; @@ -68,13 +67,126 @@ export class ObjectRetriever<OT> { } // Translating the retrieve method - async retrieve(strOrQueryBundle: QueryType): Promise<OT[]> { - const nodes = await this._retriever.retrieve(strOrQueryBundle); - return nodes.map((node) => this._objectNodeMapping.fromNode(node.node)); + async retrieve(strOrQueryBundle: QueryType): Promise<any> { + const nodes = await this.retriever.retrieve(strOrQueryBundle); + const objs = nodes.map((n) => this._objectNodeMapping.fromNode(n.node)); + return objs; + } +} + +const convertToolToNode = (tool: BaseTool): TextNode => { + const nodeText = ` + Tool name: ${tool.metadata.name} + Tool description: ${tool.metadata.description} + `; + return new TextNode({ + text: nodeText, + metadata: { name: tool.metadata.name }, + excludedEmbedMetadataKeys: ["name"], + excludedLlmMetadataKeys: ["name"], + }); +}; + +export class SimpleToolNodeMapping extends BaseObjectNodeMapping { + private _tools: Record<string, BaseTool>; + + private constructor(objs: BaseTool[] = []) { + super(); + this._tools = {}; + for (const tool of objs) { + this._tools[tool.metadata.name] = tool; + } + } + + objNodeMapping(): Record<any, any> { + return this._tools; + } + + toNode(tool: BaseTool): TextNode { + return convertToolToNode(tool); + } + + _addObj(tool: BaseTool): void { + this._tools[tool.metadata.name] = tool; + } + + _fromNode(node: BaseNode): BaseTool { + if (!node.metadata) { + throw new Error("Metadata must be set"); + } + return this._tools[node.metadata.name]; + } + + persist(persistDir: string, objNodeMappingFilename: string): void { + // Implement the persist method + } + + toNodes(objs: BaseTool[]): TextNode<Metadata>[] { + return objs.map((obj) => this.toNode(obj)); + } + + addObj(obj: BaseTool): void { + this._addObj(obj); + } + + fromNode(node: BaseNode): BaseTool { + return this._fromNode(node); } - // // Translating the _asQueryComponent method - // public asQueryComponent(kwargs: any): any { - // return new ObjectRetrieverComponent(this); - // } + static fromObjects(objs: any, ...args: any[]): BaseObjectNodeMapping { + return new SimpleToolNodeMapping(objs); + } + + fromObjects<OT>(objs: any, ...args: any[]): BaseObjectNodeMapping { + return new SimpleToolNodeMapping(objs); + } +} + +export class ObjectIndex { + private _index: VectorStoreIndex; + private _objectNodeMapping: BaseObjectNodeMapping; + + private constructor(index: any, objectNodeMapping: BaseObjectNodeMapping) { + this._index = index; + this._objectNodeMapping = objectNodeMapping; + } + + static async fromObjects( + objects: any, + objectMapping: BaseObjectNodeMapping, + // TODO: fix any (bundling issue) + indexCls: any, + indexKwargs?: Record<string, any>, + ): Promise<ObjectIndex> { + if (objectMapping === null) { + objectMapping = SimpleToolNodeMapping.fromObjects(objects, {}); + } + + const nodes = objectMapping.toNodes(objects); + + const index = await indexCls.init({ nodes, ...indexKwargs }); + + return new ObjectIndex(index, objectMapping); + } + + insertObject(obj: any): void { + this._objectNodeMapping.addObj(obj); + const node = this._objectNodeMapping.toNode(obj); + this._index.insertNodes([node]); + } + + get tools(): Record<string, BaseTool> { + return this._objectNodeMapping.objNodeMapping(); + } + + async asRetriever(kwargs: any): Promise<ObjectRetriever> { + return new ObjectRetriever( + this._index.asRetriever(kwargs), + this._objectNodeMapping, + ); + } + + asNodeRetriever(kwargs: any): any { + return this._index.asRetriever(kwargs); + } } diff --git a/packages/core/src/objects/index.ts b/packages/core/src/objects/index.ts new file mode 100644 index 000000000..955fdd143 --- /dev/null +++ b/packages/core/src/objects/index.ts @@ -0,0 +1 @@ +export * from "./base"; diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index f4095e242..3c9f64606 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -60,7 +60,6 @@ export class SimpleVectorStore implements VectorStore { this.data.embeddingDict[node.id_] = node.getEmbedding(); if (!node.sourceNode) { - console.error("Missing source node from TextNode."); continue; } diff --git a/packages/core/src/tests/Embedding.test.ts b/packages/core/src/tests/Embedding.test.ts index 68d5b2078..4f3b4262b 100644 --- a/packages/core/src/tests/Embedding.test.ts +++ b/packages/core/src/tests/Embedding.test.ts @@ -70,12 +70,12 @@ describe("[OpenAIEmbedding]", () => { test("getTextEmbeddings", async () => { const texts = ["hello", "world"]; const embeddings = await embedModel.getTextEmbeddings(texts); - expect(embeddings.length).toEqual(1); + expect(embeddings.length).toEqual(2); }); test("getTextEmbeddingsBatch", async () => { const texts = ["hello", "world"]; const embeddings = await embedModel.getTextEmbeddingsBatch(texts); - expect(embeddings.length).toEqual(1); + expect(embeddings.length).toEqual(2); }); }); diff --git a/packages/core/src/tests/llms/openai/utils.test.ts b/packages/core/src/tests/llms/openai/utils.test.ts index 95eb3bc03..5ebace03d 100644 --- a/packages/core/src/tests/llms/openai/utils.test.ts +++ b/packages/core/src/tests/llms/openai/utils.test.ts @@ -1,5 +1,7 @@ -import { ALL_AVAILABLE_OPENAI_MODELS } from "../../../llm"; -import { isFunctionCallingModel } from "../../../llm/openai/utils"; +import { + ALL_AVAILABLE_OPENAI_MODELS, + isFunctionCallingModel, +} from "../../../llm"; describe("openai/utils", () => { test("shouldn't be a old model", () => { diff --git a/packages/core/src/tests/objects/ObjectIndex.test.ts b/packages/core/src/tests/objects/ObjectIndex.test.ts new file mode 100644 index 000000000..b46616486 --- /dev/null +++ b/packages/core/src/tests/objects/ObjectIndex.test.ts @@ -0,0 +1,130 @@ +import { + FunctionTool, + ObjectIndex, + OpenAI, + OpenAIEmbedding, + ServiceContext, + SimpleToolNodeMapping, + VectorStoreIndex, + serviceContextFromDefaults, +} from "../../index"; +import { mockEmbeddingModel, mockLlmGeneration } from "../utility/mockOpenAI"; + +jest.mock("../../llm/open_ai", () => { + return { + getOpenAISession: jest.fn().mockImplementation(() => null), + }; +}); + +describe("ObjectIndex", () => { + let serviceContext: ServiceContext; + + beforeAll(() => { + const embeddingModel = new OpenAIEmbedding(); + const llm = new OpenAI(); + + mockEmbeddingModel(embeddingModel); + mockLlmGeneration({ languageModel: llm }); + + const ctx = serviceContextFromDefaults({ + embedModel: embeddingModel, + llm, + }); + + serviceContext = ctx; + }); + + test("test_object_with_tools", async () => { + const tool1 = new FunctionTool((x: any) => x, { + name: "test_tool", + description: "test tool", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const tool2 = new FunctionTool((x: any) => x, { + name: "test_tool_2", + description: "test tool 2", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const toolMapping = SimpleToolNodeMapping.fromObjects([tool1, tool2]); + + const objectRetriever = await ObjectIndex.fromObjects( + [tool1, tool2], + toolMapping, + VectorStoreIndex, + { + serviceContext, + }, + ); + + const retriever = await objectRetriever.asRetriever({ + serviceContext, + }); + + expect(await retriever.retrieve("test")).toStrictEqual([tool1, tool2]); + }); + + test("add a new object", async () => { + const tool1 = new FunctionTool((x: any) => x, { + name: "test_tool", + description: "test tool", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const tool2 = new FunctionTool((x: any) => x, { + name: "test_tool_2", + description: "test tool 2", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const toolMapping = SimpleToolNodeMapping.fromObjects([tool1]); + + const objectRetriever = await ObjectIndex.fromObjects( + [tool1], + toolMapping, + VectorStoreIndex, + { + serviceContext, + }, + ); + + let tools = objectRetriever.tools; + + expect(Object.keys(tools).length).toBe(1); + + objectRetriever.insertObject(tool2); + + tools = objectRetriever.tools; + + expect(Object.keys(tools).length).toBe(2); + }); +}); diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index a9343cd1f..aaa77d743 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -11,7 +11,7 @@ export function mockLlmGeneration({ callbackManager, }: { languageModel: OpenAI; - callbackManager: CallbackManager; + callbackManager?: CallbackManager; }) { jest .spyOn(languageModel, "chat") @@ -84,7 +84,10 @@ export function mockLlmToolCallGeneration({ ); } -export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { +export function mockEmbeddingModel( + embedModel: OpenAIEmbedding, + embeddingsLength: number = 1, +) { jest.spyOn(embedModel, "getTextEmbedding").mockImplementation(async (x) => { return new Promise((resolve) => { resolve([1, 0, 0, 0, 0, 0]); @@ -92,6 +95,9 @@ export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { }); jest.spyOn(embedModel, "getTextEmbeddings").mockImplementation(async (x) => { return new Promise((resolve) => { + if (x.length > 1) { + resolve(Array(x.length).fill([1, 0, 0, 0, 0, 0])); + } resolve([[1, 0, 0, 0, 0, 0]]); }); }); diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts index 1215bef7e..bd695d4f9 100644 --- a/packages/core/src/tools/index.ts +++ b/packages/core/src/tools/index.ts @@ -1,3 +1,4 @@ export * from "./QueryEngineTool"; export * from "./functionTool"; export * from "./types"; +export * from "./utils"; -- GitLab