Skip to content
Snippets Groups Projects
Commit 7cc658ec authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

Feat: Add vector DB to create-llama (starting with MongoDB) (#279)


* feat: add selection for vector DB
* feat: add mongo datasource 
* fix: remove not implemented vector dbs

---------

Co-authored-by: default avatarThuc Pham <51660321+thucpn@users.noreply.github.com>
parent 7c947022
No related branches found
No related tags found
No related merge requests found
...@@ -32,6 +32,7 @@ export async function createApp({ ...@@ -32,6 +32,7 @@ export async function createApp({
openAiKey, openAiKey,
model, model,
communityProjectPath, communityProjectPath,
vectorDb,
}: InstallAppArgs): Promise<void> { }: InstallAppArgs): Promise<void> {
const root = path.resolve(appPath); const root = path.resolve(appPath);
...@@ -71,6 +72,7 @@ export async function createApp({ ...@@ -71,6 +72,7 @@ export async function createApp({
openAiKey, openAiKey,
model, model,
communityProjectPath, communityProjectPath,
vectorDb,
}; };
if (frontend) { if (frontend) {
......
...@@ -209,6 +209,7 @@ async function run(): Promise<void> { ...@@ -209,6 +209,7 @@ async function run(): Promise<void> {
openAiKey: program.openAiKey, openAiKey: program.openAiKey,
model: program.model, model: program.model,
communityProjectPath: program.communityProjectPath, communityProjectPath: program.communityProjectPath,
vectorDb: program.vectorDb,
}); });
conf.set("preferences", preferences); conf.set("preferences", preferences);
} }
......
...@@ -227,15 +227,15 @@ export const askQuestions = async ( ...@@ -227,15 +227,15 @@ export const askQuestions = async (
{ {
type: "select", type: "select",
name: "engine", name: "engine",
message: "Which chat engine would you like to use?", message: "Which data source would you like to use?",
choices: [ choices: [
{ title: "ContextChatEngine", value: "context" },
{ {
title: "SimpleChatEngine (no data, just chat)", title: "No data, just a simple chat",
value: "simple", value: "simple",
}, },
{ title: "Use an example PDF", value: "context" },
], ],
initial: 0, initial: 1,
}, },
handlers, handlers,
); );
...@@ -243,6 +243,30 @@ export const askQuestions = async ( ...@@ -243,6 +243,30 @@ export const askQuestions = async (
preferences.engine = engine; preferences.engine = engine;
} }
} }
if (program.engine !== "simple" && !program.vectorDb) {
if (ciInfo.isCI) {
program.vectorDb = getPrefOrDefault("vectorDb");
} else {
const { vectorDb } = await prompts(
{
type: "select",
name: "vectorDb",
message: "Would you like to use a vector database?",
choices: [
{
title: "No, just store the data in the file system",
value: "none",
},
{ title: "MongoDB", value: "mongo" },
],
initial: 0,
},
handlers,
);
program.vectorDb = vectorDb;
preferences.vectorDb = vectorDb;
}
}
} }
if (!program.openAiKey) { if (!program.openAiKey) {
......
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
MongoDBAtlasVectorSearch,
SimpleDirectoryReader,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { MongoClient } from "mongodb";
import { STORAGE_DIR, checkRequiredEnvVars } from "./shared.mjs";
dotenv.config();
const mongoUri = process.env.MONGODB_URI;
const databaseName = process.env.MONGODB_DATABASE;
const vectorCollectionName = process.env.MONGODB_VECTORS;
const indexName = process.env.MONGODB_VECTOR_INDEX;
async function loadAndIndex() {
// Create a new client and connect to the server
const client = new MongoClient(mongoUri);
// load objects from storage and convert them into LlamaIndex Document objects
const documents = await new SimpleDirectoryReader().loadData({
directoryPath: STORAGE_DIR,
});
// create Atlas as a vector store
const vectorStore = new MongoDBAtlasVectorSearch({
mongodbClient: client,
dbName: databaseName,
collectionName: vectorCollectionName, // this is where your embeddings will be stored
indexName: indexName, // this is the name of the index you will need to create
});
// now create an index from all the Documents and store them in Atlas
const storageContext = await storageContextFromDefaults({ vectorStore });
await VectorStoreIndex.fromDocuments(documents, { storageContext });
console.log(
`Successfully created embeddings in the MongoDB collection ${vectorCollectionName}.`,
);
await client.close();
}
(async () => {
checkRequiredEnvVars();
await loadAndIndex();
console.log("Finished generating storage.");
})();
/* eslint-disable turbo/no-undeclared-env-vars */
import {
ContextChatEngine,
LLM,
MongoDBAtlasVectorSearch,
serviceContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import { MongoClient } from "mongodb";
import { checkRequiredEnvVars, CHUNK_OVERLAP, CHUNK_SIZE } from "./shared.mjs";
async function getDataSource(llm: LLM) {
checkRequiredEnvVars();
const client = new MongoClient(process.env.MONGODB_URI!);
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: CHUNK_SIZE,
chunkOverlap: CHUNK_OVERLAP,
});
const store = new MongoDBAtlasVectorSearch({
mongodbClient: client,
dbName: process.env.MONGODB_DATABASE,
collectionName: process.env.MONGODB_VECTORS,
indexName: process.env.MONGODB_VECTOR_INDEX,
});
return await VectorStoreIndex.fromVectorStore(store, serviceContext);
}
export async function createChatEngine(llm: LLM) {
const index = await getDataSource(llm);
const retriever = index.asRetriever({ similarityTopK: 5 });
return new ContextChatEngine({
chatModel: llm,
retriever,
});
}
export const STORAGE_DIR = "./data";
export const CHUNK_SIZE = 512;
export const CHUNK_OVERLAP = 20;
const REQUIRED_ENV_VARS = [
"MONGODB_URI",
"MONGODB_DATABASE",
"MONGODB_VECTORS",
"MONGODB_VECTOR_INDEX",
];
export function checkRequiredEnvVars() {
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
return !process.env[envVar];
});
if (missingEnvVars.length > 0) {
console.log(
`The following environment variables are required but missing: ${missingEnvVars.join(
", ",
)}`,
);
throw new Error(
`Missing environment variables: ${missingEnvVars.join(", ")}`,
);
}
}
...@@ -14,16 +14,34 @@ import { ...@@ -14,16 +14,34 @@ import {
InstallTemplateArgs, InstallTemplateArgs,
TemplateEngine, TemplateEngine,
TemplateFramework, TemplateFramework,
TemplateVectorDB,
} from "./types"; } from "./types";
const createEnvLocalFile = async (root: string, openAiKey?: string) => { const createEnvLocalFile = async (
root: string,
openAiKey?: string,
vectorDb?: TemplateVectorDB,
) => {
const envFileName = ".env";
let content = "";
if (openAiKey) { if (openAiKey) {
const envFileName = ".env"; content += `OPENAI_API_KEY=${openAiKey}\n`;
await fs.writeFile( }
path.join(root, envFileName),
`OPENAI_API_KEY=${openAiKey}\n`, switch (vectorDb) {
); case "mongo": {
console.log(`Created '${envFileName}' file containing OPENAI_API_KEY`); content += `MONGODB_URI=\n`;
content += `MONGODB_DATABASE=\n`;
content += `MONGODB_VECTORS=\n`;
content += `MONGODB_VECTOR_INDEX=\n`;
break;
}
}
if (content) {
await fs.writeFile(path.join(root, envFileName), content);
console.log(`Created '${envFileName}' file. Please check the settings.`);
} }
}; };
...@@ -33,6 +51,7 @@ const copyTestData = async ( ...@@ -33,6 +51,7 @@ const copyTestData = async (
packageManager?: PackageManager, packageManager?: PackageManager,
engine?: TemplateEngine, engine?: TemplateEngine,
openAiKey?: string, openAiKey?: string,
vectorDb?: TemplateVectorDB,
) => { ) => {
if (framework === "nextjs") { if (framework === "nextjs") {
// XXX: This is a hack to make the build for nextjs work with pdf-parse // XXX: This is a hack to make the build for nextjs work with pdf-parse
...@@ -53,21 +72,29 @@ const copyTestData = async ( ...@@ -53,21 +72,29 @@ const copyTestData = async (
} }
if (packageManager && engine === "context") { if (packageManager && engine === "context") {
if (openAiKey || process.env["OPENAI_API_KEY"]) { const hasOpenAiKey = openAiKey || process.env["OPENAI_API_KEY"];
const hasVectorDb = vectorDb && vectorDb !== "none";
const shouldRunGenerateAfterInstall = hasOpenAiKey && vectorDb === "none";
if (shouldRunGenerateAfterInstall) {
console.log( console.log(
`\nRunning ${cyan( `\nRunning ${cyan(
`${packageManager} run generate`, `${packageManager} run generate`,
)} to generate the context data.\n`, )} to generate the context data.\n`,
); );
await callPackageManager(packageManager, true, ["run", "generate"]); await callPackageManager(packageManager, true, ["run", "generate"]);
console.log(); return console.log();
} else {
console.log(
`\nAfter setting your OpenAI key, run ${cyan(
`${packageManager} run generate`,
)} to generate the context data.\n`,
);
} }
const settings = [];
if (!hasOpenAiKey) settings.push("your OpenAI key");
if (hasVectorDb) settings.push("your Vector DB environment variables");
const generateMessage = `run ${cyan(
`${packageManager} run generate`,
)} to generate the context data.\n`;
const message = settings.length
? `After setting ${settings.join(" and ")}, ${generateMessage}`
: generateMessage;
console.log(`\n${message}\n`);
} }
}; };
...@@ -104,6 +131,7 @@ const installTSTemplate = async ({ ...@@ -104,6 +131,7 @@ const installTSTemplate = async ({
customApiPath, customApiPath,
forBackend, forBackend,
model, model,
vectorDb,
}: InstallTemplateArgs) => { }: InstallTemplateArgs) => {
console.log(bold(`Using ${packageManager}.`)); console.log(bold(`Using ${packageManager}.`));
...@@ -148,14 +176,22 @@ const installTSTemplate = async ({ ...@@ -148,14 +176,22 @@ const installTSTemplate = async ({
const compPath = path.join(__dirname, "components"); const compPath = path.join(__dirname, "components");
if (engine && (framework === "express" || framework === "nextjs")) { if (engine && (framework === "express" || framework === "nextjs")) {
console.log("\nUsing chat engine:", engine, "\n"); console.log("\nUsing chat engine:", engine, "\n");
const enginePath = path.join(compPath, "engines", engine);
let vectorDBFolder: string = engine;
if (engine !== "simple" && vectorDb) {
console.log("\nUsing vector DB:", vectorDb, "\n");
vectorDBFolder = vectorDb;
}
const VectorDBPath = path.join(compPath, "vectordbs", vectorDBFolder);
relativeEngineDestPath = relativeEngineDestPath =
framework === "nextjs" framework === "nextjs"
? path.join("app", "api", "chat") ? path.join("app", "api", "chat")
: path.join("src", "controllers"); : path.join("src", "controllers");
await copy("**", path.join(root, relativeEngineDestPath, "engine"), { await copy("**", path.join(root, relativeEngineDestPath, "engine"), {
parents: true, parents: true,
cwd: enginePath, cwd: VectorDBPath,
}); });
} }
...@@ -341,7 +377,7 @@ export const installTemplate = async ( ...@@ -341,7 +377,7 @@ export const installTemplate = async (
// This is a backend, so we need to copy the test data and create the env file. // This is a backend, so we need to copy the test data and create the env file.
// Copy the environment file to the target directory. // Copy the environment file to the target directory.
await createEnvLocalFile(props.root, props.openAiKey); await createEnvLocalFile(props.root, props.openAiKey, props.vectorDb);
// Copy test pdf file // Copy test pdf file
await copyTestData( await copyTestData(
...@@ -350,6 +386,7 @@ export const installTemplate = async ( ...@@ -350,6 +386,7 @@ export const installTemplate = async (
props.packageManager, props.packageManager,
props.engine, props.engine,
props.openAiKey, props.openAiKey,
props.vectorDb,
); );
} }
}; };
......
...@@ -4,6 +4,7 @@ export type TemplateType = "simple" | "streaming" | "community"; ...@@ -4,6 +4,7 @@ export type TemplateType = "simple" | "streaming" | "community";
export type TemplateFramework = "nextjs" | "express" | "fastapi"; export type TemplateFramework = "nextjs" | "express" | "fastapi";
export type TemplateEngine = "simple" | "context"; export type TemplateEngine = "simple" | "context";
export type TemplateUI = "html" | "shadcn"; export type TemplateUI = "html" | "shadcn";
export type TemplateVectorDB = "none" | "mongo";
export interface InstallTemplateArgs { export interface InstallTemplateArgs {
appName: string; appName: string;
...@@ -20,4 +21,5 @@ export interface InstallTemplateArgs { ...@@ -20,4 +21,5 @@ export interface InstallTemplateArgs {
forBackend?: string; forBackend?: string;
model: string; model: string;
communityProjectPath?: string; communityProjectPath?: string;
vectorDb?: TemplateVectorDB;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment