From 02a0f5e96c7de027f786777b2dc6717467364eb1 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Fri, 15 Dec 2023 15:49:00 +0700
Subject: [PATCH] Feat: Add vector DB to create-llama (starting with MongoDB)
 (#279)

* feat: add selection for vector DB
* feat: add mongo datasource
* fix: remove not implemented vector dbs

---------

Co-authored-by: Thuc Pham <51660321+thucpn@users.noreply.github.com>
---
 packages/create-llama/create-app.ts           |  2 +
 packages/create-llama/index.ts                |  1 +
 packages/create-llama/questions.ts            | 32 +++++++-
 .../components/vectordbs/mongo/generate.mjs   | 49 +++++++++++++
 .../components/vectordbs/mongo/index.ts       | 37 ++++++++++
 .../components/vectordbs/mongo/shared.mjs     | 27 +++++++
 .../context => vectordbs/none}/constants.mjs  |  0
 .../context => vectordbs/none}/generate.mjs   |  0
 .../context => vectordbs/none}/index.ts       |  0
 packages/create-llama/templates/index.ts      | 73 ++++++++++++++-----
 packages/create-llama/templates/types.ts      |  2 +
 11 files changed, 201 insertions(+), 22 deletions(-)
 create mode 100644 packages/create-llama/templates/components/vectordbs/mongo/generate.mjs
 create mode 100644 packages/create-llama/templates/components/vectordbs/mongo/index.ts
 create mode 100644 packages/create-llama/templates/components/vectordbs/mongo/shared.mjs
 rename packages/create-llama/templates/components/{engines/context => vectordbs/none}/constants.mjs (100%)
 rename packages/create-llama/templates/components/{engines/context => vectordbs/none}/generate.mjs (100%)
 rename packages/create-llama/templates/components/{engines/context => vectordbs/none}/index.ts (100%)

diff --git a/packages/create-llama/create-app.ts b/packages/create-llama/create-app.ts
index cdaa6dbc4..d835af8ac 100644
--- a/packages/create-llama/create-app.ts
+++ b/packages/create-llama/create-app.ts
@@ -32,6 +32,7 @@ export async function createApp({
   openAiKey,
   model,
   communityProjectPath,
+  vectorDb,
 }: InstallAppArgs): Promise<void> {
   const root = path.resolve(appPath);
 
@@ -71,6 +72,7 @@ export async function createApp({
     openAiKey,
     model,
     communityProjectPath,
+    vectorDb,
   };
 
   if (frontend) {
diff --git a/packages/create-llama/index.ts b/packages/create-llama/index.ts
index 764be112f..12d06f3d5 100644
--- a/packages/create-llama/index.ts
+++ b/packages/create-llama/index.ts
@@ -209,6 +209,7 @@ async function run(): Promise<void> {
     openAiKey: program.openAiKey,
     model: program.model,
     communityProjectPath: program.communityProjectPath,
+    vectorDb: program.vectorDb,
   });
   conf.set("preferences", preferences);
 }
diff --git a/packages/create-llama/questions.ts b/packages/create-llama/questions.ts
index 5b0f098ad..109ac9fd8 100644
--- a/packages/create-llama/questions.ts
+++ b/packages/create-llama/questions.ts
@@ -227,15 +227,15 @@ export const askQuestions = async (
           {
             type: "select",
             name: "engine",
-            message: "Which chat engine would you like to use?",
+            message: "Which data source would you like to use?",
             choices: [
-              { title: "ContextChatEngine", value: "context" },
               {
-                title: "SimpleChatEngine (no data, just chat)",
+                title: "No data, just a simple chat",
                 value: "simple",
               },
+              { title: "Use an example PDF", value: "context" },
             ],
-            initial: 0,
+            initial: 1,
           },
           handlers,
         );
@@ -243,6 +243,30 @@ export const askQuestions = async (
         preferences.engine = engine;
       }
     }
+    if (program.engine !== "simple" && !program.vectorDb) {
+      if (ciInfo.isCI) {
+        program.vectorDb = getPrefOrDefault("vectorDb");
+      } else {
+        const { vectorDb } = await prompts(
+          {
+            type: "select",
+            name: "vectorDb",
+            message: "Would you like to use a vector database?",
+            choices: [
+              {
+                title: "No, just store the data in the file system",
+                value: "none",
+              },
+              { title: "MongoDB", value: "mongo" },
+            ],
+            initial: 0,
+          },
+          handlers,
+        );
+        program.vectorDb = vectorDb;
+        preferences.vectorDb = vectorDb;
+      }
+    }
   }
 
   if (!program.openAiKey) {
diff --git a/packages/create-llama/templates/components/vectordbs/mongo/generate.mjs b/packages/create-llama/templates/components/vectordbs/mongo/generate.mjs
new file mode 100644
index 000000000..e7751e2ed
--- /dev/null
+++ b/packages/create-llama/templates/components/vectordbs/mongo/generate.mjs
@@ -0,0 +1,49 @@
+/* eslint-disable turbo/no-undeclared-env-vars */
+import * as dotenv from "dotenv";
+import {
+  MongoDBAtlasVectorSearch,
+  SimpleDirectoryReader,
+  VectorStoreIndex,
+  storageContextFromDefaults,
+} from "llamaindex";
+import { MongoClient } from "mongodb";
+import { STORAGE_DIR, checkRequiredEnvVars } from "./shared.mjs";
+
+dotenv.config();
+
+const mongoUri = process.env.MONGODB_URI;
+const databaseName = process.env.MONGODB_DATABASE;
+const vectorCollectionName = process.env.MONGODB_VECTORS;
+const indexName = process.env.MONGODB_VECTOR_INDEX;
+
+async function loadAndIndex() {
+  // Create a new client and connect to the server
+  const client = new MongoClient(mongoUri);
+
+  // load objects from storage and convert them into LlamaIndex Document objects
+  const documents = await new SimpleDirectoryReader().loadData({
+    directoryPath: STORAGE_DIR,
+  });
+
+  // create Atlas as a vector store
+  const vectorStore = new MongoDBAtlasVectorSearch({
+    mongodbClient: client,
+    dbName: databaseName,
+    collectionName: vectorCollectionName, // this is where your embeddings will be stored
+    indexName: indexName, // this is the name of the index you will need to create
+  });
+
+  // now create an index from all the Documents and store them in Atlas
+  const storageContext = await storageContextFromDefaults({ vectorStore });
+  await VectorStoreIndex.fromDocuments(documents, { storageContext });
+  console.log(
+    `Successfully created embeddings in the MongoDB collection ${vectorCollectionName}.`,
+  );
+  await client.close();
+}
+
+(async () => {
+  checkRequiredEnvVars();
+  await loadAndIndex();
+  console.log("Finished generating storage.");
+})();
diff --git a/packages/create-llama/templates/components/vectordbs/mongo/index.ts b/packages/create-llama/templates/components/vectordbs/mongo/index.ts
new file mode 100644
index 000000000..68482f87d
--- /dev/null
+++ b/packages/create-llama/templates/components/vectordbs/mongo/index.ts
@@ -0,0 +1,37 @@
+/* eslint-disable turbo/no-undeclared-env-vars */
+import {
+  ContextChatEngine,
+  LLM,
+  MongoDBAtlasVectorSearch,
+  serviceContextFromDefaults,
+  VectorStoreIndex,
+} from "llamaindex";
+import { MongoClient } from "mongodb";
+import { checkRequiredEnvVars, CHUNK_OVERLAP, CHUNK_SIZE } from "./shared.mjs";
+
+async function getDataSource(llm: LLM) {
+  checkRequiredEnvVars();
+  const client = new MongoClient(process.env.MONGODB_URI!);
+  const serviceContext = serviceContextFromDefaults({
+    llm,
+    chunkSize: CHUNK_SIZE,
+    chunkOverlap: CHUNK_OVERLAP,
+  });
+  const store = new MongoDBAtlasVectorSearch({
+    mongodbClient: client,
+    dbName: process.env.MONGODB_DATABASE,
+    collectionName: process.env.MONGODB_VECTORS,
+    indexName: process.env.MONGODB_VECTOR_INDEX,
+  });
+
+  return await VectorStoreIndex.fromVectorStore(store, serviceContext);
+}
+
+export async function createChatEngine(llm: LLM) {
+  const index = await getDataSource(llm);
+  const retriever = index.asRetriever({ similarityTopK: 5 });
+  return new ContextChatEngine({
+    chatModel: llm,
+    retriever,
+  });
+}
diff --git a/packages/create-llama/templates/components/vectordbs/mongo/shared.mjs b/packages/create-llama/templates/components/vectordbs/mongo/shared.mjs
new file mode 100644
index 000000000..5d45eba62
--- /dev/null
+++ b/packages/create-llama/templates/components/vectordbs/mongo/shared.mjs
@@ -0,0 +1,27 @@
+export const STORAGE_DIR = "./data";
+export const CHUNK_SIZE = 512;
+export const CHUNK_OVERLAP = 20;
+
+const REQUIRED_ENV_VARS = [
+  "MONGODB_URI",
+  "MONGODB_DATABASE",
+  "MONGODB_VECTORS",
+  "MONGODB_VECTOR_INDEX",
+];
+
+export function checkRequiredEnvVars() {
+  const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
+    return !process.env[envVar];
+  });
+
+  if (missingEnvVars.length > 0) {
+    console.log(
+      `The following environment variables are required but missing: ${missingEnvVars.join(
+        ", ",
+      )}`,
+    );
+    throw new Error(
+      `Missing environment variables: ${missingEnvVars.join(", ")}`,
+    );
+  }
+}
diff --git a/packages/create-llama/templates/components/engines/context/constants.mjs b/packages/create-llama/templates/components/vectordbs/none/constants.mjs
similarity index 100%
rename from packages/create-llama/templates/components/engines/context/constants.mjs
rename to packages/create-llama/templates/components/vectordbs/none/constants.mjs
diff --git a/packages/create-llama/templates/components/engines/context/generate.mjs b/packages/create-llama/templates/components/vectordbs/none/generate.mjs
similarity index 100%
rename from packages/create-llama/templates/components/engines/context/generate.mjs
rename to packages/create-llama/templates/components/vectordbs/none/generate.mjs
diff --git a/packages/create-llama/templates/components/engines/context/index.ts b/packages/create-llama/templates/components/vectordbs/none/index.ts
similarity index 100%
rename from packages/create-llama/templates/components/engines/context/index.ts
rename to packages/create-llama/templates/components/vectordbs/none/index.ts
diff --git a/packages/create-llama/templates/index.ts b/packages/create-llama/templates/index.ts
index 300078768..a1b9c46a3 100644
--- a/packages/create-llama/templates/index.ts
+++ b/packages/create-llama/templates/index.ts
@@ -14,16 +14,34 @@ import {
   InstallTemplateArgs,
   TemplateEngine,
   TemplateFramework,
+  TemplateVectorDB,
 } from "./types";
 
-const createEnvLocalFile = async (root: string, openAiKey?: string) => {
+const createEnvLocalFile = async (
+  root: string,
+  openAiKey?: string,
+  vectorDb?: TemplateVectorDB,
+) => {
+  const envFileName = ".env";
+  let content = "";
+
   if (openAiKey) {
-    const envFileName = ".env";
-    await fs.writeFile(
-      path.join(root, envFileName),
-      `OPENAI_API_KEY=${openAiKey}\n`,
-    );
-    console.log(`Created '${envFileName}' file containing OPENAI_API_KEY`);
+    content += `OPENAI_API_KEY=${openAiKey}\n`;
+  }
+
+  switch (vectorDb) {
+    case "mongo": {
+      content += `MONGODB_URI=\n`;
+      content += `MONGODB_DATABASE=\n`;
+      content += `MONGODB_VECTORS=\n`;
+      content += `MONGODB_VECTOR_INDEX=\n`;
+      break;
+    }
+  }
+
+  if (content) {
+    await fs.writeFile(path.join(root, envFileName), content);
+    console.log(`Created '${envFileName}' file. Please check the settings.`);
   }
 };
 
@@ -33,6 +51,7 @@ const copyTestData = async (
   packageManager?: PackageManager,
   engine?: TemplateEngine,
   openAiKey?: string,
+  vectorDb?: TemplateVectorDB,
 ) => {
   if (framework === "nextjs") {
     // XXX: This is a hack to make the build for nextjs work with pdf-parse
@@ -53,21 +72,29 @@ const copyTestData = async (
   }
 
   if (packageManager && engine === "context") {
-    if (openAiKey || process.env["OPENAI_API_KEY"]) {
+    const hasOpenAiKey = openAiKey || process.env["OPENAI_API_KEY"];
+    const hasVectorDb = vectorDb && vectorDb !== "none";
+    const shouldRunGenerateAfterInstall = hasOpenAiKey && vectorDb === "none";
+    if (shouldRunGenerateAfterInstall) {
       console.log(
         `\nRunning ${cyan(
           `${packageManager} run generate`,
         )} to generate the context data.\n`,
       );
       await callPackageManager(packageManager, true, ["run", "generate"]);
-      console.log();
-    } else {
-      console.log(
-        `\nAfter setting your OpenAI key, run ${cyan(
-          `${packageManager} run generate`,
-        )} to generate the context data.\n`,
-      );
+      return console.log();
     }
+
+    const settings = [];
+    if (!hasOpenAiKey) settings.push("your OpenAI key");
+    if (hasVectorDb) settings.push("your Vector DB environment variables");
+    const generateMessage = `run ${cyan(
+      `${packageManager} run generate`,
+    )} to generate the context data.\n`;
+    const message = settings.length
+      ? `After setting ${settings.join(" and ")}, ${generateMessage}`
+      : generateMessage;
+    console.log(`\n${message}\n`);
   }
 };
 
@@ -104,6 +131,7 @@ const installTSTemplate = async ({
   customApiPath,
   forBackend,
   model,
+  vectorDb,
 }: InstallTemplateArgs) => {
   console.log(bold(`Using ${packageManager}.`));
 
@@ -148,14 +176,22 @@ const installTSTemplate = async ({
   const compPath = path.join(__dirname, "components");
   if (engine && (framework === "express" || framework === "nextjs")) {
     console.log("\nUsing chat engine:", engine, "\n");
-    const enginePath = path.join(compPath, "engines", engine);
+
+    let vectorDBFolder: string = engine;
+
+    if (engine !== "simple" && vectorDb) {
+      console.log("\nUsing vector DB:", vectorDb, "\n");
+      vectorDBFolder = vectorDb;
+    }
+
+    const VectorDBPath = path.join(compPath, "vectordbs", vectorDBFolder);
     relativeEngineDestPath =
       framework === "nextjs"
         ? path.join("app", "api", "chat")
         : path.join("src", "controllers");
     await copy("**", path.join(root, relativeEngineDestPath, "engine"), {
       parents: true,
-      cwd: enginePath,
+      cwd: VectorDBPath,
     });
   }
 
@@ -341,7 +377,7 @@ export const installTemplate = async (
     // This is a backend, so we need to copy the test data and create the env file.
 
     // Copy the environment file to the target directory.
-    await createEnvLocalFile(props.root, props.openAiKey);
+    await createEnvLocalFile(props.root, props.openAiKey, props.vectorDb);
 
     // Copy test pdf file
     await copyTestData(
@@ -350,6 +386,7 @@ export const installTemplate = async (
       props.packageManager,
       props.engine,
       props.openAiKey,
+      props.vectorDb,
     );
   }
 };
diff --git a/packages/create-llama/templates/types.ts b/packages/create-llama/templates/types.ts
index eaab3951e..4b905ef96 100644
--- a/packages/create-llama/templates/types.ts
+++ b/packages/create-llama/templates/types.ts
@@ -4,6 +4,7 @@ export type TemplateType = "simple" | "streaming" | "community";
 export type TemplateFramework = "nextjs" | "express" | "fastapi";
 export type TemplateEngine = "simple" | "context";
 export type TemplateUI = "html" | "shadcn";
+export type TemplateVectorDB = "none" | "mongo";
 
 export interface InstallTemplateArgs {
   appName: string;
@@ -20,4 +21,5 @@ export interface InstallTemplateArgs {
   forBackend?: string;
   model: string;
   communityProjectPath?: string;
+  vectorDb?: TemplateVectorDB;
 }
-- 
GitLab