From 02a0f5e96c7de027f786777b2dc6717467364eb1 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Fri, 15 Dec 2023 15:49:00 +0700 Subject: [PATCH] Feat: Add vector DB to create-llama (starting with MongoDB) (#279) * feat: add selection for vector DB * feat: add mongo datasource * fix: remove not implemented vector dbs --------- Co-authored-by: Thuc Pham <51660321+thucpn@users.noreply.github.com> --- packages/create-llama/create-app.ts | 2 + packages/create-llama/index.ts | 1 + packages/create-llama/questions.ts | 32 +++++++- .../components/vectordbs/mongo/generate.mjs | 49 +++++++++++++ .../components/vectordbs/mongo/index.ts | 37 ++++++++++ .../components/vectordbs/mongo/shared.mjs | 27 +++++++ .../context => vectordbs/none}/constants.mjs | 0 .../context => vectordbs/none}/generate.mjs | 0 .../context => vectordbs/none}/index.ts | 0 packages/create-llama/templates/index.ts | 73 ++++++++++++++----- packages/create-llama/templates/types.ts | 2 + 11 files changed, 201 insertions(+), 22 deletions(-) create mode 100644 packages/create-llama/templates/components/vectordbs/mongo/generate.mjs create mode 100644 packages/create-llama/templates/components/vectordbs/mongo/index.ts create mode 100644 packages/create-llama/templates/components/vectordbs/mongo/shared.mjs rename packages/create-llama/templates/components/{engines/context => vectordbs/none}/constants.mjs (100%) rename packages/create-llama/templates/components/{engines/context => vectordbs/none}/generate.mjs (100%) rename packages/create-llama/templates/components/{engines/context => vectordbs/none}/index.ts (100%) diff --git a/packages/create-llama/create-app.ts b/packages/create-llama/create-app.ts index cdaa6dbc4..d835af8ac 100644 --- a/packages/create-llama/create-app.ts +++ b/packages/create-llama/create-app.ts @@ -32,6 +32,7 @@ export async function createApp({ openAiKey, model, communityProjectPath, + vectorDb, }: InstallAppArgs): Promise<void> { const root = path.resolve(appPath); @@ -71,6 +72,7 @@ export async function createApp({ openAiKey, model, communityProjectPath, + vectorDb, }; if (frontend) { diff --git a/packages/create-llama/index.ts b/packages/create-llama/index.ts index 764be112f..12d06f3d5 100644 --- a/packages/create-llama/index.ts +++ b/packages/create-llama/index.ts @@ -209,6 +209,7 @@ async function run(): Promise<void> { openAiKey: program.openAiKey, model: program.model, communityProjectPath: program.communityProjectPath, + vectorDb: program.vectorDb, }); conf.set("preferences", preferences); } diff --git a/packages/create-llama/questions.ts b/packages/create-llama/questions.ts index 5b0f098ad..109ac9fd8 100644 --- a/packages/create-llama/questions.ts +++ b/packages/create-llama/questions.ts @@ -227,15 +227,15 @@ export const askQuestions = async ( { type: "select", name: "engine", - message: "Which chat engine would you like to use?", + message: "Which data source would you like to use?", choices: [ - { title: "ContextChatEngine", value: "context" }, { - title: "SimpleChatEngine (no data, just chat)", + title: "No data, just a simple chat", value: "simple", }, + { title: "Use an example PDF", value: "context" }, ], - initial: 0, + initial: 1, }, handlers, ); @@ -243,6 +243,30 @@ export const askQuestions = async ( preferences.engine = engine; } } + if (program.engine !== "simple" && !program.vectorDb) { + if (ciInfo.isCI) { + program.vectorDb = getPrefOrDefault("vectorDb"); + } else { + const { vectorDb } = await prompts( + { + type: "select", + name: "vectorDb", + message: "Would you like to use a vector database?", + choices: [ + { + title: "No, just store the data in the file system", + value: "none", + }, + { title: "MongoDB", value: "mongo" }, + ], + initial: 0, + }, + handlers, + ); + program.vectorDb = vectorDb; + preferences.vectorDb = vectorDb; + } + } } if (!program.openAiKey) { diff --git a/packages/create-llama/templates/components/vectordbs/mongo/generate.mjs b/packages/create-llama/templates/components/vectordbs/mongo/generate.mjs new file mode 100644 index 000000000..e7751e2ed --- /dev/null +++ b/packages/create-llama/templates/components/vectordbs/mongo/generate.mjs @@ -0,0 +1,49 @@ +/* eslint-disable turbo/no-undeclared-env-vars */ +import * as dotenv from "dotenv"; +import { + MongoDBAtlasVectorSearch, + SimpleDirectoryReader, + VectorStoreIndex, + storageContextFromDefaults, +} from "llamaindex"; +import { MongoClient } from "mongodb"; +import { STORAGE_DIR, checkRequiredEnvVars } from "./shared.mjs"; + +dotenv.config(); + +const mongoUri = process.env.MONGODB_URI; +const databaseName = process.env.MONGODB_DATABASE; +const vectorCollectionName = process.env.MONGODB_VECTORS; +const indexName = process.env.MONGODB_VECTOR_INDEX; + +async function loadAndIndex() { + // Create a new client and connect to the server + const client = new MongoClient(mongoUri); + + // load objects from storage and convert them into LlamaIndex Document objects + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: STORAGE_DIR, + }); + + // create Atlas as a vector store + const vectorStore = new MongoDBAtlasVectorSearch({ + mongodbClient: client, + dbName: databaseName, + collectionName: vectorCollectionName, // this is where your embeddings will be stored + indexName: indexName, // this is the name of the index you will need to create + }); + + // now create an index from all the Documents and store them in Atlas + const storageContext = await storageContextFromDefaults({ vectorStore }); + await VectorStoreIndex.fromDocuments(documents, { storageContext }); + console.log( + `Successfully created embeddings in the MongoDB collection ${vectorCollectionName}.`, + ); + await client.close(); +} + +(async () => { + checkRequiredEnvVars(); + await loadAndIndex(); + console.log("Finished generating storage."); +})(); diff --git a/packages/create-llama/templates/components/vectordbs/mongo/index.ts b/packages/create-llama/templates/components/vectordbs/mongo/index.ts new file mode 100644 index 000000000..68482f87d --- /dev/null +++ b/packages/create-llama/templates/components/vectordbs/mongo/index.ts @@ -0,0 +1,37 @@ +/* eslint-disable turbo/no-undeclared-env-vars */ +import { + ContextChatEngine, + LLM, + MongoDBAtlasVectorSearch, + serviceContextFromDefaults, + VectorStoreIndex, +} from "llamaindex"; +import { MongoClient } from "mongodb"; +import { checkRequiredEnvVars, CHUNK_OVERLAP, CHUNK_SIZE } from "./shared.mjs"; + +async function getDataSource(llm: LLM) { + checkRequiredEnvVars(); + const client = new MongoClient(process.env.MONGODB_URI!); + const serviceContext = serviceContextFromDefaults({ + llm, + chunkSize: CHUNK_SIZE, + chunkOverlap: CHUNK_OVERLAP, + }); + const store = new MongoDBAtlasVectorSearch({ + mongodbClient: client, + dbName: process.env.MONGODB_DATABASE, + collectionName: process.env.MONGODB_VECTORS, + indexName: process.env.MONGODB_VECTOR_INDEX, + }); + + return await VectorStoreIndex.fromVectorStore(store, serviceContext); +} + +export async function createChatEngine(llm: LLM) { + const index = await getDataSource(llm); + const retriever = index.asRetriever({ similarityTopK: 5 }); + return new ContextChatEngine({ + chatModel: llm, + retriever, + }); +} diff --git a/packages/create-llama/templates/components/vectordbs/mongo/shared.mjs b/packages/create-llama/templates/components/vectordbs/mongo/shared.mjs new file mode 100644 index 000000000..5d45eba62 --- /dev/null +++ b/packages/create-llama/templates/components/vectordbs/mongo/shared.mjs @@ -0,0 +1,27 @@ +export const STORAGE_DIR = "./data"; +export const CHUNK_SIZE = 512; +export const CHUNK_OVERLAP = 20; + +const REQUIRED_ENV_VARS = [ + "MONGODB_URI", + "MONGODB_DATABASE", + "MONGODB_VECTORS", + "MONGODB_VECTOR_INDEX", +]; + +export function checkRequiredEnvVars() { + const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => { + return !process.env[envVar]; + }); + + if (missingEnvVars.length > 0) { + console.log( + `The following environment variables are required but missing: ${missingEnvVars.join( + ", ", + )}`, + ); + throw new Error( + `Missing environment variables: ${missingEnvVars.join(", ")}`, + ); + } +} diff --git a/packages/create-llama/templates/components/engines/context/constants.mjs b/packages/create-llama/templates/components/vectordbs/none/constants.mjs similarity index 100% rename from packages/create-llama/templates/components/engines/context/constants.mjs rename to packages/create-llama/templates/components/vectordbs/none/constants.mjs diff --git a/packages/create-llama/templates/components/engines/context/generate.mjs b/packages/create-llama/templates/components/vectordbs/none/generate.mjs similarity index 100% rename from packages/create-llama/templates/components/engines/context/generate.mjs rename to packages/create-llama/templates/components/vectordbs/none/generate.mjs diff --git a/packages/create-llama/templates/components/engines/context/index.ts b/packages/create-llama/templates/components/vectordbs/none/index.ts similarity index 100% rename from packages/create-llama/templates/components/engines/context/index.ts rename to packages/create-llama/templates/components/vectordbs/none/index.ts diff --git a/packages/create-llama/templates/index.ts b/packages/create-llama/templates/index.ts index 300078768..a1b9c46a3 100644 --- a/packages/create-llama/templates/index.ts +++ b/packages/create-llama/templates/index.ts @@ -14,16 +14,34 @@ import { InstallTemplateArgs, TemplateEngine, TemplateFramework, + TemplateVectorDB, } from "./types"; -const createEnvLocalFile = async (root: string, openAiKey?: string) => { +const createEnvLocalFile = async ( + root: string, + openAiKey?: string, + vectorDb?: TemplateVectorDB, +) => { + const envFileName = ".env"; + let content = ""; + if (openAiKey) { - const envFileName = ".env"; - await fs.writeFile( - path.join(root, envFileName), - `OPENAI_API_KEY=${openAiKey}\n`, - ); - console.log(`Created '${envFileName}' file containing OPENAI_API_KEY`); + content += `OPENAI_API_KEY=${openAiKey}\n`; + } + + switch (vectorDb) { + case "mongo": { + content += `MONGODB_URI=\n`; + content += `MONGODB_DATABASE=\n`; + content += `MONGODB_VECTORS=\n`; + content += `MONGODB_VECTOR_INDEX=\n`; + break; + } + } + + if (content) { + await fs.writeFile(path.join(root, envFileName), content); + console.log(`Created '${envFileName}' file. Please check the settings.`); } }; @@ -33,6 +51,7 @@ const copyTestData = async ( packageManager?: PackageManager, engine?: TemplateEngine, openAiKey?: string, + vectorDb?: TemplateVectorDB, ) => { if (framework === "nextjs") { // XXX: This is a hack to make the build for nextjs work with pdf-parse @@ -53,21 +72,29 @@ const copyTestData = async ( } if (packageManager && engine === "context") { - if (openAiKey || process.env["OPENAI_API_KEY"]) { + const hasOpenAiKey = openAiKey || process.env["OPENAI_API_KEY"]; + const hasVectorDb = vectorDb && vectorDb !== "none"; + const shouldRunGenerateAfterInstall = hasOpenAiKey && vectorDb === "none"; + if (shouldRunGenerateAfterInstall) { console.log( `\nRunning ${cyan( `${packageManager} run generate`, )} to generate the context data.\n`, ); await callPackageManager(packageManager, true, ["run", "generate"]); - console.log(); - } else { - console.log( - `\nAfter setting your OpenAI key, run ${cyan( - `${packageManager} run generate`, - )} to generate the context data.\n`, - ); + return console.log(); } + + const settings = []; + if (!hasOpenAiKey) settings.push("your OpenAI key"); + if (hasVectorDb) settings.push("your Vector DB environment variables"); + const generateMessage = `run ${cyan( + `${packageManager} run generate`, + )} to generate the context data.\n`; + const message = settings.length + ? `After setting ${settings.join(" and ")}, ${generateMessage}` + : generateMessage; + console.log(`\n${message}\n`); } }; @@ -104,6 +131,7 @@ const installTSTemplate = async ({ customApiPath, forBackend, model, + vectorDb, }: InstallTemplateArgs) => { console.log(bold(`Using ${packageManager}.`)); @@ -148,14 +176,22 @@ const installTSTemplate = async ({ const compPath = path.join(__dirname, "components"); if (engine && (framework === "express" || framework === "nextjs")) { console.log("\nUsing chat engine:", engine, "\n"); - const enginePath = path.join(compPath, "engines", engine); + + let vectorDBFolder: string = engine; + + if (engine !== "simple" && vectorDb) { + console.log("\nUsing vector DB:", vectorDb, "\n"); + vectorDBFolder = vectorDb; + } + + const VectorDBPath = path.join(compPath, "vectordbs", vectorDBFolder); relativeEngineDestPath = framework === "nextjs" ? path.join("app", "api", "chat") : path.join("src", "controllers"); await copy("**", path.join(root, relativeEngineDestPath, "engine"), { parents: true, - cwd: enginePath, + cwd: VectorDBPath, }); } @@ -341,7 +377,7 @@ export const installTemplate = async ( // This is a backend, so we need to copy the test data and create the env file. // Copy the environment file to the target directory. - await createEnvLocalFile(props.root, props.openAiKey); + await createEnvLocalFile(props.root, props.openAiKey, props.vectorDb); // Copy test pdf file await copyTestData( @@ -350,6 +386,7 @@ export const installTemplate = async ( props.packageManager, props.engine, props.openAiKey, + props.vectorDb, ); } }; diff --git a/packages/create-llama/templates/types.ts b/packages/create-llama/templates/types.ts index eaab3951e..4b905ef96 100644 --- a/packages/create-llama/templates/types.ts +++ b/packages/create-llama/templates/types.ts @@ -4,6 +4,7 @@ export type TemplateType = "simple" | "streaming" | "community"; export type TemplateFramework = "nextjs" | "express" | "fastapi"; export type TemplateEngine = "simple" | "context"; export type TemplateUI = "html" | "shadcn"; +export type TemplateVectorDB = "none" | "mongo"; export interface InstallTemplateArgs { appName: string; @@ -20,4 +21,5 @@ export interface InstallTemplateArgs { forBackend?: string; model: string; communityProjectPath?: string; + vectorDb?: TemplateVectorDB; } -- GitLab