Skip to content
Snippets Groups Projects
Commit 20aaf35f authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

add ContextChatEngine and generator for different chat engines

parent 151a63a1
No related branches found
No related tags found
No related merge requests found
...@@ -8,18 +8,24 @@ import { getOnline } from "./helpers/is-online"; ...@@ -8,18 +8,24 @@ import { getOnline } from "./helpers/is-online";
import { isWriteable } from "./helpers/is-writeable"; import { isWriteable } from "./helpers/is-writeable";
import { makeDir } from "./helpers/make-dir"; import { makeDir } from "./helpers/make-dir";
import type { TemplateFramework, TemplateType } from "./templates"; import type {
TemplateEngine,
TemplateFramework,
TemplateType,
} from "./templates";
import { installTemplate } from "./templates"; import { installTemplate } from "./templates";
export async function createApp({ export async function createApp({
template, template,
framework, framework,
engine,
appPath, appPath,
packageManager, packageManager,
eslint, eslint,
}: { }: {
template: TemplateType; template: TemplateType;
framework: TemplateFramework; framework: TemplateFramework;
engine: TemplateEngine;
appPath: string; appPath: string;
packageManager: PackageManager; packageManager: PackageManager;
eslint: boolean; eslint: boolean;
...@@ -51,15 +57,12 @@ export async function createApp({ ...@@ -51,15 +57,12 @@ export async function createApp({
process.chdir(root); process.chdir(root);
/**
* If an example repository is not provided for cloning, proceed
* by installing from a template.
*/
await installTemplate({ await installTemplate({
appName, appName,
root, root,
template, template,
framework, framework,
engine,
packageManager, packageManager,
isOnline, isOnline,
eslint, eslint,
......
...@@ -179,6 +179,7 @@ async function run(): Promise<void> { ...@@ -179,6 +179,7 @@ async function run(): Promise<void> {
const defaults: typeof preferences = { const defaults: typeof preferences = {
template: "simple", template: "simple",
framework: "nextjs", framework: "nextjs",
engine: "simple",
eslint: true, eslint: true,
}; };
const getPrefOrDefault = (field: string) => const getPrefOrDefault = (field: string) =>
...@@ -194,10 +195,10 @@ async function run(): Promise<void> { ...@@ -194,10 +195,10 @@ async function run(): Promise<void> {
name: "template", name: "template",
message: "Which template would you like to use?", message: "Which template would you like to use?",
choices: [ choices: [
{ title: "Simple chat without streaming", value: "simple" }, { title: "Chat without streaming", value: "simple" },
{ title: "Simple chat with streaming", value: "streaming" }, { title: "Chat with streaming", value: "streaming" },
], ],
initial: 0, initial: 1,
}, },
{ {
onCancel: () => { onCancel: () => {
...@@ -238,6 +239,33 @@ async function run(): Promise<void> { ...@@ -238,6 +239,33 @@ async function run(): Promise<void> {
} }
} }
if (!program.engine) {
if (ciInfo.isCI) {
program.engine = getPrefOrDefault("engine");
} else {
const { engine } = await prompts(
{
type: "select",
name: "engine",
message: "Which chat engine would you like to use?",
choices: [
{ title: "SimpleChatEngine", value: "simple" },
{ title: "ContextChatEngine", value: "context" },
],
initial: 0,
},
{
onCancel: () => {
console.error("Exiting.");
process.exit(1);
},
},
);
program.engine = engine;
preferences.engine = engine;
}
}
if ( if (
!process.argv.includes("--eslint") && !process.argv.includes("--eslint") &&
!process.argv.includes("--no-eslint") !process.argv.includes("--no-eslint")
...@@ -263,6 +291,7 @@ async function run(): Promise<void> { ...@@ -263,6 +291,7 @@ async function run(): Promise<void> {
await createApp({ await createApp({
template: program.template, template: program.template,
framework: program.framework, framework: program.framework,
engine: program.engine,
appPath: resolvedProjectPath, appPath: resolvedProjectPath,
packageManager, packageManager,
eslint: program.eslint, eslint: program.eslint,
......
export const STORAGE_DIR = "./data";
export const STORAGE_CACHE_DIR = "./cache";
export const CHUNK_SIZE = 512;
export const CHUNK_OVERLAP = 20;
import {
serviceContextFromDefaults,
SimpleDirectoryReader,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import {
CHUNK_OVERLAP,
CHUNK_SIZE,
STORAGE_CACHE_DIR,
STORAGE_DIR,
} from "./constants.mjs";
async function getRuntime(func) {
const start = Date.now();
await func();
const end = Date.now();
return end - start;
}
async function generateDatasource(serviceContext) {
console.log(`Generating storage context...`);
// Split documents, create embeddings and store them in the storage context
const ms = await getRuntime(async () => {
const storageContext = await storageContextFromDefaults({
persistDir: STORAGE_CACHE_DIR,
});
const documents = await new SimpleDirectoryReader().loadData({
directoryPath: STORAGE_DIR,
});
await VectorStoreIndex.fromDocuments(documents, {
storageContext,
serviceContext,
});
});
console.log(`Storage context successfully generated in ${ms / 1000}s.`);
}
(async () => {
const serviceContext = serviceContextFromDefaults({
chunkSize: CHUNK_SIZE,
chunkOverlap: CHUNK_OVERLAP,
});
await generateDatasource(serviceContext);
console.log("Finished generating storage.");
})();
import {
ContextChatEngine,
LLM,
serviceContextFromDefaults,
SimpleDocumentStore,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import { CHUNK_OVERLAP, CHUNK_SIZE, STORAGE_CACHE_DIR } from "./constants.mjs";
async function getDataSource(llm: LLM) {
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: CHUNK_SIZE,
chunkOverlap: CHUNK_OVERLAP,
});
let storageContext = await storageContextFromDefaults({
persistDir: `${STORAGE_CACHE_DIR}`,
});
const numberOfDocs = Object.keys(
(storageContext.docStore as SimpleDocumentStore).toDict(),
).length;
if (numberOfDocs === 0) {
throw new Error(
`StorageContext is empty - call 'npm run generate' to generate the storage first`,
);
}
return await VectorStoreIndex.init({
storageContext,
serviceContext,
});
}
export async function createChatEngine(llm: LLM) {
const index = await getDataSource(llm);
const retriever = index.asRetriever();
retriever.similarityTopK = 5;
return new ContextChatEngine({
chatModel: llm,
retriever,
});
}
import { LLM, SimpleChatEngine } from "llamaindex";
export async function createChatEngine(llm: LLM) {
return new SimpleChatEngine({
llm,
});
}
...@@ -19,6 +19,7 @@ export const installTemplate = async ({ ...@@ -19,6 +19,7 @@ export const installTemplate = async ({
isOnline, isOnline,
template, template,
framework, framework,
engine,
eslint, eslint,
}: InstallTemplateArgs) => { }: InstallTemplateArgs) => {
console.log(bold(`Using ${packageManager}.`)); console.log(bold(`Using ${packageManager}.`));
...@@ -52,6 +53,24 @@ export const installTemplate = async ({ ...@@ -52,6 +53,24 @@ export const installTemplate = async ({
}, },
}); });
/**
* Copy the selected chat engine files to the target directory and reference it.
*/
console.log("\nUsing chat engine:", engine, "\n");
const enginePath = path.join(__dirname, "engines", engine);
const engineDestPath = path.join(root, "app", "api", "chat", "engine");
await copy("**", engineDestPath, {
parents: true,
cwd: enginePath,
});
const routeFile = path.join(engineDestPath, "..", "route.ts");
const routeFileContent = await fs.readFile(routeFile, "utf8");
const newContent = routeFileContent.replace(
/^import { createChatEngine }.*$/m,
'import { createChatEngine } from "./engine"\n',
);
await fs.writeFile(routeFile, newContent);
/** /**
* Update the package.json scripts. * Update the package.json scripts.
*/ */
...@@ -67,6 +86,14 @@ export const installTemplate = async ({ ...@@ -67,6 +86,14 @@ export const installTemplate = async ({
llamaindex: version, llamaindex: version,
}; };
if (engine === "context") {
// add generate script if using context engine
packageJson.scripts = {
...packageJson.scripts,
generate: "node ./app/api/chat/engine/generate.mjs",
};
}
if (!eslint) { if (!eslint) {
// Remove packages starting with "eslint" from devDependencies // Remove packages starting with "eslint" from devDependencies
packageJson.devDependencies = Object.fromEntries( packageJson.devDependencies = Object.fromEntries(
......
import { ChatMessage, OpenAI, SimpleChatEngine } from "llamaindex"; import { ChatMessage, OpenAI } from "llamaindex";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "../../../../../engines/context";
export const runtime = "nodejs"; export const runtime = "nodejs";
export const dynamic = "force-dynamic"; export const dynamic = "force-dynamic";
...@@ -23,9 +24,7 @@ export async function POST(request: NextRequest) { ...@@ -23,9 +24,7 @@ export async function POST(request: NextRequest) {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
}); });
const chatEngine = new SimpleChatEngine({ const chatEngine = await createChatEngine(llm);
llm,
});
const response = await chatEngine.chat(lastMessage.content, messages); const response = await chatEngine.chat(lastMessage.content, messages);
const result: ChatMessage = { const result: ChatMessage = {
......
import { Message, StreamingTextResponse } from "ai"; import { Message, StreamingTextResponse } from "ai";
import { OpenAI, SimpleChatEngine } from "llamaindex"; import { OpenAI } from "llamaindex";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "../../../../../engines/context";
import { LlamaIndexStream } from "./llamaindex-stream"; import { LlamaIndexStream } from "./llamaindex-stream";
export const runtime = "nodejs"; export const runtime = "nodejs";
...@@ -25,9 +26,7 @@ export async function POST(request: NextRequest) { ...@@ -25,9 +26,7 @@ export async function POST(request: NextRequest) {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
}); });
const chatEngine = new SimpleChatEngine({ const chatEngine = await createChatEngine(llm);
llm,
});
const response = await chatEngine.chat(lastMessage.content, messages, true); const response = await chatEngine.chat(lastMessage.content, messages, true);
......
...@@ -2,6 +2,7 @@ import { PackageManager } from "../helpers/get-pkg-manager"; ...@@ -2,6 +2,7 @@ import { PackageManager } from "../helpers/get-pkg-manager";
export type TemplateType = "simple" | "streaming"; export type TemplateType = "simple" | "streaming";
export type TemplateFramework = "nextjs" | "express"; export type TemplateFramework = "nextjs" | "express";
export type TemplateEngine = "simple" | "context";
export interface InstallTemplateArgs { export interface InstallTemplateArgs {
appName: string; appName: string;
...@@ -10,5 +11,6 @@ export interface InstallTemplateArgs { ...@@ -10,5 +11,6 @@ export interface InstallTemplateArgs {
isOnline: boolean; isOnline: boolean;
template: TemplateType; template: TemplateType;
framework: TemplateFramework; framework: TemplateFramework;
engine: TemplateEngine;
eslint: boolean; eslint: boolean;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment