diff --git a/helpers/tools.ts b/helpers/tools.ts index 49559253776d3fef308b130e14a067fd6481db6a..f2e44bd079fce2706592e3d95b59814a6172baf7 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -1,10 +1,12 @@ import { red } from "picocolors"; +import { TemplateFramework } from "./types"; export type Tool = { display: string; name: string; config?: Record<string, any>; dependencies?: ToolDependencies[]; + supportedFrameworks?: Array<TemplateFramework>; }; export type ToolDependencies = { name: string; @@ -27,6 +29,7 @@ export const supportedTools: Tool[] = [ version: "0.1.2", }, ], + supportedFrameworks: ["fastapi"], }, { display: "Wikipedia", @@ -37,6 +40,7 @@ export const supportedTools: Tool[] = [ version: "0.1.2", }, ], + supportedFrameworks: ["fastapi", "express", "nextjs"], }, ]; diff --git a/helpers/typescript.ts b/helpers/typescript.ts index 91dd5b8250a12cd1dbdd85c38e74d556c624a74a..4fdbca9874330dd919ca8db01c72a961ef6e5836 100644 --- a/helpers/typescript.ts +++ b/helpers/typescript.ts @@ -64,6 +64,7 @@ export const installTSTemplate = async ({ postInstallAction, backend, observability, + tools, dataSource, }: InstallTemplateArgs & { backend: boolean }) => { console.log(bold(`Using ${packageManager}.`)); @@ -186,6 +187,30 @@ export const installTSTemplate = async ({ cwd: path.join(compPath, "loaders", "typescript", loaderFolder), }); } + + // copy tools component + if (tools?.length) { + await copy("**", enginePath, { + parents: true, + cwd: path.join(compPath, "engines", "typescript", "agent"), + }); + + // Write tools_config.json + const configContent: Record<string, any> = {}; + tools.forEach((tool) => { + configContent[tool.name] = tool.config ?? {}; + }); + const configFilePath = path.join(enginePath, "tools_config.json"); + await fs.writeFile( + configFilePath, + JSON.stringify(configContent, null, 2), + ); + } else if (engine !== "simple") { + await copy("**", enginePath, { + parents: true, + cwd: path.join(compPath, "engines", "typescript", "chat"), + }); + } } /** diff --git a/questions.ts b/questions.ts index bbf9322af8b680d749a5f8f672378a27b6b3a689..cb97fae743ffe0fa46f2c3e7aa3fbdc63304b710 100644 --- a/questions.ts +++ b/questions.ts @@ -805,15 +805,14 @@ export const askQuestions = async ( } } - if ( - !program.tools && - program.framework === "fastapi" && - program.engine === "context" - ) { + if (!program.tools && program.engine === "context") { if (ciInfo.isCI) { program.tools = getPrefOrDefault("tools"); } else { - const toolChoices = supportedTools.map((tool) => ({ + const options = supportedTools.filter((t) => + t.supportedFrameworks?.includes(program.framework), + ); + const toolChoices = options.map((tool) => ({ title: tool.display, value: tool.name, })); diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts new file mode 100644 index 0000000000000000000000000000000000000000..f3660c2514bd391957598577af80592100cd2047 --- /dev/null +++ b/templates/components/engines/typescript/agent/chat.ts @@ -0,0 +1,26 @@ +import { OpenAI, OpenAIAgent, QueryEngineTool, ToolFactory } from "llamaindex"; +import { STORAGE_CACHE_DIR } from "./constants.mjs"; +import { getDataSource } from "./index"; +import config from "./tools_config.json"; + +export async function createChatEngine(llm: OpenAI) { + const index = await getDataSource(llm); + const queryEngine = index.asQueryEngine(); + const queryEngineTool = new QueryEngineTool({ + queryEngine: queryEngine, + metadata: { + name: "data_query_engine", + description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`, + }, + }); + + const externalTools = await ToolFactory.createTools(config); + + const agent = new OpenAIAgent({ + tools: [queryEngineTool, ...externalTools], + verbose: true, + llm, + }); + + return agent; +} diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts new file mode 100644 index 0000000000000000000000000000000000000000..cf77edb3379245c008ae817c19f9c0a9dab668e6 --- /dev/null +++ b/templates/components/engines/typescript/chat/chat.ts @@ -0,0 +1,13 @@ +import { ContextChatEngine, LLM } from "llamaindex"; +import { getDataSource } from "./index"; + +export async function createChatEngine(llm: LLM) { + const index = await getDataSource(llm); + const retriever = index.asRetriever(); + retriever.similarityTopK = 3; + + return new ContextChatEngine({ + chatModel: llm, + retriever, + }); +} diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts index e335446cfd72e9da910eab3228848e32e1e0475a..528d605727fa4ac832bff2b20c1e5d41e3cb5c49 100644 --- a/templates/components/vectordbs/typescript/none/index.ts +++ b/templates/components/vectordbs/typescript/none/index.ts @@ -1,5 +1,4 @@ import { - ContextChatEngine, LLM, serviceContextFromDefaults, SimpleDocumentStore, @@ -8,7 +7,7 @@ import { } from "llamaindex"; import { CHUNK_OVERLAP, CHUNK_SIZE, STORAGE_CACHE_DIR } from "./constants.mjs"; -async function getDataSource(llm: LLM) { +export async function getDataSource(llm: LLM) { const serviceContext = serviceContextFromDefaults({ llm, chunkSize: CHUNK_SIZE, @@ -31,14 +30,3 @@ async function getDataSource(llm: LLM) { serviceContext, }); } - -export async function createChatEngine(llm: LLM) { - const index = await getDataSource(llm); - const retriever = index.asRetriever(); - retriever.similarityTopK = 3; - - return new ContextChatEngine({ - chatModel: llm, - retriever, - }); -} diff --git a/templates/types/simple/express/src/controllers/chat.controller.ts b/templates/types/simple/express/src/controllers/chat.controller.ts index 9f9639b72724669e42ddd2ad25dcc2d31368c07c..5fdb88255b223c71ce8aab8789614596205defd7 100644 --- a/templates/types/simple/express/src/controllers/chat.controller.ts +++ b/templates/types/simple/express/src/controllers/chat.controller.ts @@ -1,6 +1,6 @@ import { Request, Response } from "express"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; -import { createChatEngine } from "./engine"; +import { createChatEngine } from "./engine/chat"; const convertMessageContent = ( textMessage: string, diff --git a/templates/types/simple/express/src/controllers/engine/index.ts b/templates/types/simple/express/src/controllers/engine/chat.ts similarity index 100% rename from templates/types/simple/express/src/controllers/engine/index.ts rename to templates/types/simple/express/src/controllers/engine/chat.ts diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 9d1eb0c69b50ae4f258e5321eec1ffcc5d0bda1e..1d9cd56a79b3b9ce4ea89f82b6030aeaf8a84625 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -1,7 +1,7 @@ import { streamToResponse } from "ai"; import { Request, Response } from "express"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; -import { createChatEngine } from "./engine"; +import { createChatEngine } from "./engine/chat"; import { LlamaIndexStream } from "./llamaindex-stream"; const convertMessageContent = ( diff --git a/templates/types/streaming/express/src/controllers/engine/index.ts b/templates/types/streaming/express/src/controllers/engine/chat.ts similarity index 100% rename from templates/types/streaming/express/src/controllers/engine/index.ts rename to templates/types/streaming/express/src/controllers/engine/chat.ts diff --git a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts index 6ddd8eae68bc199188d07a0af8f27a12b2a6abb3..f0c9d80cc056a0a81d2a9ca8fbcb893e8af6c5c7 100644 --- a/templates/types/streaming/express/src/controllers/llamaindex-stream.ts +++ b/templates/types/streaming/express/src/controllers/llamaindex-stream.ts @@ -6,7 +6,7 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; -import { Response } from "llamaindex"; +import { Response, StreamingAgentChatResponse } from "llamaindex"; type ParserOptions = { image_url?: string; @@ -52,13 +52,17 @@ function createParser( } export function LlamaIndexStream( - res: AsyncIterable<Response>, + response: StreamingAgentChatResponse | AsyncIterable<Response>, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; }, ): { stream: ReadableStream; data: experimental_StreamData } { const data = new experimental_StreamData(); + const res = + response instanceof StreamingAgentChatResponse + ? response.response + : response; return { stream: createParser(res, data, opts?.parserOptions) .pipeThrough(createCallbacksTransformer(opts?.callbacks)) diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/index.ts b/templates/types/streaming/nextjs/app/api/chat/engine/chat.ts similarity index 100% rename from templates/types/streaming/nextjs/app/api/chat/engine/index.ts rename to templates/types/streaming/nextjs/app/api/chat/engine/chat.ts diff --git a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts index 6ddd8eae68bc199188d07a0af8f27a12b2a6abb3..f0c9d80cc056a0a81d2a9ca8fbcb893e8af6c5c7 100644 --- a/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts +++ b/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts @@ -6,7 +6,7 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; -import { Response } from "llamaindex"; +import { Response, StreamingAgentChatResponse } from "llamaindex"; type ParserOptions = { image_url?: string; @@ -52,13 +52,17 @@ function createParser( } export function LlamaIndexStream( - res: AsyncIterable<Response>, + response: StreamingAgentChatResponse | AsyncIterable<Response>, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; }, ): { stream: ReadableStream; data: experimental_StreamData } { const data = new experimental_StreamData(); + const res = + response instanceof StreamingAgentChatResponse + ? response.response + : response; return { stream: createParser(res, data, opts?.parserOptions) .pipeThrough(createCallbacksTransformer(opts?.callbacks)) diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index 32b9bb163d19e158e186e2648347da54c02d5fe5..484262f2c3f97cad8754a90ee02f2a071e5cf0c4 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -2,7 +2,7 @@ import { initObservability } from "@/app/observability"; import { StreamingTextResponse } from "ai"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; -import { createChatEngine } from "./engine"; +import { createChatEngine } from "./engine/chat"; import { LlamaIndexStream } from "./llamaindex-stream"; initObservability();