From 26ab74bd9a62e79191f4ecb71d23d2d082488bef Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Mon, 25 Mar 2024 09:00:25 +0700
Subject: [PATCH] feat: support agent in typescript templates (#5)

---
 helpers/tools.ts                              |  4 +++
 helpers/typescript.ts                         | 25 ++++++++++++++++++
 questions.ts                                  | 11 ++++----
 .../engines/typescript/agent/chat.ts          | 26 +++++++++++++++++++
 .../engines/typescript/chat/chat.ts           | 13 ++++++++++
 .../vectordbs/typescript/none/index.ts        | 14 +---------
 .../src/controllers/chat.controller.ts        |  2 +-
 .../controllers/engine/{index.ts => chat.ts}  |  0
 .../src/controllers/chat.controller.ts        |  2 +-
 .../controllers/engine/{index.ts => chat.ts}  |  0
 .../src/controllers/llamaindex-stream.ts      |  8 ++++--
 .../app/api/chat/engine/{index.ts => chat.ts} |  0
 .../nextjs/app/api/chat/llamaindex-stream.ts  |  8 ++++--
 .../streaming/nextjs/app/api/chat/route.ts    |  2 +-
 14 files changed, 89 insertions(+), 26 deletions(-)
 create mode 100644 templates/components/engines/typescript/agent/chat.ts
 create mode 100644 templates/components/engines/typescript/chat/chat.ts
 rename templates/types/simple/express/src/controllers/engine/{index.ts => chat.ts} (100%)
 rename templates/types/streaming/express/src/controllers/engine/{index.ts => chat.ts} (100%)
 rename templates/types/streaming/nextjs/app/api/chat/engine/{index.ts => chat.ts} (100%)

diff --git a/helpers/tools.ts b/helpers/tools.ts
index 49559253..f2e44bd0 100644
--- a/helpers/tools.ts
+++ b/helpers/tools.ts
@@ -1,10 +1,12 @@
 import { red } from "picocolors";
+import { TemplateFramework } from "./types";
 
 export type Tool = {
   display: string;
   name: string;
   config?: Record<string, any>;
   dependencies?: ToolDependencies[];
+  supportedFrameworks?: Array<TemplateFramework>;
 };
 export type ToolDependencies = {
   name: string;
@@ -27,6 +29,7 @@ export const supportedTools: Tool[] = [
         version: "0.1.2",
       },
     ],
+    supportedFrameworks: ["fastapi"],
   },
   {
     display: "Wikipedia",
@@ -37,6 +40,7 @@ export const supportedTools: Tool[] = [
         version: "0.1.2",
       },
     ],
+    supportedFrameworks: ["fastapi", "express", "nextjs"],
   },
 ];
 
diff --git a/helpers/typescript.ts b/helpers/typescript.ts
index 91dd5b82..4fdbca98 100644
--- a/helpers/typescript.ts
+++ b/helpers/typescript.ts
@@ -64,6 +64,7 @@ export const installTSTemplate = async ({
   postInstallAction,
   backend,
   observability,
+  tools,
   dataSource,
 }: InstallTemplateArgs & { backend: boolean }) => {
   console.log(bold(`Using ${packageManager}.`));
@@ -186,6 +187,30 @@ export const installTSTemplate = async ({
         cwd: path.join(compPath, "loaders", "typescript", loaderFolder),
       });
     }
+
+    // copy tools component
+    if (tools?.length) {
+      await copy("**", enginePath, {
+        parents: true,
+        cwd: path.join(compPath, "engines", "typescript", "agent"),
+      });
+
+      // Write tools_config.json
+      const configContent: Record<string, any> = {};
+      tools.forEach((tool) => {
+        configContent[tool.name] = tool.config ?? {};
+      });
+      const configFilePath = path.join(enginePath, "tools_config.json");
+      await fs.writeFile(
+        configFilePath,
+        JSON.stringify(configContent, null, 2),
+      );
+    } else if (engine !== "simple") {
+      await copy("**", enginePath, {
+        parents: true,
+        cwd: path.join(compPath, "engines", "typescript", "chat"),
+      });
+    }
   }
 
   /**
diff --git a/questions.ts b/questions.ts
index bbf9322a..cb97fae7 100644
--- a/questions.ts
+++ b/questions.ts
@@ -805,15 +805,14 @@ export const askQuestions = async (
     }
   }
 
-  if (
-    !program.tools &&
-    program.framework === "fastapi" &&
-    program.engine === "context"
-  ) {
+  if (!program.tools && program.engine === "context") {
     if (ciInfo.isCI) {
       program.tools = getPrefOrDefault("tools");
     } else {
-      const toolChoices = supportedTools.map((tool) => ({
+      const options = supportedTools.filter((t) =>
+        t.supportedFrameworks?.includes(program.framework),
+      );
+      const toolChoices = options.map((tool) => ({
         title: tool.display,
         value: tool.name,
       }));
diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts
new file mode 100644
index 00000000..f3660c25
--- /dev/null
+++ b/templates/components/engines/typescript/agent/chat.ts
@@ -0,0 +1,26 @@
+import { OpenAI, OpenAIAgent, QueryEngineTool, ToolFactory } from "llamaindex";
+import { STORAGE_CACHE_DIR } from "./constants.mjs";
+import { getDataSource } from "./index";
+import config from "./tools_config.json";
+
+export async function createChatEngine(llm: OpenAI) {
+  const index = await getDataSource(llm);
+  const queryEngine = index.asQueryEngine();
+  const queryEngineTool = new QueryEngineTool({
+    queryEngine: queryEngine,
+    metadata: {
+      name: "data_query_engine",
+      description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`,
+    },
+  });
+
+  const externalTools = await ToolFactory.createTools(config);
+
+  const agent = new OpenAIAgent({
+    tools: [queryEngineTool, ...externalTools],
+    verbose: true,
+    llm,
+  });
+
+  return agent;
+}
diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts
new file mode 100644
index 00000000..cf77edb3
--- /dev/null
+++ b/templates/components/engines/typescript/chat/chat.ts
@@ -0,0 +1,13 @@
+import { ContextChatEngine, LLM } from "llamaindex";
+import { getDataSource } from "./index";
+
+export async function createChatEngine(llm: LLM) {
+  const index = await getDataSource(llm);
+  const retriever = index.asRetriever();
+  retriever.similarityTopK = 3;
+
+  return new ContextChatEngine({
+    chatModel: llm,
+    retriever,
+  });
+}
diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts
index e335446c..528d6057 100644
--- a/templates/components/vectordbs/typescript/none/index.ts
+++ b/templates/components/vectordbs/typescript/none/index.ts
@@ -1,5 +1,4 @@
 import {
-  ContextChatEngine,
   LLM,
   serviceContextFromDefaults,
   SimpleDocumentStore,
@@ -8,7 +7,7 @@ import {
 } from "llamaindex";
 import { CHUNK_OVERLAP, CHUNK_SIZE, STORAGE_CACHE_DIR } from "./constants.mjs";
 
-async function getDataSource(llm: LLM) {
+export async function getDataSource(llm: LLM) {
   const serviceContext = serviceContextFromDefaults({
     llm,
     chunkSize: CHUNK_SIZE,
@@ -31,14 +30,3 @@ async function getDataSource(llm: LLM) {
     serviceContext,
   });
 }
-
-export async function createChatEngine(llm: LLM) {
-  const index = await getDataSource(llm);
-  const retriever = index.asRetriever();
-  retriever.similarityTopK = 3;
-
-  return new ContextChatEngine({
-    chatModel: llm,
-    retriever,
-  });
-}
diff --git a/templates/types/simple/express/src/controllers/chat.controller.ts b/templates/types/simple/express/src/controllers/chat.controller.ts
index 9f9639b7..5fdb8825 100644
--- a/templates/types/simple/express/src/controllers/chat.controller.ts
+++ b/templates/types/simple/express/src/controllers/chat.controller.ts
@@ -1,6 +1,6 @@
 import { Request, Response } from "express";
 import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
-import { createChatEngine } from "./engine";
+import { createChatEngine } from "./engine/chat";
 
 const convertMessageContent = (
   textMessage: string,
diff --git a/templates/types/simple/express/src/controllers/engine/index.ts b/templates/types/simple/express/src/controllers/engine/chat.ts
similarity index 100%
rename from templates/types/simple/express/src/controllers/engine/index.ts
rename to templates/types/simple/express/src/controllers/engine/chat.ts
diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts
index 9d1eb0c6..1d9cd56a 100644
--- a/templates/types/streaming/express/src/controllers/chat.controller.ts
+++ b/templates/types/streaming/express/src/controllers/chat.controller.ts
@@ -1,7 +1,7 @@
 import { streamToResponse } from "ai";
 import { Request, Response } from "express";
 import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
-import { createChatEngine } from "./engine";
+import { createChatEngine } from "./engine/chat";
 import { LlamaIndexStream } from "./llamaindex-stream";
 
 const convertMessageContent = (
diff --git a/templates/types/streaming/express/src/controllers/engine/index.ts b/templates/types/streaming/express/src/controllers/engine/chat.ts
similarity index 100%
rename from templates/types/streaming/express/src/controllers/engine/index.ts
rename to templates/types/streaming/express/src/controllers/engine/chat.ts
diff --git a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts
index 6ddd8eae..f0c9d80c 100644
--- a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts
+++ b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts
@@ -6,7 +6,7 @@ import {
   trimStartOfStreamHelper,
   type AIStreamCallbacksAndOptions,
 } from "ai";
-import { Response } from "llamaindex";
+import { Response, StreamingAgentChatResponse } from "llamaindex";
 
 type ParserOptions = {
   image_url?: string;
@@ -52,13 +52,17 @@ function createParser(
 }
 
 export function LlamaIndexStream(
-  res: AsyncIterable<Response>,
+  response: StreamingAgentChatResponse | AsyncIterable<Response>,
   opts?: {
     callbacks?: AIStreamCallbacksAndOptions;
     parserOptions?: ParserOptions;
   },
 ): { stream: ReadableStream; data: experimental_StreamData } {
   const data = new experimental_StreamData();
+  const res =
+    response instanceof StreamingAgentChatResponse
+      ? response.response
+      : response;
   return {
     stream: createParser(res, data, opts?.parserOptions)
       .pipeThrough(createCallbacksTransformer(opts?.callbacks))
diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/index.ts b/templates/types/streaming/nextjs/app/api/chat/engine/chat.ts
similarity index 100%
rename from templates/types/streaming/nextjs/app/api/chat/engine/index.ts
rename to templates/types/streaming/nextjs/app/api/chat/engine/chat.ts
diff --git a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts
index 6ddd8eae..f0c9d80c 100644
--- a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts
+++ b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts
@@ -6,7 +6,7 @@ import {
   trimStartOfStreamHelper,
   type AIStreamCallbacksAndOptions,
 } from "ai";
-import { Response } from "llamaindex";
+import { Response, StreamingAgentChatResponse } from "llamaindex";
 
 type ParserOptions = {
   image_url?: string;
@@ -52,13 +52,17 @@ function createParser(
 }
 
 export function LlamaIndexStream(
-  res: AsyncIterable<Response>,
+  response: StreamingAgentChatResponse | AsyncIterable<Response>,
   opts?: {
     callbacks?: AIStreamCallbacksAndOptions;
     parserOptions?: ParserOptions;
   },
 ): { stream: ReadableStream; data: experimental_StreamData } {
   const data = new experimental_StreamData();
+  const res =
+    response instanceof StreamingAgentChatResponse
+      ? response.response
+      : response;
   return {
     stream: createParser(res, data, opts?.parserOptions)
       .pipeThrough(createCallbacksTransformer(opts?.callbacks))
diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts
index 32b9bb16..484262f2 100644
--- a/templates/types/streaming/nextjs/app/api/chat/route.ts
+++ b/templates/types/streaming/nextjs/app/api/chat/route.ts
@@ -2,7 +2,7 @@ import { initObservability } from "@/app/observability";
 import { StreamingTextResponse } from "ai";
 import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
 import { NextRequest, NextResponse } from "next/server";
-import { createChatEngine } from "./engine";
+import { createChatEngine } from "./engine/chat";
 import { LlamaIndexStream } from "./llamaindex-stream";
 
 initObservability();
-- 
GitLab