diff --git a/.changeset/thin-seals-cough.md b/.changeset/thin-seals-cough.md new file mode 100644 index 0000000000000000000000000000000000000000..6d07e650dd6d6514767c5d00ca01b2b72eb419c6 --- /dev/null +++ b/.changeset/thin-seals-cough.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: multi-document agent diff --git a/apps/docs/docs/modules/agent/_category_.yml b/apps/docs/docs/modules/agent/_category_.yml index 549b74c5a136746cc283d3307172ae7f5cd42c38..0ea1a6657ff35f718e2187badfd0f107ab19e272 100644 --- a/apps/docs/docs/modules/agent/_category_.yml +++ b/apps/docs/docs/modules/agent/_category_.yml @@ -1 +1,2 @@ label: "Agents" +position: 3 diff --git a/apps/docs/docs/modules/agent/multi_document_agent.mdx b/apps/docs/docs/modules/agent/multi_document_agent.mdx new file mode 100644 index 0000000000000000000000000000000000000000..cbe9e3c2111bb2efd36f2595a01cd387c22e5483 --- /dev/null +++ b/apps/docs/docs/modules/agent/multi_document_agent.mdx @@ -0,0 +1,314 @@ +# Multi-Document Agent + +In this guide, you learn towards setting up an agent that can effectively answer different types of questions over a larger set of documents. + +These questions include the following + +- QA over a specific doc +- QA comparing different docs +- Summaries over a specific doc +- Comparing summaries between different docs + +We do this with the following architecture: + +- setup a “document agent” over each Document: each doc agent can do QA/summarization within its doc +- setup a top-level agent over this set of document agents. Do tool retrieval and then do CoT over the set of tools to answer a question. + +## Setup and Download Data + +We first start by installing the necessary libraries and downloading the data. + +```bash +pnpm i llamaindex +``` + +```ts +import { + Document, + ObjectIndex, + OpenAI, + OpenAIAgent, + QueryEngineTool, + SimpleNodeParser, + SimpleToolNodeMapping, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, + storageContextFromDefaults, +} from "llamaindex"; +``` + +And then for the data we will run through a list of countries and download the wikipedia page for each country. + +```ts +import fs from "fs"; +import path from "path"; + +const dataPath = path.join(__dirname, "tmp_data"); + +const extractWikipediaTitle = async (title: string) => { + const fileExists = fs.existsSync(path.join(dataPath, `${title}.txt`)); + + if (fileExists) { + console.log(`File already exists for the title: ${title}`); + return; + } + + const queryParams = new URLSearchParams({ + action: "query", + format: "json", + titles: title, + prop: "extracts", + explaintext: "true", + }); + + const url = `https://en.wikipedia.org/w/api.php?${queryParams}`; + + const response = await fetch(url); + const data: any = await response.json(); + + const pages = data.query.pages; + const page = pages[Object.keys(pages)[0]]; + const wikiText = page.extract; + + await new Promise((resolve) => { + fs.writeFile(path.join(dataPath, `${title}.txt`), wikiText, (err: any) => { + if (err) { + console.error(err); + resolve(title); + return; + } + console.log(`${title} stored in file!`); + + resolve(title); + }); + }); +}; +``` + +```ts +export const extractWikipedia = async (titles: string[]) => { + if (!fs.existsSync(dataPath)) { + fs.mkdirSync(dataPath); + } + + for await (const title of titles) { + await extractWikipediaTitle(title); + } + + console.log("Extration finished!"); +``` + +These files will be saved in the `tmp_data` folder. + +Now we can call the function to download the data for each country. + +```ts +await extractWikipedia([ + "Brazil", + "United States", + "Canada", + "Mexico", + "Argentina", + "Chile", + "Colombia", + "Peru", + "Venezuela", + "Ecuador", + "Bolivia", + "Paraguay", + "Uruguay", + "Guyana", + "Suriname", + "French Guiana", + "Falkland Islands", +]); +``` + +## Load the data + +Now that we have the data, we can load it into the LlamaIndex and store as a document. + +```ts +import { Document } from "llamaindex"; + +const countryDocs: Record<string, Document> = {}; + +for (const title of wikiTitles) { + const path = `./agent/helpers/tmp_data/${title}.txt`; + const text = await fs.readFile(path, "utf-8"); + const document = new Document({ text: text, id_: path }); + countryDocs[title] = document; +} +``` + +## Setup LLM and StorageContext + +We will be using gpt-4 for this example and we will use the `StorageContext` to store the documents in-memory. + +```ts +const llm = new OpenAI({ + model: "gpt-4", +}); + +const ctx = serviceContextFromDefaults({ llm }); + +const storageContext = await storageContextFromDefaults({ + persistDir: "./storage", +}); +``` + +## Building Multi-Document Agents + +In this section we show you how to construct the multi-document agent. We first build a document agent for each document, and then define the top-level parent agent with an object index. + +```ts +const documentAgents: Record<string, any> = {}; +const queryEngines: Record<string, any> = {}; +``` + +Now we iterate over each country and create a document agent for each one. + +### Build Agent for each Document + +In this section we define “document agents” for each document. + +We define both a vector index (for semantic search) and summary index (for summarization) for each document. The two query engines are then converted into tools that are passed to an OpenAI function calling agent. + +This document agent can dynamically choose to perform semantic search or summarization within a given document. + +We create a separate document agent for each coutnry. + +```ts +for (const title of wikiTitles) { + // parse the document into nodes + const nodes = new SimpleNodeParser({ + chunkSize: 200, + chunkOverlap: 20, + }).getNodesFromDocuments([countryDocs[title]]); + + // create the vector index for specific search + const vectorIndex = await VectorStoreIndex.init({ + serviceContext: serviceContext, + storageContext: storageContext, + nodes, + }); + + // create the summary index for broader search + const summaryIndex = await SummaryIndex.init({ + serviceContext: serviceContext, + nodes, + }); + + const vectorQueryEngine = summaryIndex.asQueryEngine(); + const summaryQueryEngine = summaryIndex.asQueryEngine(); + + // create the query engines for each task + const queryEngineTools = [ + new QueryEngineTool({ + queryEngine: vectorQueryEngine, + metadata: { + name: "vector_tool", + description: `Useful for questions related to specific aspects of ${title} (e.g. the history, arts and culture, sports, demographics, or more).`, + }, + }), + new QueryEngineTool({ + queryEngine: summaryQueryEngine, + metadata: { + name: "summary_tool", + description: `Useful for any requests that require a holistic summary of EVERYTHING about ${title}. For questions about more specific sections, please use the vector_tool.`, + }, + }), + ]; + + // create the document agent + const agent = new OpenAIAgent({ + tools: queryEngineTools, + llm, + verbose: true, + }); + + documentAgents[title] = agent; + queryEngines[title] = vectorIndex.asQueryEngine(); +} +``` + +## Build Top-Level Agent + +Now we define the top-level agent that can answer questions over the set of document agents. + +This agent takes in all document agents as tools. This specific agent RetrieverOpenAIAgent performs tool retrieval before tool use (unlike a default agent that tries to put all tools in the prompt). + +Here we use a top-k retriever, but we encourage you to customize the tool retriever method! + +Firstly, we create a tool for each document agent + +```ts +const allTools: QueryEngineTool[] = []; +``` + +```ts +for (const title of wikiTitles) { + const wikiSummary = ` + This content contains Wikipedia articles about ${title}. + Use this tool if you want to answer any questions about ${title} + `; + + const docTool = new QueryEngineTool({ + queryEngine: documentAgents[title], + metadata: { + name: `tool_${title}`, + description: wikiSummary, + }, + }); + + allTools.push(docTool); +} +``` + +Our top level agent will use this document agents as tools and use toolRetriever to retrieve the best tool to answer a question. + +```ts +// map the tools to nodes +const toolMapping = SimpleToolNodeMapping.fromObjects(allTools); + +// create the object index +const objectIndex = await ObjectIndex.fromObjects( + allTools, + toolMapping, + VectorStoreIndex, + { + serviceContext, + storageContext, + }, +); + +// create the top agent +const topAgent = new OpenAIAgent({ + toolRetriever: await objectIndex.asRetriever({}), + llm, + verbose: true, + prefixMessages: [ + { + content: + "You are an agent designed to answer queries about a set of given countries. Please always use the tools provided to answer a question. Do not rely on prior knowledge.", + role: "system", + }, + ], +}); +``` + +## Use the Agent + +Now we can use the agent to answer questions. + +```ts +const response = await topAgent.chat({ + message: "Tell me the differences between Brazil and Canada economics?", +}); + +// print output +console.log(response); +``` + +You can find the full code for this example [here](https://github.com/run-llama/LlamaIndexTS/tree/main/examples/agent/multi-document-agent.ts) diff --git a/apps/docs/docs/modules/agent/openai.mdx b/apps/docs/docs/modules/agent/openai.mdx index f0130a985e75ca89c79ab1fb8d3f7ced78d4e23e..37a4364f9a7c99b414e84b05d31db0225768359d 100644 --- a/apps/docs/docs/modules/agent/openai.mdx +++ b/apps/docs/docs/modules/agent/openai.mdx @@ -1,3 +1,7 @@ +--- +sidebar_position: 0 +--- + # OpenAI Agent OpenAI API that supports function calling, it’s never been easier to build your own agent! diff --git a/apps/docs/docs/modules/agent/query_engine_tool.mdx b/apps/docs/docs/modules/agent/query_engine_tool.mdx index 83ce66f929ea168c26a1aceaa308fd56bdb77bb8..85b007014f5c2e62f30237dc81ae063b43ca06e7 100644 --- a/apps/docs/docs/modules/agent/query_engine_tool.mdx +++ b/apps/docs/docs/modules/agent/query_engine_tool.mdx @@ -1,3 +1,7 @@ +--- +sidebar_position: 1 +--- + # OpenAI Agent + QueryEngineTool QueryEngineTool is a tool that allows you to query a vector index. In this example, we will create a vector index from a set of documents and then create a QueryEngineTool from the vector index. We will then create an OpenAIAgent with the QueryEngineTool and chat with the agent. diff --git a/examples/.gitignore b/examples/.gitignore index 8a77cb22f9809ef84223cfd2eac7b45d184a45ee..0e6a8af04e4904f9432e59f6aafe98b73dad6893 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -1,2 +1,3 @@ package-lock.json storage +tmp_data \ No newline at end of file diff --git a/examples/agent/helpers/extractWikipedia.ts b/examples/agent/helpers/extractWikipedia.ts new file mode 100644 index 0000000000000000000000000000000000000000..05f391c292a140d846d8e68d762f4aafbf59cba6 --- /dev/null +++ b/examples/agent/helpers/extractWikipedia.ts @@ -0,0 +1,55 @@ +import fs from "fs"; +import path from "path"; + +const dataPath = path.join(__dirname, "tmp_data"); + +const extractWikipediaTitle = async (title: string) => { + const fileExists = fs.existsSync(path.join(dataPath, `${title}.txt`)); + + if (fileExists) { + console.log(`Arquivo já existe para o título: ${title}`); + return; + } + + const queryParams = new URLSearchParams({ + action: "query", + format: "json", + titles: title, + prop: "extracts", + explaintext: "true", + }); + + const url = `https://en.wikipedia.org/w/api.php?${queryParams}`; + + const response = await fetch(url); + const data: any = await response.json(); + + const pages = data.query.pages; + const page = pages[Object.keys(pages)[0]]; + const wikiText = page.extract; + + await new Promise((resolve) => { + fs.writeFile(path.join(dataPath, `${title}.txt`), wikiText, (err: any) => { + if (err) { + console.error(err); + resolve(title); + return; + } + console.log(`${title} stored!`); + + resolve(title); + }); + }); +}; + +export const extractWikipedia = async (titles: string[]) => { + if (!fs.existsSync(dataPath)) { + fs.mkdirSync(dataPath); + } + + for await (const title of titles) { + await extractWikipediaTitle(title); + } + + console.log("Extration finished!"); +}; diff --git a/examples/agent/multi_document_agent.ts b/examples/agent/multi_document_agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..b14449fffaefbe165eab0cf8da1a30ed2dac7806 --- /dev/null +++ b/examples/agent/multi_document_agent.ts @@ -0,0 +1,157 @@ +import fs from "node:fs/promises"; + +import { + Document, + ObjectIndex, + OpenAI, + OpenAIAgent, + QueryEngineTool, + SimpleNodeParser, + SimpleToolNodeMapping, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, + storageContextFromDefaults, +} from "llamaindex"; + +import { extractWikipedia } from "./helpers/extractWikipedia"; + +const wikiTitles = ["Brazil", "Canada"]; + +async function main() { + await extractWikipedia(wikiTitles); + + const countryDocs: Record<string, Document> = {}; + + for (const title of wikiTitles) { + const path = `./agent/helpers/tmp_data/${title}.txt`; + const text = await fs.readFile(path, "utf-8"); + const document = new Document({ text: text, id_: path }); + countryDocs[title] = document; + } + + const llm = new OpenAI({ + model: "gpt-4", + }); + + const serviceContext = serviceContextFromDefaults({ llm }); + const storageContext = await storageContextFromDefaults({ + persistDir: "./storage", + }); + + // TODO: fix any + const documentAgents: any = {}; + const queryEngines: any = {}; + + for (const title of wikiTitles) { + console.log(`Processing ${title}`); + + const nodes = new SimpleNodeParser({ + chunkSize: 200, + chunkOverlap: 20, + }).getNodesFromDocuments([countryDocs[title]]); + + console.log(`Creating index for ${title}`); + + const vectorIndex = await VectorStoreIndex.init({ + serviceContext: serviceContext, + storageContext: storageContext, + nodes, + }); + + const summaryIndex = await SummaryIndex.init({ + serviceContext: serviceContext, + nodes, + }); + + console.log(`Creating query engines for ${title}`); + + const vectorQueryEngine = summaryIndex.asQueryEngine(); + const summaryQueryEngine = summaryIndex.asQueryEngine(); + + const queryEngineTools = [ + new QueryEngineTool({ + queryEngine: vectorQueryEngine, + metadata: { + name: "vector_tool", + description: `Useful for questions related to specific aspects of ${title} (e.g. the history, arts and culture, sports, demographics, or more).`, + }, + }), + new QueryEngineTool({ + queryEngine: summaryQueryEngine, + metadata: { + name: "summary_tool", + description: `Useful for any requests that require a holistic summary of EVERYTHING about ${title}. For questions about more specific sections, please use the vector_tool.`, + }, + }), + ]; + + console.log(`Creating agents for ${title}`); + + const agent = new OpenAIAgent({ + tools: queryEngineTools, + llm, + verbose: true, + }); + + documentAgents[title] = agent; + queryEngines[title] = vectorIndex.asQueryEngine(); + } + + const allTools: QueryEngineTool[] = []; + + console.log(`Creating tools for all countries`); + + for (const title of wikiTitles) { + const wikiSummary = `This content contains Wikipedia articles about ${title}. Use this tool if you want to answer any questions about ${title}`; + + console.log(`Creating tool for ${title}`); + + const docTool = new QueryEngineTool({ + queryEngine: documentAgents[title], + metadata: { + name: `tool_${title}`, + description: wikiSummary, + }, + }); + + allTools.push(docTool); + } + + console.log("creating tool mapping"); + + const toolMapping = SimpleToolNodeMapping.fromObjects(allTools); + + const objectIndex = await ObjectIndex.fromObjects( + allTools, + toolMapping, + VectorStoreIndex, + { + serviceContext, + storageContext, + }, + ); + + const topAgent = new OpenAIAgent({ + toolRetriever: await objectIndex.asRetriever({}), + llm, + verbose: true, + prefixMessages: [ + { + content: + "You are an agent designed to answer queries about a set of given countries. Please always use the tools provided to answer a question. Do not rely on prior knowledge.", + role: "system", + }, + ], + }); + + const response = await topAgent.chat({ + message: "Tell me the differences between Brazil and Canada economics?", + }); + + console.log({ + capitalOfBrazil: response, + }); +} + +main(); diff --git a/examples/agent/query_openai_agent.ts b/examples/agent/query_openai_agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..37614b46df6703ba7f03451b0ac43f89d66a8957 --- /dev/null +++ b/examples/agent/query_openai_agent.ts @@ -0,0 +1,46 @@ +import { + OpenAIAgent, + QueryEngineTool, + SimpleDirectoryReader, + VectorStoreIndex, +} from "llamaindex"; + +async function main() { + // Load the documents + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples/", + }); + + // Create a vector index from the documents + const vectorIndex = await VectorStoreIndex.fromDocuments(documents); + + // Create a query engine from the vector index + const abramovQueryEngine = vectorIndex.asQueryEngine(); + + // Create a QueryEngineTool with the query engine + const queryEngineTool = new QueryEngineTool({ + queryEngine: abramovQueryEngine, + metadata: { + name: "abramov_query_engine", + description: "A query engine for the Abramov documents", + }, + }); + + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [queryEngineTool], + verbose: true, + }); + + // Chat with the agent + const response = await agent.chat({ + message: "What was his salary?", + }); + + // Print the response + console.log(String(response)); +} + +main().then(() => { + console.log("Done"); +}); diff --git a/examples/vectorIndexEmbed3.ts b/examples/vectorIndexEmbed3.ts index 9ecc33d288cd94cad152ae28f546f2b20581465d..1d6c52b9048056817553514b88b78df8cd8785d0 100644 --- a/examples/vectorIndexEmbed3.ts +++ b/examples/vectorIndexEmbed3.ts @@ -30,6 +30,7 @@ async function main() { // Query the index const queryEngine = index.asQueryEngine(); + const response = await queryEngine.query({ query: "What did the author do in college?", }); diff --git a/packages/core/package.json b/packages/core/package.json index ff403290252f787f9be6215ded66f51db041c855..93e1bf867cf4422790af3ba897f0082d7cb87dd6 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -134,6 +134,11 @@ "import": "./dist/tools.mjs", "require": "./dist/tools.js" }, + "./objects": { + "types": "./dist/objects.d.mts", + "import": "./dist/objects.mjs", + "require": "./dist/objects.js" + }, "./readers": { "types": "./dist/readers.d.mts", "import": "./dist/readers.mjs", diff --git a/packages/core/src/agent/openai/base.ts b/packages/core/src/agent/openai/base.ts index 5b916ebae63208e838da400cdf66e08e9b520354..6c5c15e6182127958e6457a697ac52ff9c0e1b77 100644 --- a/packages/core/src/agent/openai/base.ts +++ b/packages/core/src/agent/openai/base.ts @@ -6,7 +6,7 @@ import { AgentRunner } from "../runner/base"; import { OpenAIAgentWorker } from "./worker"; type OpenAIAgentParams = { - tools: BaseTool[]; + tools?: BaseTool[]; llm?: OpenAI; memory?: any; prefixMessages?: ChatMessage[]; @@ -14,7 +14,7 @@ type OpenAIAgentParams = { maxFunctionCalls?: number; defaultToolChoice?: string; callbackManager?: CallbackManager; - toolRetriever?: ObjectRetriever<BaseTool>; + toolRetriever?: ObjectRetriever; systemPrompt?: string; }; diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 33b6e470f40d72d512ef9df5f2366b46df41615e..779a7c0142e112ac060a44eac9226fa220b97a07 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -69,13 +69,13 @@ async function callFunction( } type OpenAIAgentWorkerParams = { - tools: BaseTool[]; + tools?: BaseTool[]; llm?: OpenAI; prefixMessages?: ChatMessage[]; verbose?: boolean; maxFunctionCalls?: number; callbackManager?: CallbackManager | undefined; - toolRetriever?: ObjectRetriever<BaseTool>; + toolRetriever?: ObjectRetriever; }; type CallFunctionOutput = { @@ -101,7 +101,7 @@ export class OpenAIAgentWorker implements AgentWorker { * Initialize. */ constructor({ - tools, + tools = [], llm, prefixMessages, verbose, @@ -191,6 +191,7 @@ export class OpenAIAgentWorker implements AgentWorker { ): AgentChatResponse | AsyncIterable<ChatResponseChunk> { const aiMessage = chatResponse.message; task.extraState.newMemory.put(aiMessage); + return new AgentChatResponse(aiMessage.content, task.extraState.sources); } diff --git a/packages/core/src/agent/react/base.ts b/packages/core/src/agent/react/base.ts index ecafcff382ebbdccb6652ecb1b14a9bd797dbe6c..e0883e4bbe027a06e6650cfe9d029ceade8076cf 100644 --- a/packages/core/src/agent/react/base.ts +++ b/packages/core/src/agent/react/base.ts @@ -14,7 +14,7 @@ type ReActAgentParams = { maxInteractions?: number; defaultToolChoice?: string; callbackManager?: CallbackManager; - toolRetriever?: ObjectRetriever<BaseTool>; + toolRetriever?: ObjectRetriever; }; /** diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts index ff5b1a9fdfe7613cbd33ebabc7d895be131075e4..1bb48691a8c5fbb4ee46fbd5dbe7c5fbac126a7e 100644 --- a/packages/core/src/agent/react/worker.ts +++ b/packages/core/src/agent/react/worker.ts @@ -24,7 +24,7 @@ type ReActAgentWorkerParams = { outputParser?: ReActOutputParser | undefined; callbackManager?: CallbackManager | undefined; verbose?: boolean | undefined; - toolRetriever?: ObjectRetriever<BaseTool> | undefined; + toolRetriever?: ObjectRetriever | undefined; }; /** diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index 4c874b15a46c7220418eef686e36f38fe6c958fa..73d0ada6c52d0cd3edc19698f7d36ae867139e0a 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -12,6 +12,7 @@ import { CompactAndRefine, ResponseSynthesizer, } from "../../synthesizers"; + import { BaseQueryEngine, BaseTool, @@ -19,6 +20,7 @@ import { QueryEngineParamsStreaming, ToolMetadata, } from "../../types"; + import { BaseQuestionGenerator, SubQuestion } from "./types"; /** diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 16faedf93bc5043805f352a8509ade0f74cf653c..f2449885b585d95ef189f61b090a6110e1d35168 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,6 +20,7 @@ export * from "./indices"; export * from "./ingestion"; export * from "./llm"; export * from "./nodeParsers"; +export * from "./objects"; export * from "./postprocessors"; export * from "./readers"; export * from "./selectors"; diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index b416126e73ab04e0dbcff65f4d629391da8a27bf..9515bab4471d35eb438c194d99ddcbf03a89cfb2 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -84,7 +84,9 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { * @param options * @returns */ - static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> { + public static async init( + options: VectorIndexOptions, + ): Promise<VectorStoreIndex> { const storageContext = options.storageContext ?? (await storageContextFromDefaults({})); const serviceContext = diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index bdba512d99c7249275dec3441b7ec7025e437be9..1a3a1ad41e49f946e07e8e3dcbc8287068e425d2 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -26,7 +26,6 @@ import { } from "./azure"; import { BaseLLM } from "./base"; import { OpenAISession, getOpenAISession } from "./open_ai"; -import { isFunctionCallingModel } from "./openai/utils"; import { PortkeySession, getPortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate_ai"; import { @@ -67,6 +66,12 @@ export const ALL_AVAILABLE_OPENAI_MODELS = { ...GPT35_MODELS, }; +export const isFunctionCallingModel = (model: string): boolean => { + const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); + const isOld = model.includes("0314") || model.includes("0301"); + return isChatModel && !isOld; +}; + /** * OpenAI LLM implementation */ diff --git a/packages/core/src/llm/openai/utils.ts b/packages/core/src/llm/openai/utils.ts deleted file mode 100644 index 0521c9a410b80fe43c16ee19d7a457022cc8f7b0..0000000000000000000000000000000000000000 --- a/packages/core/src/llm/openai/utils.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { ALL_AVAILABLE_OPENAI_MODELS } from ".."; - -export const isFunctionCallingModel = (model: string): boolean => { - const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); - const isOld = model.includes("0314") || model.includes("0301"); - return isChatModel && !isOld; -}; diff --git a/packages/core/src/objects/base.ts b/packages/core/src/objects/base.ts index 7adeb94902709556c8956ba4f8de43781ae79918..db39e2f92a7fec6aa6345db85ed126d4bb1f9bda 100644 --- a/packages/core/src/objects/base.ts +++ b/packages/core/src/objects/base.ts @@ -1,42 +1,41 @@ -import { BaseNode, TextNode } from "../Node"; +import { BaseNode, Metadata, TextNode } from "../Node"; import { BaseRetriever } from "../Retriever"; +import { VectorStoreIndex } from "../indices"; +import { BaseTool } from "../types"; // Assuming that necessary interfaces and classes (like OT, TextNode, BaseNode, etc.) are defined elsewhere // Import statements (e.g., for TextNode, BaseNode) should be added based on your project's structure -export abstract class BaseObjectNodeMapping<OT> { +export abstract class BaseObjectNodeMapping { // TypeScript doesn't support Python's classmethod directly, but we can use static methods as an alternative - abstract fromObjects<OT>( - objs: OT[], - ...args: any[] - ): BaseObjectNodeMapping<OT>; + abstract fromObjects<OT>(objs: OT[], ...args: any[]): BaseObjectNodeMapping; // Abstract methods in TypeScript abstract objNodeMapping(): Record<any, any>; - abstract toNode(obj: OT): TextNode; + abstract toNode(obj: any): TextNode; // Concrete methods can be defined as usual - validateObject(obj: OT): void {} + validateObject(obj: any): void {} // Implementing the add object logic - addObj(obj: OT): void { + addObj(obj: any): void { this.validateObject(obj); this._addObj(obj); } // Abstract method for internal add object logic - protected abstract _addObj(obj: OT): void; + abstract _addObj(obj: any): void; // Implementing toNodes method - toNodes(objs: OT[]): TextNode[] { + toNodes(objs: any[]): TextNode[] { return objs.map((obj) => this.toNode(obj)); } // Abstract method for internal from node logic - protected abstract _fromNode(node: BaseNode): OT; + abstract _fromNode(node: BaseNode): any; // Implementing fromNode method - fromNode(node: BaseNode): OT { + fromNode(node: BaseNode): any { const obj = this._fromNode(node); this.validateObject(obj); return obj; @@ -50,13 +49,13 @@ export abstract class BaseObjectNodeMapping<OT> { type QueryType = string; -export class ObjectRetriever<OT> { - private _retriever: BaseRetriever; - private _objectNodeMapping: BaseObjectNodeMapping<OT>; +export class ObjectRetriever { + _retriever: BaseRetriever; + _objectNodeMapping: BaseObjectNodeMapping; constructor( retriever: BaseRetriever, - objectNodeMapping: BaseObjectNodeMapping<OT>, + objectNodeMapping: BaseObjectNodeMapping, ) { this._retriever = retriever; this._objectNodeMapping = objectNodeMapping; @@ -68,13 +67,126 @@ export class ObjectRetriever<OT> { } // Translating the retrieve method - async retrieve(strOrQueryBundle: QueryType): Promise<OT[]> { - const nodes = await this._retriever.retrieve(strOrQueryBundle); - return nodes.map((node) => this._objectNodeMapping.fromNode(node.node)); + async retrieve(strOrQueryBundle: QueryType): Promise<any> { + const nodes = await this.retriever.retrieve(strOrQueryBundle); + const objs = nodes.map((n) => this._objectNodeMapping.fromNode(n.node)); + return objs; + } +} + +const convertToolToNode = (tool: BaseTool): TextNode => { + const nodeText = ` + Tool name: ${tool.metadata.name} + Tool description: ${tool.metadata.description} + `; + return new TextNode({ + text: nodeText, + metadata: { name: tool.metadata.name }, + excludedEmbedMetadataKeys: ["name"], + excludedLlmMetadataKeys: ["name"], + }); +}; + +export class SimpleToolNodeMapping extends BaseObjectNodeMapping { + private _tools: Record<string, BaseTool>; + + private constructor(objs: BaseTool[] = []) { + super(); + this._tools = {}; + for (const tool of objs) { + this._tools[tool.metadata.name] = tool; + } + } + + objNodeMapping(): Record<any, any> { + return this._tools; + } + + toNode(tool: BaseTool): TextNode { + return convertToolToNode(tool); + } + + _addObj(tool: BaseTool): void { + this._tools[tool.metadata.name] = tool; + } + + _fromNode(node: BaseNode): BaseTool { + if (!node.metadata) { + throw new Error("Metadata must be set"); + } + return this._tools[node.metadata.name]; + } + + persist(persistDir: string, objNodeMappingFilename: string): void { + // Implement the persist method + } + + toNodes(objs: BaseTool[]): TextNode<Metadata>[] { + return objs.map((obj) => this.toNode(obj)); + } + + addObj(obj: BaseTool): void { + this._addObj(obj); + } + + fromNode(node: BaseNode): BaseTool { + return this._fromNode(node); } - // // Translating the _asQueryComponent method - // public asQueryComponent(kwargs: any): any { - // return new ObjectRetrieverComponent(this); - // } + static fromObjects(objs: any, ...args: any[]): BaseObjectNodeMapping { + return new SimpleToolNodeMapping(objs); + } + + fromObjects<OT>(objs: any, ...args: any[]): BaseObjectNodeMapping { + return new SimpleToolNodeMapping(objs); + } +} + +export class ObjectIndex { + private _index: VectorStoreIndex; + private _objectNodeMapping: BaseObjectNodeMapping; + + private constructor(index: any, objectNodeMapping: BaseObjectNodeMapping) { + this._index = index; + this._objectNodeMapping = objectNodeMapping; + } + + static async fromObjects( + objects: any, + objectMapping: BaseObjectNodeMapping, + // TODO: fix any (bundling issue) + indexCls: any, + indexKwargs?: Record<string, any>, + ): Promise<ObjectIndex> { + if (objectMapping === null) { + objectMapping = SimpleToolNodeMapping.fromObjects(objects, {}); + } + + const nodes = objectMapping.toNodes(objects); + + const index = await indexCls.init({ nodes, ...indexKwargs }); + + return new ObjectIndex(index, objectMapping); + } + + insertObject(obj: any): void { + this._objectNodeMapping.addObj(obj); + const node = this._objectNodeMapping.toNode(obj); + this._index.insertNodes([node]); + } + + get tools(): Record<string, BaseTool> { + return this._objectNodeMapping.objNodeMapping(); + } + + async asRetriever(kwargs: any): Promise<ObjectRetriever> { + return new ObjectRetriever( + this._index.asRetriever(kwargs), + this._objectNodeMapping, + ); + } + + asNodeRetriever(kwargs: any): any { + return this._index.asRetriever(kwargs); + } } diff --git a/packages/core/src/objects/index.ts b/packages/core/src/objects/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..955fdd14398fe62365b8ce88feeb7cac6272d3e3 --- /dev/null +++ b/packages/core/src/objects/index.ts @@ -0,0 +1 @@ +export * from "./base"; diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index f4095e242e4e36abc2e4d5e49424dc93eafb0acc..3c9f646060fea10fef5e8c3bf2fa010800b53995 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -60,7 +60,6 @@ export class SimpleVectorStore implements VectorStore { this.data.embeddingDict[node.id_] = node.getEmbedding(); if (!node.sourceNode) { - console.error("Missing source node from TextNode."); continue; } diff --git a/packages/core/src/tests/Embedding.test.ts b/packages/core/src/tests/Embedding.test.ts index 68d5b207890ed3cacc5169aaee5e2b4757bc69e4..4f3b4262b3a97f626240bd69d0280c54738ab9c2 100644 --- a/packages/core/src/tests/Embedding.test.ts +++ b/packages/core/src/tests/Embedding.test.ts @@ -70,12 +70,12 @@ describe("[OpenAIEmbedding]", () => { test("getTextEmbeddings", async () => { const texts = ["hello", "world"]; const embeddings = await embedModel.getTextEmbeddings(texts); - expect(embeddings.length).toEqual(1); + expect(embeddings.length).toEqual(2); }); test("getTextEmbeddingsBatch", async () => { const texts = ["hello", "world"]; const embeddings = await embedModel.getTextEmbeddingsBatch(texts); - expect(embeddings.length).toEqual(1); + expect(embeddings.length).toEqual(2); }); }); diff --git a/packages/core/src/tests/llms/openai/utils.test.ts b/packages/core/src/tests/llms/openai/utils.test.ts index 95eb3bc03eddf74653b538f5104a180d906c5b92..5ebace03d8201810e71ba6ae9c92ec6303de7c20 100644 --- a/packages/core/src/tests/llms/openai/utils.test.ts +++ b/packages/core/src/tests/llms/openai/utils.test.ts @@ -1,5 +1,7 @@ -import { ALL_AVAILABLE_OPENAI_MODELS } from "../../../llm"; -import { isFunctionCallingModel } from "../../../llm/openai/utils"; +import { + ALL_AVAILABLE_OPENAI_MODELS, + isFunctionCallingModel, +} from "../../../llm"; describe("openai/utils", () => { test("shouldn't be a old model", () => { diff --git a/packages/core/src/tests/objects/ObjectIndex.test.ts b/packages/core/src/tests/objects/ObjectIndex.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..b46616486e0e492150c016a9d8020fc64f3a0667 --- /dev/null +++ b/packages/core/src/tests/objects/ObjectIndex.test.ts @@ -0,0 +1,130 @@ +import { + FunctionTool, + ObjectIndex, + OpenAI, + OpenAIEmbedding, + ServiceContext, + SimpleToolNodeMapping, + VectorStoreIndex, + serviceContextFromDefaults, +} from "../../index"; +import { mockEmbeddingModel, mockLlmGeneration } from "../utility/mockOpenAI"; + +jest.mock("../../llm/open_ai", () => { + return { + getOpenAISession: jest.fn().mockImplementation(() => null), + }; +}); + +describe("ObjectIndex", () => { + let serviceContext: ServiceContext; + + beforeAll(() => { + const embeddingModel = new OpenAIEmbedding(); + const llm = new OpenAI(); + + mockEmbeddingModel(embeddingModel); + mockLlmGeneration({ languageModel: llm }); + + const ctx = serviceContextFromDefaults({ + embedModel: embeddingModel, + llm, + }); + + serviceContext = ctx; + }); + + test("test_object_with_tools", async () => { + const tool1 = new FunctionTool((x: any) => x, { + name: "test_tool", + description: "test tool", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const tool2 = new FunctionTool((x: any) => x, { + name: "test_tool_2", + description: "test tool 2", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const toolMapping = SimpleToolNodeMapping.fromObjects([tool1, tool2]); + + const objectRetriever = await ObjectIndex.fromObjects( + [tool1, tool2], + toolMapping, + VectorStoreIndex, + { + serviceContext, + }, + ); + + const retriever = await objectRetriever.asRetriever({ + serviceContext, + }); + + expect(await retriever.retrieve("test")).toStrictEqual([tool1, tool2]); + }); + + test("add a new object", async () => { + const tool1 = new FunctionTool((x: any) => x, { + name: "test_tool", + description: "test tool", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const tool2 = new FunctionTool((x: any) => x, { + name: "test_tool_2", + description: "test tool 2", + parameters: { + type: "object", + properties: { + x: { + type: "string", + }, + }, + }, + }); + + const toolMapping = SimpleToolNodeMapping.fromObjects([tool1]); + + const objectRetriever = await ObjectIndex.fromObjects( + [tool1], + toolMapping, + VectorStoreIndex, + { + serviceContext, + }, + ); + + let tools = objectRetriever.tools; + + expect(Object.keys(tools).length).toBe(1); + + objectRetriever.insertObject(tool2); + + tools = objectRetriever.tools; + + expect(Object.keys(tools).length).toBe(2); + }); +}); diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index a9343cd1ff231467c97e994a78ea93e4f6b49965..aaa77d74306842d38dfe37a535ff086bb3089fd9 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -11,7 +11,7 @@ export function mockLlmGeneration({ callbackManager, }: { languageModel: OpenAI; - callbackManager: CallbackManager; + callbackManager?: CallbackManager; }) { jest .spyOn(languageModel, "chat") @@ -84,7 +84,10 @@ export function mockLlmToolCallGeneration({ ); } -export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { +export function mockEmbeddingModel( + embedModel: OpenAIEmbedding, + embeddingsLength: number = 1, +) { jest.spyOn(embedModel, "getTextEmbedding").mockImplementation(async (x) => { return new Promise((resolve) => { resolve([1, 0, 0, 0, 0, 0]); @@ -92,6 +95,9 @@ export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { }); jest.spyOn(embedModel, "getTextEmbeddings").mockImplementation(async (x) => { return new Promise((resolve) => { + if (x.length > 1) { + resolve(Array(x.length).fill([1, 0, 0, 0, 0, 0])); + } resolve([[1, 0, 0, 0, 0, 0]]); }); }); diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts index 1215bef7ea3bfb3f14b2cb58aaa70315e69c4b0f..bd695d4f9a66d5025807899554f0c05b0219c3d6 100644 --- a/packages/core/src/tools/index.ts +++ b/packages/core/src/tools/index.ts @@ -1,3 +1,4 @@ export * from "./QueryEngineTool"; export * from "./functionTool"; export * from "./types"; +export * from "./utils";