diff --git a/.changeset/smart-ligers-occur.md b/.changeset/smart-ligers-occur.md new file mode 100644 index 0000000000000000000000000000000000000000..acce80359452196ddcd593a325e5206d0366feec --- /dev/null +++ b/.changeset/smart-ligers-occur.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat(queryEngineTool): add query engine tool to agents diff --git a/apps/docs/docs/modules/agent/query_engine_tool.mdx b/apps/docs/docs/modules/agent/query_engine_tool.mdx new file mode 100644 index 0000000000000000000000000000000000000000..83ce66f929ea168c26a1aceaa308fd56bdb77bb8 --- /dev/null +++ b/apps/docs/docs/modules/agent/query_engine_tool.mdx @@ -0,0 +1,128 @@ +# OpenAI Agent + QueryEngineTool + +QueryEngineTool is a tool that allows you to query a vector index. In this example, we will create a vector index from a set of documents and then create a QueryEngineTool from the vector index. We will then create an OpenAIAgent with the QueryEngineTool and chat with the agent. + +## Setup + +First, you need to install the `llamaindex` package. You can do this by running the following command in your terminal: + +```bash +pnpm i llamaindex +``` + +Then you can import the necessary classes and functions. + +```ts +import { + OpenAIAgent, + SimpleDirectoryReader, + VectorStoreIndex, + QueryEngineTool, +} from "llamaindex"; +``` + +## Create a vector index + +Now we can create a vector index from a set of documents. + +```ts +// Load the documents +const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples/", +}); + +// Create a vector index from the documents +const vectorIndex = await VectorStoreIndex.fromDocuments(documents); +``` + +## Create a QueryEngineTool + +Now we can create a QueryEngineTool from the vector index. + +```ts +// Create a query engine from the vector index +const abramovQueryEngine = vectorIndex.asQueryEngine(); + +// Create a QueryEngineTool with the query engine +const queryEngineTool = new QueryEngineTool({ + queryEngine: abramovQueryEngine, + metadata: { + name: "abramov_query_engine", + description: "A query engine for the Abramov documents", + }, +}); +``` + +## Create an OpenAIAgent + +```ts +// Create an OpenAIAgent with the query engine tool tools + +const agent = new OpenAIAgent({ + tools: [queryEngineTool], + verbose: true, +}); +``` + +## Chat with the agent + +Now we can chat with the agent. + +```ts +const response = await agent.chat({ + message: "What was his salary?", +}); + +console.log(String(response)); +``` + +## Full code + +```ts +import { + OpenAIAgent, + SimpleDirectoryReader, + VectorStoreIndex, + QueryEngineTool, +} from "llamaindex"; + +async function main() { + // Load the documents + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples/", + }); + + // Create a vector index from the documents + const vectorIndex = await VectorStoreIndex.fromDocuments(documents); + + // Create a query engine from the vector index + const abramovQueryEngine = vectorIndex.asQueryEngine(); + + // Create a QueryEngineTool with the query engine + const queryEngineTool = new QueryEngineTool({ + queryEngine: abramovQueryEngine, + metadata: { + name: "abramov_query_engine", + description: "A query engine for the Abramov documents", + }, + }); + + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [queryEngineTool], + verbose: true, + }); + + // Chat with the agent + const response = await agent.chat({ + message: "What was his salary?", + }); + + // Print the response + console.log(String(response)); +} + +main().then(() => { + console.log("Done"); +}); +``` diff --git a/examples/agent/query_openai_agent.ts b/examples/agent/query_openai_agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..37614b46df6703ba7f03451b0ac43f89d66a8957 --- /dev/null +++ b/examples/agent/query_openai_agent.ts @@ -0,0 +1,46 @@ +import { + OpenAIAgent, + QueryEngineTool, + SimpleDirectoryReader, + VectorStoreIndex, +} from "llamaindex"; + +async function main() { + // Load the documents + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples/", + }); + + // Create a vector index from the documents + const vectorIndex = await VectorStoreIndex.fromDocuments(documents); + + // Create a query engine from the vector index + const abramovQueryEngine = vectorIndex.asQueryEngine(); + + // Create a QueryEngineTool with the query engine + const queryEngineTool = new QueryEngineTool({ + queryEngine: abramovQueryEngine, + metadata: { + name: "abramov_query_engine", + description: "A query engine for the Abramov documents", + }, + }); + + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [queryEngineTool], + verbose: true, + }); + + // Chat with the agent + const response = await agent.chat({ + message: "What was his salary?", + }); + + // Print the response + console.log(String(response)); +} + +main().then(() => { + console.log("Done"); +}); diff --git a/examples/subquestion.ts b/examples/subquestion.ts index b1e692e1f305668a21f8ac12076ea247cc7046e8..b1f8b3e4bec417f54baa41e14f24cd3c9c8dbacf 100644 --- a/examples/subquestion.ts +++ b/examples/subquestion.ts @@ -1,4 +1,9 @@ -import { Document, SubQuestionQueryEngine, VectorStoreIndex } from "llamaindex"; +import { + Document, + QueryEngineTool, + SubQuestionQueryEngine, + VectorStoreIndex, +} from "llamaindex"; import essay from "./essay"; @@ -6,16 +11,18 @@ import essay from "./essay"; const document = new Document({ text: essay, id_: essay }); const index = await VectorStoreIndex.fromDocuments([document]); - const queryEngine = SubQuestionQueryEngine.fromDefaults({ - queryEngineTools: [ - { - queryEngine: index.asQueryEngine(), - metadata: { - name: "pg_essay", - description: "Paul Graham essay on What I Worked On", - }, + const queryEngineTools = [ + new QueryEngineTool({ + queryEngine: index.asQueryEngine(), + metadata: { + name: "pg_essay", + description: "Paul Graham essay on What I Worked On", }, - ], + }), + ]; + + const queryEngine = SubQuestionQueryEngine.fromDefaults({ + queryEngineTools, }); const response = await queryEngine.query({ diff --git a/packages/core/src/agent/runner/base.ts b/packages/core/src/agent/runner/base.ts index 39e4d6379d7372935f57934e7bc0d05277c616ca..5183f16f98015fe6c1a77aa6eb5bafcc87729922 100644 --- a/packages/core/src/agent/runner/base.ts +++ b/packages/core/src/agent/runner/base.ts @@ -266,7 +266,14 @@ export class AgentRunner extends BaseAgentRunner { let resultOutput; while (true) { - const curStepOutput = await this._runStep(task.taskId); + const curStepOutput = await this._runStep( + task.taskId, + undefined, + ChatResponseMode.WAIT, + { + toolChoice, + }, + ); if (curStepOutput.isLast) { resultOutput = curStepOutput; diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts index a70dfbb9e98ffc24037aabe9883fcdf763633b23..4c874b15a46c7220418eef686e36f38fe6c958fa 100644 --- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -14,9 +14,9 @@ import { } from "../../synthesizers"; import { BaseQueryEngine, + BaseTool, QueryEngineParamsNonStreaming, QueryEngineParamsStreaming, - QueryEngineTool, ToolMetadata, } from "../../types"; import { BaseQuestionGenerator, SubQuestion } from "./types"; @@ -27,28 +27,23 @@ import { BaseQuestionGenerator, SubQuestion } from "./types"; export class SubQuestionQueryEngine implements BaseQueryEngine { responseSynthesizer: BaseSynthesizer; questionGen: BaseQuestionGenerator; - queryEngines: Record<string, BaseQueryEngine>; + queryEngines: BaseTool[]; metadatas: ToolMetadata[]; constructor(init: { questionGen: BaseQuestionGenerator; responseSynthesizer: BaseSynthesizer; - queryEngineTools: QueryEngineTool[]; + queryEngineTools: BaseTool[]; }) { this.questionGen = init.questionGen; this.responseSynthesizer = init.responseSynthesizer ?? new ResponseSynthesizer(); - this.queryEngines = init.queryEngineTools.reduce< - Record<string, BaseQueryEngine> - >((acc, tool) => { - acc[tool.metadata.name] = tool.queryEngine; - return acc; - }, {}); + this.queryEngines = init.queryEngineTools; this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); } static fromDefaults(init: { - queryEngineTools: QueryEngineTool[]; + queryEngineTools: BaseTool[]; questionGen?: BaseQuestionGenerator; responseSynthesizer?: BaseSynthesizer; serviceContext?: ServiceContext; @@ -122,13 +117,24 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { ): Promise<NodeWithScore | null> { try { const question = subQ.subQuestion; - const queryEngine = this.queryEngines[subQ.toolName]; - const response = await queryEngine.query({ + const queryEngine = this.queryEngines.find( + (tool) => tool.metadata.name === subQ.toolName, + ); + + if (!queryEngine) { + return null; + } + + const responseText = await queryEngine?.call?.({ query: question, parentEvent, }); - const responseText = response.response; + + if (!responseText) { + return null; + } + const nodeText = `Sub question: ${question}\nResponse: ${responseText}`; const node = new TextNode({ text: nodeText }); return { node, score: 0 }; diff --git a/packages/core/src/tools/QueryEngineTool.ts b/packages/core/src/tools/QueryEngineTool.ts new file mode 100644 index 0000000000000000000000000000000000000000..41c0cc85be5f5c5e3bcc0b38cf1c14de58b53a13 --- /dev/null +++ b/packages/core/src/tools/QueryEngineTool.ts @@ -0,0 +1,54 @@ +import { BaseQueryEngine, BaseTool, ToolMetadata } from "../types"; + +export type QueryEngineToolParams = { + queryEngine: BaseQueryEngine; + metadata: ToolMetadata; +}; + +type QueryEngineCallParams = { + query: string; +}; + +const DEFAULT_NAME = "query_engine_tool"; +const DEFAULT_DESCRIPTION = + "Useful for running a natural language query against a knowledge base and get back a natural language response."; +const DEFAULT_PARAMETERS = { + type: "object", + properties: { + query: { + type: "string", + description: "The query to search for", + }, + }, + required: ["query"], +}; + +export class QueryEngineTool implements BaseTool { + private queryEngine: BaseQueryEngine; + metadata: ToolMetadata; + + constructor({ queryEngine, metadata }: QueryEngineToolParams) { + this.queryEngine = queryEngine; + this.metadata = { + name: metadata?.name ?? DEFAULT_NAME, + description: metadata?.description ?? DEFAULT_DESCRIPTION, + parameters: metadata?.parameters ?? DEFAULT_PARAMETERS, + }; + } + + async call(...args: QueryEngineCallParams[]): Promise<any> { + let queryStr: string; + + if (args && args.length > 0) { + queryStr = String(args[0].query); + } else { + throw new Error( + "Cannot call query engine without specifying `input` parameter.", + ); + } + + const response = await this.queryEngine.query({ query: queryStr }); + + return response.response; + } +} diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts index 2c87cd60ed0fcfb23760ffa446e4ee0a59ddad35..1215bef7ea3bfb3f14b2cb58aaa70315e69c4b0f 100644 --- a/packages/core/src/tools/index.ts +++ b/packages/core/src/tools/index.ts @@ -1,2 +1,3 @@ +export * from "./QueryEngineTool"; export * from "./functionTool"; export * from "./types"; diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 61dee2fb5cbedf34f85968b6894965bff8a25aec..f71ee294d18fc3e777ed27f23f236ec191324ba3 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -40,13 +40,6 @@ export interface BaseTool { metadata: ToolMetadata; } -/** - * A Tool that uses a QueryEngine. - */ -export interface QueryEngineTool extends BaseTool { - queryEngine: BaseQueryEngine; -} - /** * An OutputParser is used to extract structured data from the raw output of the LLM. */