From cd829474d6630d7166e8f02dfd024f3e3f32cc34 Mon Sep 17 00:00:00 2001
From: Emanuel Ferreira <contatoferreirads@gmail.com>
Date: Tue, 6 Feb 2024 11:11:26 -0300
Subject: [PATCH] feat(queryEngineTool): add query engine tool to agents (#509)

---
 .changeset/smart-ligers-occur.md              |   5 +
 .../docs/modules/agent/query_engine_tool.mdx  | 128 ++++++++++++++++++
 examples/agent/query_openai_agent.ts          |  46 +++++++
 examples/subquestion.ts                       |  27 ++--
 packages/core/src/agent/runner/base.ts        |   9 +-
 .../engines/query/SubQuestionQueryEngine.ts   |  32 +++--
 packages/core/src/tools/QueryEngineTool.ts    |  54 ++++++++
 packages/core/src/tools/index.ts              |   1 +
 packages/core/src/types.ts                    |   7 -
 9 files changed, 278 insertions(+), 31 deletions(-)
 create mode 100644 .changeset/smart-ligers-occur.md
 create mode 100644 apps/docs/docs/modules/agent/query_engine_tool.mdx
 create mode 100644 examples/agent/query_openai_agent.ts
 create mode 100644 packages/core/src/tools/QueryEngineTool.ts

diff --git a/.changeset/smart-ligers-occur.md b/.changeset/smart-ligers-occur.md
new file mode 100644
index 000000000..acce80359
--- /dev/null
+++ b/.changeset/smart-ligers-occur.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+feat(queryEngineTool): add query engine tool to agents
diff --git a/apps/docs/docs/modules/agent/query_engine_tool.mdx b/apps/docs/docs/modules/agent/query_engine_tool.mdx
new file mode 100644
index 000000000..83ce66f92
--- /dev/null
+++ b/apps/docs/docs/modules/agent/query_engine_tool.mdx
@@ -0,0 +1,128 @@
+# OpenAI Agent + QueryEngineTool
+
+QueryEngineTool is a tool that allows you to query a vector index. In this example, we will create a vector index from a set of documents and then create a QueryEngineTool from the vector index. We will then create an OpenAIAgent with the QueryEngineTool and chat with the agent.
+
+## Setup
+
+First, you need to install the `llamaindex` package. You can do this by running the following command in your terminal:
+
+```bash
+pnpm i llamaindex
+```
+
+Then you can import the necessary classes and functions.
+
+```ts
+import {
+  OpenAIAgent,
+  SimpleDirectoryReader,
+  VectorStoreIndex,
+  QueryEngineTool,
+} from "llamaindex";
+```
+
+## Create a vector index
+
+Now we can create a vector index from a set of documents.
+
+```ts
+// Load the documents
+const documents = await new SimpleDirectoryReader().loadData({
+  directoryPath: "node_modules/llamaindex/examples/",
+});
+
+// Create a vector index from the documents
+const vectorIndex = await VectorStoreIndex.fromDocuments(documents);
+```
+
+## Create a QueryEngineTool
+
+Now we can create a QueryEngineTool from the vector index.
+
+```ts
+// Create a query engine from the vector index
+const abramovQueryEngine = vectorIndex.asQueryEngine();
+
+// Create a QueryEngineTool with the query engine
+const queryEngineTool = new QueryEngineTool({
+  queryEngine: abramovQueryEngine,
+  metadata: {
+    name: "abramov_query_engine",
+    description: "A query engine for the Abramov documents",
+  },
+});
+```
+
+## Create an OpenAIAgent
+
+```ts
+// Create an OpenAIAgent with the query engine tool tools
+
+const agent = new OpenAIAgent({
+  tools: [queryEngineTool],
+  verbose: true,
+});
+```
+
+## Chat with the agent
+
+Now we can chat with the agent.
+
+```ts
+const response = await agent.chat({
+  message: "What was his salary?",
+});
+
+console.log(String(response));
+```
+
+## Full code
+
+```ts
+import {
+  OpenAIAgent,
+  SimpleDirectoryReader,
+  VectorStoreIndex,
+  QueryEngineTool,
+} from "llamaindex";
+
+async function main() {
+  // Load the documents
+  const documents = await new SimpleDirectoryReader().loadData({
+    directoryPath: "node_modules/llamaindex/examples/",
+  });
+
+  // Create a vector index from the documents
+  const vectorIndex = await VectorStoreIndex.fromDocuments(documents);
+
+  // Create a query engine from the vector index
+  const abramovQueryEngine = vectorIndex.asQueryEngine();
+
+  // Create a QueryEngineTool with the query engine
+  const queryEngineTool = new QueryEngineTool({
+    queryEngine: abramovQueryEngine,
+    metadata: {
+      name: "abramov_query_engine",
+      description: "A query engine for the Abramov documents",
+    },
+  });
+
+  // Create an OpenAIAgent with the function tools
+  const agent = new OpenAIAgent({
+    tools: [queryEngineTool],
+    verbose: true,
+  });
+
+  // Chat with the agent
+  const response = await agent.chat({
+    message: "What was his salary?",
+  });
+
+  // Print the response
+  console.log(String(response));
+}
+
+main().then(() => {
+  console.log("Done");
+});
+```
diff --git a/examples/agent/query_openai_agent.ts b/examples/agent/query_openai_agent.ts
new file mode 100644
index 000000000..37614b46d
--- /dev/null
+++ b/examples/agent/query_openai_agent.ts
@@ -0,0 +1,46 @@
+import {
+  OpenAIAgent,
+  QueryEngineTool,
+  SimpleDirectoryReader,
+  VectorStoreIndex,
+} from "llamaindex";
+
+async function main() {
+  // Load the documents
+  const documents = await new SimpleDirectoryReader().loadData({
+    directoryPath: "node_modules/llamaindex/examples/",
+  });
+
+  // Create a vector index from the documents
+  const vectorIndex = await VectorStoreIndex.fromDocuments(documents);
+
+  // Create a query engine from the vector index
+  const abramovQueryEngine = vectorIndex.asQueryEngine();
+
+  // Create a QueryEngineTool with the query engine
+  const queryEngineTool = new QueryEngineTool({
+    queryEngine: abramovQueryEngine,
+    metadata: {
+      name: "abramov_query_engine",
+      description: "A query engine for the Abramov documents",
+    },
+  });
+
+  // Create an OpenAIAgent with the function tools
+  const agent = new OpenAIAgent({
+    tools: [queryEngineTool],
+    verbose: true,
+  });
+
+  // Chat with the agent
+  const response = await agent.chat({
+    message: "What was his salary?",
+  });
+
+  // Print the response
+  console.log(String(response));
+}
+
+main().then(() => {
+  console.log("Done");
+});
diff --git a/examples/subquestion.ts b/examples/subquestion.ts
index b1e692e1f..b1f8b3e4b 100644
--- a/examples/subquestion.ts
+++ b/examples/subquestion.ts
@@ -1,4 +1,9 @@
-import { Document, SubQuestionQueryEngine, VectorStoreIndex } from "llamaindex";
+import {
+  Document,
+  QueryEngineTool,
+  SubQuestionQueryEngine,
+  VectorStoreIndex,
+} from "llamaindex";
 
 import essay from "./essay";
 
@@ -6,16 +11,18 @@ import essay from "./essay";
   const document = new Document({ text: essay, id_: essay });
   const index = await VectorStoreIndex.fromDocuments([document]);
 
-  const queryEngine = SubQuestionQueryEngine.fromDefaults({
-    queryEngineTools: [
-      {
-        queryEngine: index.asQueryEngine(),
-        metadata: {
-          name: "pg_essay",
-          description: "Paul Graham essay on What I Worked On",
-        },
+  const queryEngineTools = [
+    new QueryEngineTool({
+      queryEngine: index.asQueryEngine(),
+      metadata: {
+        name: "pg_essay",
+        description: "Paul Graham essay on What I Worked On",
       },
-    ],
+    }),
+  ];
+
+  const queryEngine = SubQuestionQueryEngine.fromDefaults({
+    queryEngineTools,
   });
 
   const response = await queryEngine.query({
diff --git a/packages/core/src/agent/runner/base.ts b/packages/core/src/agent/runner/base.ts
index 39e4d6379..5183f16f9 100644
--- a/packages/core/src/agent/runner/base.ts
+++ b/packages/core/src/agent/runner/base.ts
@@ -266,7 +266,14 @@ export class AgentRunner extends BaseAgentRunner {
     let resultOutput;
 
     while (true) {
-      const curStepOutput = await this._runStep(task.taskId);
+      const curStepOutput = await this._runStep(
+        task.taskId,
+        undefined,
+        ChatResponseMode.WAIT,
+        {
+          toolChoice,
+        },
+      );
 
       if (curStepOutput.isLast) {
         resultOutput = curStepOutput;
diff --git a/packages/core/src/engines/query/SubQuestionQueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts
index a70dfbb9e..4c874b15a 100644
--- a/packages/core/src/engines/query/SubQuestionQueryEngine.ts
+++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts
@@ -14,9 +14,9 @@ import {
 } from "../../synthesizers";
 import {
   BaseQueryEngine,
+  BaseTool,
   QueryEngineParamsNonStreaming,
   QueryEngineParamsStreaming,
-  QueryEngineTool,
   ToolMetadata,
 } from "../../types";
 import { BaseQuestionGenerator, SubQuestion } from "./types";
@@ -27,28 +27,23 @@ import { BaseQuestionGenerator, SubQuestion } from "./types";
 export class SubQuestionQueryEngine implements BaseQueryEngine {
   responseSynthesizer: BaseSynthesizer;
   questionGen: BaseQuestionGenerator;
-  queryEngines: Record<string, BaseQueryEngine>;
+  queryEngines: BaseTool[];
   metadatas: ToolMetadata[];
 
   constructor(init: {
     questionGen: BaseQuestionGenerator;
     responseSynthesizer: BaseSynthesizer;
-    queryEngineTools: QueryEngineTool[];
+    queryEngineTools: BaseTool[];
   }) {
     this.questionGen = init.questionGen;
     this.responseSynthesizer =
       init.responseSynthesizer ?? new ResponseSynthesizer();
-    this.queryEngines = init.queryEngineTools.reduce<
-      Record<string, BaseQueryEngine>
-    >((acc, tool) => {
-      acc[tool.metadata.name] = tool.queryEngine;
-      return acc;
-    }, {});
+    this.queryEngines = init.queryEngineTools;
     this.metadatas = init.queryEngineTools.map((tool) => tool.metadata);
   }
 
   static fromDefaults(init: {
-    queryEngineTools: QueryEngineTool[];
+    queryEngineTools: BaseTool[];
     questionGen?: BaseQuestionGenerator;
     responseSynthesizer?: BaseSynthesizer;
     serviceContext?: ServiceContext;
@@ -122,13 +117,24 @@ export class SubQuestionQueryEngine implements BaseQueryEngine {
   ): Promise<NodeWithScore | null> {
     try {
       const question = subQ.subQuestion;
-      const queryEngine = this.queryEngines[subQ.toolName];
 
-      const response = await queryEngine.query({
+      const queryEngine = this.queryEngines.find(
+        (tool) => tool.metadata.name === subQ.toolName,
+      );
+
+      if (!queryEngine) {
+        return null;
+      }
+
+      const responseText = await queryEngine?.call?.({
         query: question,
         parentEvent,
       });
-      const responseText = response.response;
+
+      if (!responseText) {
+        return null;
+      }
+
       const nodeText = `Sub question: ${question}\nResponse: ${responseText}`;
       const node = new TextNode({ text: nodeText });
       return { node, score: 0 };
diff --git a/packages/core/src/tools/QueryEngineTool.ts b/packages/core/src/tools/QueryEngineTool.ts
new file mode 100644
index 000000000..41c0cc85b
--- /dev/null
+++ b/packages/core/src/tools/QueryEngineTool.ts
@@ -0,0 +1,54 @@
+import { BaseQueryEngine, BaseTool, ToolMetadata } from "../types";
+
+export type QueryEngineToolParams = {
+  queryEngine: BaseQueryEngine;
+  metadata: ToolMetadata;
+};
+
+type QueryEngineCallParams = {
+  query: string;
+};
+
+const DEFAULT_NAME = "query_engine_tool";
+const DEFAULT_DESCRIPTION =
+  "Useful for running a natural language query against a knowledge base and get back a natural language response.";
+const DEFAULT_PARAMETERS = {
+  type: "object",
+  properties: {
+    query: {
+      type: "string",
+      description: "The query to search for",
+    },
+  },
+  required: ["query"],
+};
+
+export class QueryEngineTool implements BaseTool {
+  private queryEngine: BaseQueryEngine;
+  metadata: ToolMetadata;
+
+  constructor({ queryEngine, metadata }: QueryEngineToolParams) {
+    this.queryEngine = queryEngine;
+    this.metadata = {
+      name: metadata?.name ?? DEFAULT_NAME,
+      description: metadata?.description ?? DEFAULT_DESCRIPTION,
+      parameters: metadata?.parameters ?? DEFAULT_PARAMETERS,
+    };
+  }
+
+  async call(...args: QueryEngineCallParams[]): Promise<any> {
+    let queryStr: string;
+
+    if (args && args.length > 0) {
+      queryStr = String(args[0].query);
+    } else {
+      throw new Error(
+        "Cannot call query engine without specifying `input` parameter.",
+      );
+    }
+
+    const response = await this.queryEngine.query({ query: queryStr });
+
+    return response.response;
+  }
+}
diff --git a/packages/core/src/tools/index.ts b/packages/core/src/tools/index.ts
index 2c87cd60e..1215bef7e 100644
--- a/packages/core/src/tools/index.ts
+++ b/packages/core/src/tools/index.ts
@@ -1,2 +1,3 @@
+export * from "./QueryEngineTool";
 export * from "./functionTool";
 export * from "./types";
diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts
index 61dee2fb5..f71ee294d 100644
--- a/packages/core/src/types.ts
+++ b/packages/core/src/types.ts
@@ -40,13 +40,6 @@ export interface BaseTool {
   metadata: ToolMetadata;
 }
 
-/**
- * A Tool that uses a QueryEngine.
- */
-export interface QueryEngineTool extends BaseTool {
-  queryEngine: BaseQueryEngine;
-}
-
 /**
  * An OutputParser is used to extract structured data from the raw output of the LLM.
  */
-- 
GitLab