From 26ab74bd9a62e79191f4ecb71d23d2d082488bef Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Mon, 25 Mar 2024 09:00:25 +0700 Subject: [PATCH] feat: support agent in typescript templates (#5) --- helpers/tools.ts | 4 +++ helpers/typescript.ts | 25 ++++++++++++++++++ questions.ts | 11 ++++---- .../engines/typescript/agent/chat.ts | 26 +++++++++++++++++++ .../engines/typescript/chat/chat.ts | 13 ++++++++++ .../vectordbs/typescript/none/index.ts | 14 +--------- .../src/controllers/chat.controller.ts | 2 +- .../controllers/engine/{index.ts => chat.ts} | 0 .../src/controllers/chat.controller.ts | 2 +- .../controllers/engine/{index.ts => chat.ts} | 0 .../src/controllers/llamaindex-stream.ts | 8 ++++-- .../app/api/chat/engine/{index.ts => chat.ts} | 0 .../nextjs/app/api/chat/llamaindex-stream.ts | 8 ++++-- .../streaming/nextjs/app/api/chat/route.ts | 2 +- 14 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 templates/components/engines/typescript/agent/chat.ts create mode 100644 templates/components/engines/typescript/chat/chat.ts rename templates/types/simple/express/src/controllers/engine/{index.ts => chat.ts} (100%) rename templates/types/streaming/express/src/controllers/engine/{index.ts => chat.ts} (100%) rename templates/types/streaming/nextjs/app/api/chat/engine/{index.ts => chat.ts} (100%) diff --git a/helpers/tools.ts b/helpers/tools.ts index 49559253..f2e44bd0 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -1,10 +1,12 @@ import { red } from "picocolors"; +import { TemplateFramework } from "./types"; export type Tool = { display: string; name: string; config?: Record<string, any>; dependencies?: ToolDependencies[]; + supportedFrameworks?: Array<TemplateFramework>; }; export type ToolDependencies = { name: string; @@ -27,6 +29,7 @@ export const supportedTools: Tool[] = [ version: "0.1.2", }, ], + supportedFrameworks: ["fastapi"], }, { display: "Wikipedia", @@ -37,6 +40,7 @@ export const supportedTools: Tool[] = [ version: "0.1.2", }, ], + supportedFrameworks: ["fastapi", "express", "nextjs"], }, ]; diff --git a/helpers/typescript.ts b/helpers/typescript.ts index 91dd5b82..4fdbca98 100644 --- a/helpers/typescript.ts +++ b/helpers/typescript.ts @@ -64,6 +64,7 @@ export const installTSTemplate = async ({ postInstallAction, backend, observability, + tools, dataSource, }: InstallTemplateArgs & { backend: boolean }) => { console.log(bold(`Using ${packageManager}.`)); @@ -186,6 +187,30 @@ export const installTSTemplate = async ({ cwd: path.join(compPath, "loaders", "typescript", loaderFolder), }); } + + // copy tools component + if (tools?.length) { + await copy("**", enginePath, { + parents: true, + cwd: path.join(compPath, "engines", "typescript", "agent"), + }); + + // Write tools_config.json + const configContent: Record<string, any> = {}; + tools.forEach((tool) => { + configContent[tool.name] = tool.config ?? {}; + }); + const configFilePath = path.join(enginePath, "tools_config.json"); + await fs.writeFile( + configFilePath, + JSON.stringify(configContent, null, 2), + ); + } else if (engine !== "simple") { + await copy("**", enginePath, { + parents: true, + cwd: path.join(compPath, "engines", "typescript", "chat"), + }); + } } /** diff --git a/questions.ts b/questions.ts index bbf9322a..cb97fae7 100644 --- a/questions.ts +++ b/questions.ts @@ -805,15 +805,14 @@ export const askQuestions = async ( } } - if ( - !program.tools && - program.framework === "fastapi" && - program.engine === "context" - ) { + if (!program.tools && program.engine === "context") { if (ciInfo.isCI) { program.tools = getPrefOrDefault("tools"); } else { - const toolChoices = supportedTools.map((tool) => ({ + const options = supportedTools.filter((t) => + t.supportedFrameworks?.includes(program.framework), + ); + const toolChoices = options.map((tool) => ({ title: tool.display, value: tool.name, })); diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts new file mode 100644 index 00000000..f3660c25 --- /dev/null +++ b/templates/components/engines/typescript/agent/chat.ts @@ -0,0 +1,26 @@ +import { OpenAI, OpenAIAgent, QueryEngineTool, ToolFactory } from "llamaindex"; +import { STORAGE_CACHE_DIR } from "./constants.mjs"; +import { getDataSource } from "./index"; +import config from "./tools_config.json"; + +export async function createChatEngine(llm: OpenAI) { + const index = await getDataSource(llm); + const queryEngine = index.asQueryEngine(); + const queryEngineTool = new QueryEngineTool({ + queryEngine: queryEngine, + metadata: { + name: "data_query_engine", + description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`, + }, + }); + + const externalTools = await ToolFactory.createTools(config); + + const agent = new OpenAIAgent({ + tools: [queryEngineTool, ...externalTools], + verbose: true, + llm, + }); + + return agent; +} diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts new file mode 100644 index 00000000..cf77edb3 --- /dev/null +++ b/templates/components/engines/typescript/chat/chat.ts @@ -0,0 +1,13 @@ +import { ContextChatEngine, LLM } from "llamaindex"; +import { getDataSource } from "./index"; + +export async function createChatEngine(llm: LLM) { + const index = await getDataSource(llm); + const retriever = index.asRetriever(); + retriever.similarityTopK = 3; + + return new ContextChatEngine({ + chatModel: llm, + retriever, + }); +} diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts index e335446c..528d6057 100644 --- a/templates/components/vectordbs/typescript/none/index.ts +++ b/templates/components/vectordbs/typescript/none/index.ts @@ -1,5 +1,4 @@ import { - ContextChatEngine, LLM, serviceContextFromDefaults, SimpleDocumentStore, @@ -8,7 +7,7 @@ import { } from "llamaindex"; import { CHUNK_OVERLAP, CHUNK_SIZE, STORAGE_CACHE_DIR } from "./constants.mjs"; -async function getDataSource(llm: LLM) { +export async function getDataSource(llm: LLM) { const serviceContext = serviceContextFromDefaults({ llm, chunkSize: CHUNK_SIZE, @@ -31,14 +30,3 @@ async function getDataSource(llm: LLM) { serviceContext, }); } - -export async function createChatEngine(llm: LLM) { - const index = await getDataSource(llm); - const retriever = index.asRetriever(); - retriever.similarityTopK = 3; - - return new ContextChatEngine({ - chatModel: llm, - retriever, - }); -} diff --git a/templates/types/simple/express/src/controllers/chat.controller.ts b/templates/types/simple/express/src/controllers/chat.controller.ts index 9f9639b7..5fdb8825 100644 --- a/templates/types/simple/express/src/controllers/chat.controller.ts +++ b/templates/types/simple/express/src/controllers/chat.controller.ts @@ -1,6 +1,6 @@ import { Request, Response } from "express"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; -import { createChatEngine } from "./engine"; +import { createChatEngine } from "./engine/chat"; const convertMessageContent = ( textMessage: string, diff --git a/templates/types/simple/express/src/controllers/engine/index.ts b/templates/types/simple/express/src/controllers/engine/chat.ts similarity index 100% rename from templates/types/simple/express/src/controllers/engine/index.ts rename to templates/types/simple/express/src/controllers/engine/chat.ts diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 9d1eb0c6..1d9cd56a 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -1,7 +1,7 @@ import { streamToResponse } from "ai"; import { Request, Response } from "express"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; -import { createChatEngine } from "./engine"; +import { createChatEngine } from "./engine/chat"; import { LlamaIndexStream } from "./llamaindex-stream"; const convertMessageContent = ( diff --git a/templates/types/streaming/express/src/controllers/engine/index.ts b/templates/types/streaming/express/src/controllers/engine/chat.ts similarity index 100% rename from templates/types/streaming/express/src/controllers/engine/index.ts rename to templates/types/streaming/express/src/controllers/engine/chat.ts diff --git a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts index 6ddd8eae..f0c9d80c 100644 --- a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts +++ b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts @@ -6,7 +6,7 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; -import { Response } from "llamaindex"; +import { Response, StreamingAgentChatResponse } from "llamaindex"; type ParserOptions = { image_url?: string; @@ -52,13 +52,17 @@ function createParser( } export function LlamaIndexStream( - res: AsyncIterable<Response>, + response: StreamingAgentChatResponse | AsyncIterable<Response>, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; }, ): { stream: ReadableStream; data: experimental_StreamData } { const data = new experimental_StreamData(); + const res = + response instanceof StreamingAgentChatResponse + ? response.response + : response; return { stream: createParser(res, data, opts?.parserOptions) .pipeThrough(createCallbacksTransformer(opts?.callbacks)) diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/index.ts b/templates/types/streaming/nextjs/app/api/chat/engine/chat.ts similarity index 100% rename from templates/types/streaming/nextjs/app/api/chat/engine/index.ts rename to templates/types/streaming/nextjs/app/api/chat/engine/chat.ts diff --git a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts index 6ddd8eae..f0c9d80c 100644 --- a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts +++ b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts @@ -6,7 +6,7 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; -import { Response } from "llamaindex"; +import { Response, StreamingAgentChatResponse } from "llamaindex"; type ParserOptions = { image_url?: string; @@ -52,13 +52,17 @@ function createParser( } export function LlamaIndexStream( - res: AsyncIterable<Response>, + response: StreamingAgentChatResponse | AsyncIterable<Response>, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; }, ): { stream: ReadableStream; data: experimental_StreamData } { const data = new experimental_StreamData(); + const res = + response instanceof StreamingAgentChatResponse + ? response.response + : response; return { stream: createParser(res, data, opts?.parserOptions) .pipeThrough(createCallbacksTransformer(opts?.callbacks)) diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index 32b9bb16..484262f2 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -2,7 +2,7 @@ import { initObservability } from "@/app/observability"; import { StreamingTextResponse } from "ai"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; -import { createChatEngine } from "./engine"; +import { createChatEngine } from "./engine/chat"; import { LlamaIndexStream } from "./llamaindex-stream"; initObservability(); -- GitLab