diff --git a/create-app.ts b/create-app.ts index 29e2d837fbce10ea06ecaed8e769633a315d164b..f37c4c40ad12dfe94443d54619a8c15ace24586c 100644 --- a/create-app.ts +++ b/create-app.ts @@ -8,18 +8,24 @@ import { getOnline } from "./helpers/is-online"; import { isWriteable } from "./helpers/is-writeable"; import { makeDir } from "./helpers/make-dir"; -import type { TemplateFramework, TemplateType } from "./templates"; +import type { + TemplateEngine, + TemplateFramework, + TemplateType, +} from "./templates"; import { installTemplate } from "./templates"; export async function createApp({ template, framework, + engine, appPath, packageManager, eslint, }: { template: TemplateType; framework: TemplateFramework; + engine: TemplateEngine; appPath: string; packageManager: PackageManager; eslint: boolean; @@ -51,15 +57,12 @@ export async function createApp({ process.chdir(root); - /** - * If an example repository is not provided for cloning, proceed - * by installing from a template. - */ await installTemplate({ appName, root, template, framework, + engine, packageManager, isOnline, eslint, diff --git a/index.ts b/index.ts index ad53b6d43c6943777d30d1e5a6971e7f76138884..5131e9227a48d7d420236281919267d1a421f23a 100644 --- a/index.ts +++ b/index.ts @@ -179,6 +179,7 @@ async function run(): Promise<void> { const defaults: typeof preferences = { template: "simple", framework: "nextjs", + engine: "simple", eslint: true, }; const getPrefOrDefault = (field: string) => @@ -194,10 +195,10 @@ async function run(): Promise<void> { name: "template", message: "Which template would you like to use?", choices: [ - { title: "Simple chat without streaming", value: "simple" }, - { title: "Simple chat with streaming", value: "streaming" }, + { title: "Chat without streaming", value: "simple" }, + { title: "Chat with streaming", value: "streaming" }, ], - initial: 0, + initial: 1, }, { onCancel: () => { @@ -238,6 +239,33 @@ async function run(): Promise<void> { } } + if (!program.engine) { + if (ciInfo.isCI) { + program.engine = getPrefOrDefault("engine"); + } else { + const { engine } = await prompts( + { + type: "select", + name: "engine", + message: "Which chat engine would you like to use?", + choices: [ + { title: "SimpleChatEngine", value: "simple" }, + { title: "ContextChatEngine", value: "context" }, + ], + initial: 0, + }, + { + onCancel: () => { + console.error("Exiting."); + process.exit(1); + }, + }, + ); + program.engine = engine; + preferences.engine = engine; + } + } + if ( !process.argv.includes("--eslint") && !process.argv.includes("--no-eslint") @@ -263,6 +291,7 @@ async function run(): Promise<void> { await createApp({ template: program.template, framework: program.framework, + engine: program.engine, appPath: resolvedProjectPath, packageManager, eslint: program.eslint, diff --git a/templates/engines/context/constants.mjs b/templates/engines/context/constants.mjs new file mode 100644 index 0000000000000000000000000000000000000000..8cfb403c3790b0c5088a802071896d2fdca98ded --- /dev/null +++ b/templates/engines/context/constants.mjs @@ -0,0 +1,4 @@ +export const STORAGE_DIR = "./data"; +export const STORAGE_CACHE_DIR = "./cache"; +export const CHUNK_SIZE = 512; +export const CHUNK_OVERLAP = 20; diff --git a/templates/engines/context/generate.mjs b/templates/engines/context/generate.mjs new file mode 100644 index 0000000000000000000000000000000000000000..8420dd5f81eab52574ec51770cf0e09ccb139563 --- /dev/null +++ b/templates/engines/context/generate.mjs @@ -0,0 +1,48 @@ +import { + serviceContextFromDefaults, + SimpleDirectoryReader, + storageContextFromDefaults, + VectorStoreIndex, +} from "llamaindex"; + +import { + CHUNK_OVERLAP, + CHUNK_SIZE, + STORAGE_CACHE_DIR, + STORAGE_DIR, +} from "./constants.mjs"; + +async function getRuntime(func) { + const start = Date.now(); + await func(); + const end = Date.now(); + return end - start; +} + +async function generateDatasource(serviceContext) { + console.log(`Generating storage context...`); + // Split documents, create embeddings and store them in the storage context + const ms = await getRuntime(async () => { + const storageContext = await storageContextFromDefaults({ + persistDir: STORAGE_CACHE_DIR, + }); + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: STORAGE_DIR, + }); + await VectorStoreIndex.fromDocuments(documents, { + storageContext, + serviceContext, + }); + }); + console.log(`Storage context successfully generated in ${ms / 1000}s.`); +} + +(async () => { + const serviceContext = serviceContextFromDefaults({ + chunkSize: CHUNK_SIZE, + chunkOverlap: CHUNK_OVERLAP, + }); + + await generateDatasource(serviceContext); + console.log("Finished generating storage."); +})(); diff --git a/templates/engines/context/index.ts b/templates/engines/context/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..cdd93809dcd7d9939aba8798fdac3639151e4ac5 --- /dev/null +++ b/templates/engines/context/index.ts @@ -0,0 +1,44 @@ +import { + ContextChatEngine, + LLM, + serviceContextFromDefaults, + SimpleDocumentStore, + storageContextFromDefaults, + VectorStoreIndex, +} from "llamaindex"; +import { CHUNK_OVERLAP, CHUNK_SIZE, STORAGE_CACHE_DIR } from "./constants.mjs"; + +async function getDataSource(llm: LLM) { + const serviceContext = serviceContextFromDefaults({ + llm, + chunkSize: CHUNK_SIZE, + chunkOverlap: CHUNK_OVERLAP, + }); + let storageContext = await storageContextFromDefaults({ + persistDir: `${STORAGE_CACHE_DIR}`, + }); + + const numberOfDocs = Object.keys( + (storageContext.docStore as SimpleDocumentStore).toDict(), + ).length; + if (numberOfDocs === 0) { + throw new Error( + `StorageContext is empty - call 'npm run generate' to generate the storage first`, + ); + } + return await VectorStoreIndex.init({ + storageContext, + serviceContext, + }); +} + +export async function createChatEngine(llm: LLM) { + const index = await getDataSource(llm); + const retriever = index.asRetriever(); + retriever.similarityTopK = 5; + + return new ContextChatEngine({ + chatModel: llm, + retriever, + }); +} diff --git a/templates/engines/simple/index.ts b/templates/engines/simple/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..abb02e90cd2ce91096791bf10c4665afcbe11d38 --- /dev/null +++ b/templates/engines/simple/index.ts @@ -0,0 +1,7 @@ +import { LLM, SimpleChatEngine } from "llamaindex"; + +export async function createChatEngine(llm: LLM) { + return new SimpleChatEngine({ + llm, + }); +} diff --git a/templates/index.ts b/templates/index.ts index a6cb26d61d6750866e3cfb39827042b225aec0c1..28eb7ef42edd5c312c19fecc0dd0421c47da2d51 100644 --- a/templates/index.ts +++ b/templates/index.ts @@ -19,6 +19,7 @@ export const installTemplate = async ({ isOnline, template, framework, + engine, eslint, }: InstallTemplateArgs) => { console.log(bold(`Using ${packageManager}.`)); @@ -52,6 +53,24 @@ export const installTemplate = async ({ }, }); + /** + * Copy the selected chat engine files to the target directory and reference it. + */ + console.log("\nUsing chat engine:", engine, "\n"); + const enginePath = path.join(__dirname, "engines", engine); + const engineDestPath = path.join(root, "app", "api", "chat", "engine"); + await copy("**", engineDestPath, { + parents: true, + cwd: enginePath, + }); + const routeFile = path.join(engineDestPath, "..", "route.ts"); + const routeFileContent = await fs.readFile(routeFile, "utf8"); + const newContent = routeFileContent.replace( + /^import { createChatEngine }.*$/m, + 'import { createChatEngine } from "./engine"\n', + ); + await fs.writeFile(routeFile, newContent); + /** * Update the package.json scripts. */ @@ -67,6 +86,14 @@ export const installTemplate = async ({ llamaindex: version, }; + if (engine === "context") { + // add generate script if using context engine + packageJson.scripts = { + ...packageJson.scripts, + generate: "node ./app/api/chat/engine/generate.mjs", + }; + } + if (!eslint) { // Remove packages starting with "eslint" from devDependencies packageJson.devDependencies = Object.fromEntries( diff --git a/templates/simple/nextjs/app/api/chat/route.ts b/templates/simple/nextjs/app/api/chat/route.ts index 8647e7043abcd298d3c7fb2d90dfb9659432428c..1d98c832e44c56c645cffce96fa638d9c9522f51 100644 --- a/templates/simple/nextjs/app/api/chat/route.ts +++ b/templates/simple/nextjs/app/api/chat/route.ts @@ -1,5 +1,6 @@ -import { ChatMessage, OpenAI, SimpleChatEngine } from "llamaindex"; +import { ChatMessage, OpenAI } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; +import { createChatEngine } from "../../../../../engines/context"; export const runtime = "nodejs"; export const dynamic = "force-dynamic"; @@ -23,9 +24,7 @@ export async function POST(request: NextRequest) { model: "gpt-3.5-turbo", }); - const chatEngine = new SimpleChatEngine({ - llm, - }); + const chatEngine = await createChatEngine(llm); const response = await chatEngine.chat(lastMessage.content, messages); const result: ChatMessage = { diff --git a/templates/streaming/nextjs/app/api/chat/route.ts b/templates/streaming/nextjs/app/api/chat/route.ts index fb54dbc803da9de5ddb3c2dd72fa073acb01e47f..461f3118c6567510f9c59f8b4d010cda0f388d52 100644 --- a/templates/streaming/nextjs/app/api/chat/route.ts +++ b/templates/streaming/nextjs/app/api/chat/route.ts @@ -1,6 +1,7 @@ import { Message, StreamingTextResponse } from "ai"; -import { OpenAI, SimpleChatEngine } from "llamaindex"; +import { OpenAI } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; +import { createChatEngine } from "../../../../../engines/context"; import { LlamaIndexStream } from "./llamaindex-stream"; export const runtime = "nodejs"; @@ -25,9 +26,7 @@ export async function POST(request: NextRequest) { model: "gpt-3.5-turbo", }); - const chatEngine = new SimpleChatEngine({ - llm, - }); + const chatEngine = await createChatEngine(llm); const response = await chatEngine.chat(lastMessage.content, messages, true); diff --git a/templates/types.ts b/templates/types.ts index 2847856898886016f689bbeda0a3ffc4479ebb17..4c9ad3c6538353d3ca1d74eb98f541d0ef6f9b34 100644 --- a/templates/types.ts +++ b/templates/types.ts @@ -2,6 +2,7 @@ import { PackageManager } from "../helpers/get-pkg-manager"; export type TemplateType = "simple" | "streaming"; export type TemplateFramework = "nextjs" | "express"; +export type TemplateEngine = "simple" | "context"; export interface InstallTemplateArgs { appName: string; @@ -10,5 +11,6 @@ export interface InstallTemplateArgs { isOnline: boolean; template: TemplateType; framework: TemplateFramework; + engine: TemplateEngine; eslint: boolean; }