Skip to content
Snippets Groups Projects
Unverified Commit 26ab74bd authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: support agent in typescript templates (#5)

parent 17afc918
No related branches found
No related tags found
No related merge requests found
Showing
with 89 additions and 26 deletions
import { red } from "picocolors";
import { TemplateFramework } from "./types";
export type Tool = {
display: string;
name: string;
config?: Record<string, any>;
dependencies?: ToolDependencies[];
supportedFrameworks?: Array<TemplateFramework>;
};
export type ToolDependencies = {
name: string;
......@@ -27,6 +29,7 @@ export const supportedTools: Tool[] = [
version: "0.1.2",
},
],
supportedFrameworks: ["fastapi"],
},
{
display: "Wikipedia",
......@@ -37,6 +40,7 @@ export const supportedTools: Tool[] = [
version: "0.1.2",
},
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
},
];
......
......@@ -64,6 +64,7 @@ export const installTSTemplate = async ({
postInstallAction,
backend,
observability,
tools,
dataSource,
}: InstallTemplateArgs & { backend: boolean }) => {
console.log(bold(`Using ${packageManager}.`));
......@@ -186,6 +187,30 @@ export const installTSTemplate = async ({
cwd: path.join(compPath, "loaders", "typescript", loaderFolder),
});
}
// copy tools component
if (tools?.length) {
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "typescript", "agent"),
});
// Write tools_config.json
const configContent: Record<string, any> = {};
tools.forEach((tool) => {
configContent[tool.name] = tool.config ?? {};
});
const configFilePath = path.join(enginePath, "tools_config.json");
await fs.writeFile(
configFilePath,
JSON.stringify(configContent, null, 2),
);
} else if (engine !== "simple") {
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "typescript", "chat"),
});
}
}
/**
......
......@@ -805,15 +805,14 @@ export const askQuestions = async (
}
}
if (
!program.tools &&
program.framework === "fastapi" &&
program.engine === "context"
) {
if (!program.tools && program.engine === "context") {
if (ciInfo.isCI) {
program.tools = getPrefOrDefault("tools");
} else {
const toolChoices = supportedTools.map((tool) => ({
const options = supportedTools.filter((t) =>
t.supportedFrameworks?.includes(program.framework),
);
const toolChoices = options.map((tool) => ({
title: tool.display,
value: tool.name,
}));
......
import { OpenAI, OpenAIAgent, QueryEngineTool, ToolFactory } from "llamaindex";
import { STORAGE_CACHE_DIR } from "./constants.mjs";
import { getDataSource } from "./index";
import config from "./tools_config.json";
export async function createChatEngine(llm: OpenAI) {
const index = await getDataSource(llm);
const queryEngine = index.asQueryEngine();
const queryEngineTool = new QueryEngineTool({
queryEngine: queryEngine,
metadata: {
name: "data_query_engine",
description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`,
},
});
const externalTools = await ToolFactory.createTools(config);
const agent = new OpenAIAgent({
tools: [queryEngineTool, ...externalTools],
verbose: true,
llm,
});
return agent;
}
import { ContextChatEngine, LLM } from "llamaindex";
import { getDataSource } from "./index";
export async function createChatEngine(llm: LLM) {
const index = await getDataSource(llm);
const retriever = index.asRetriever();
retriever.similarityTopK = 3;
return new ContextChatEngine({
chatModel: llm,
retriever,
});
}
import {
ContextChatEngine,
LLM,
serviceContextFromDefaults,
SimpleDocumentStore,
......@@ -8,7 +7,7 @@ import {
} from "llamaindex";
import { CHUNK_OVERLAP, CHUNK_SIZE, STORAGE_CACHE_DIR } from "./constants.mjs";
async function getDataSource(llm: LLM) {
export async function getDataSource(llm: LLM) {
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: CHUNK_SIZE,
......@@ -31,14 +30,3 @@ async function getDataSource(llm: LLM) {
serviceContext,
});
}
export async function createChatEngine(llm: LLM) {
const index = await getDataSource(llm);
const retriever = index.asRetriever();
retriever.similarityTopK = 3;
return new ContextChatEngine({
chatModel: llm,
retriever,
});
}
import { Request, Response } from "express";
import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
import { createChatEngine } from "./engine";
import { createChatEngine } from "./engine/chat";
const convertMessageContent = (
textMessage: string,
......
import { streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
import { createChatEngine } from "./engine";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream } from "./llamaindex-stream";
const convertMessageContent = (
......
......@@ -6,7 +6,7 @@ import {
trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions,
} from "ai";
import { Response } from "llamaindex";
import { Response, StreamingAgentChatResponse } from "llamaindex";
type ParserOptions = {
image_url?: string;
......@@ -52,13 +52,17 @@ function createParser(
}
export function LlamaIndexStream(
res: AsyncIterable<Response>,
response: StreamingAgentChatResponse | AsyncIterable<Response>,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: experimental_StreamData } {
const data = new experimental_StreamData();
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
......
......@@ -6,7 +6,7 @@ import {
trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions,
} from "ai";
import { Response } from "llamaindex";
import { Response, StreamingAgentChatResponse } from "llamaindex";
type ParserOptions = {
image_url?: string;
......@@ -52,13 +52,17 @@ function createParser(
}
export function LlamaIndexStream(
res: AsyncIterable<Response>,
response: StreamingAgentChatResponse | AsyncIterable<Response>,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: experimental_StreamData } {
const data = new experimental_StreamData();
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
......
......@@ -2,7 +2,7 @@ import { initObservability } from "@/app/observability";
import { StreamingTextResponse } from "ai";
import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "./engine";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream } from "./llamaindex-stream";
initObservability();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment