From cb1001de957fad116b69099f19d21c4fc56f6c25 Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Tue, 14 May 2024 15:42:01 +0700
Subject: [PATCH] feat: add support for ChromaDB vector store (#82)

---
 .changeset/unlucky-bikes-eat.md               |  5 +++
 helpers/env-variables.ts                      | 34 +++++++++++++++--
 helpers/python.ts                             | 14 +++++++
 helpers/types.ts                              |  3 +-
 questions.ts                                  |  1 +
 .../vectordbs/python/chroma/generate.py       | 33 +++++++++++++++++
 .../vectordbs/python/chroma/index.py          | 14 +++++++
 .../vectordbs/python/chroma/vectordb.py       | 24 ++++++++++++
 .../vectordbs/typescript/chroma/generate.ts   | 37 +++++++++++++++++++
 .../vectordbs/typescript/chroma/index.ts      | 16 ++++++++
 .../vectordbs/typescript/chroma/shared.ts     | 18 +++++++++
 .../types/streaming/express/package.json      |  4 +-
 templates/types/streaming/nextjs/package.json |  2 +-
 13 files changed, 198 insertions(+), 7 deletions(-)
 create mode 100644 .changeset/unlucky-bikes-eat.md
 create mode 100644 templates/components/vectordbs/python/chroma/generate.py
 create mode 100644 templates/components/vectordbs/python/chroma/index.py
 create mode 100644 templates/components/vectordbs/python/chroma/vectordb.py
 create mode 100644 templates/components/vectordbs/typescript/chroma/generate.ts
 create mode 100644 templates/components/vectordbs/typescript/chroma/index.ts
 create mode 100644 templates/components/vectordbs/typescript/chroma/shared.ts

diff --git a/.changeset/unlucky-bikes-eat.md b/.changeset/unlucky-bikes-eat.md
new file mode 100644
index 00000000..b25461b5
--- /dev/null
+++ b/.changeset/unlucky-bikes-eat.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add ChromaDB vector store
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index d410edfa..2d69f227 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -29,8 +29,11 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
   );
 };
 
-const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
-  if (!vectorDb) {
+const getVectorDBEnvs = (
+  vectorDb?: TemplateVectorDB,
+  framework?: TemplateFramework,
+): EnvVar[] => {
+  if (!vectorDb || !framework) {
     return [];
   }
   switch (vectorDb) {
@@ -129,6 +132,31 @@ const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
             "Optional API key for authenticating requests to Qdrant.",
         },
       ];
+    case "chroma":
+      const envs = [
+        {
+          name: "CHROMA_COLLECTION",
+          description: "The name of the collection in your Chroma database",
+        },
+        {
+          name: "CHROMA_HOST",
+          description: "The API endpoint for your Chroma database",
+        },
+        {
+          name: "CHROMA_PORT",
+          description: "The port for your Chroma database",
+        },
+      ];
+      // TS Version doesn't support config local storage path
+      if (framework === "fastapi") {
+        envs.push({
+          name: "CHROMA_PATH",
+          description: `The local path to the Chroma database. 
+Specify this if you are using a local Chroma database. 
+Otherwise, use CHROMA_HOST and CHROMA_PORT config above`,
+        });
+      }
+      return envs;
     default:
       return [];
   }
@@ -257,7 +285,7 @@ export const createBackendEnvFile = async (
     // Add engine environment variables
     ...getEngineEnvs(),
     // Add vector database environment variables
-    ...getVectorDBEnvs(opts.vectorDb),
+    ...getVectorDBEnvs(opts.vectorDb, opts.framework),
     ...getFrameworkEnvs(opts.framework, opts.port),
   ];
   // Render and write env file
diff --git a/helpers/python.ts b/helpers/python.ts
index 51d4802f..ccc4cec9 100644
--- a/helpers/python.ts
+++ b/helpers/python.ts
@@ -70,6 +70,20 @@ const getAdditionalDependencies = (
       });
       break;
     }
+    case "qdrant": {
+      dependencies.push({
+        name: "llama-index-vector-stores-qdrant",
+        version: "^0.2.8",
+      });
+      break;
+    }
+    case "chroma": {
+      dependencies.push({
+        name: "llama-index-vector-stores-chroma",
+        version: "^0.1.8",
+      });
+      break;
+    }
   }
 
   // Add data source dependencies
diff --git a/helpers/types.ts b/helpers/types.ts
index b70f586a..b26e7088 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -20,7 +20,8 @@ export type TemplateVectorDB =
   | "pinecone"
   | "milvus"
   | "astra"
-  | "qdrant";
+  | "qdrant"
+  | "chroma";
 export type TemplatePostInstallAction =
   | "none"
   | "VSCode"
diff --git a/questions.ts b/questions.ts
index 3e046084..8d1f07dd 100644
--- a/questions.ts
+++ b/questions.ts
@@ -97,6 +97,7 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
     { title: "Milvus", value: "milvus" },
     { title: "Astra", value: "astra" },
     { title: "Qdrant", value: "qdrant" },
+    { title: "ChromaDB", value: "chroma" },
   ];
 
   const vectordbLang = framework === "fastapi" ? "python" : "typescript";
diff --git a/templates/components/vectordbs/python/chroma/generate.py b/templates/components/vectordbs/python/chroma/generate.py
new file mode 100644
index 00000000..2665d8a6
--- /dev/null
+++ b/templates/components/vectordbs/python/chroma/generate.py
@@ -0,0 +1,33 @@
+from dotenv import load_dotenv
+
+load_dotenv()
+
+import os
+import logging
+from llama_index.core.storage import StorageContext
+from llama_index.core.indices import VectorStoreIndex
+from app.settings import init_settings
+from app.engine.loaders import get_documents
+from app.engine.vectordb import get_vector_store
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger()
+
+
+def generate_datasource():
+    init_settings()
+    logger.info("Creating new index")
+    # load the documents and create the index
+    documents = get_documents()
+    store = get_vector_store()
+    storage_context = StorageContext.from_defaults(vector_store=store)
+    VectorStoreIndex.from_documents(
+        documents,
+        storage_context=storage_context,
+        show_progress=True,  # this will show you a progress bar as the embeddings are created
+    )
+    logger.info("Successfully created embeddings in the ChromaDB")
+
+
+if __name__ == "__main__":
+    generate_datasource()
diff --git a/templates/components/vectordbs/python/chroma/index.py b/templates/components/vectordbs/python/chroma/index.py
new file mode 100644
index 00000000..f476c00d
--- /dev/null
+++ b/templates/components/vectordbs/python/chroma/index.py
@@ -0,0 +1,14 @@
+import logging
+
+from llama_index.core.indices import VectorStoreIndex
+from app.engine.vectordb import get_vector_store
+
+logger = logging.getLogger("uvicorn")
+
+
+def get_index():
+    logger.info("Connecting to ChromaDB..")
+    store = get_vector_store()
+    index = VectorStoreIndex.from_vector_store(store, use_async=False)
+    logger.info("Finished connecting to ChromaDB.")
+    return index
diff --git a/templates/components/vectordbs/python/chroma/vectordb.py b/templates/components/vectordbs/python/chroma/vectordb.py
new file mode 100644
index 00000000..2a71e0a2
--- /dev/null
+++ b/templates/components/vectordbs/python/chroma/vectordb.py
@@ -0,0 +1,24 @@
+import os
+from llama_index.vector_stores.chroma import ChromaVectorStore
+
+
+def get_vector_store():
+    collection_name = os.getenv("CHROMA_COLLECTION", "default")
+    chroma_path = os.getenv("CHROMA_PATH")
+    # if CHROMA_PATH is set, use a local ChromaVectorStore from the path
+    # otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
+    if chroma_path:
+        store = ChromaVectorStore.from_params(
+            persist_dir=chroma_path, collection_name=collection_name
+        )
+    else:
+        if not os.getenv("CHROMA_HOST") or not os.getenv("CHROMA_PORT"):
+            raise ValueError(
+                "Please provide either CHROMA_PATH or CHROMA_HOST and CHROMA_PORT"
+            )
+        store = ChromaVectorStore.from_params(
+            host=os.getenv("CHROMA_HOST"),
+            port=int(os.getenv("CHROMA_PORT")),
+            collection_name=collection_name,
+        )
+    return store
diff --git a/templates/components/vectordbs/typescript/chroma/generate.ts b/templates/components/vectordbs/typescript/chroma/generate.ts
new file mode 100644
index 00000000..83e8ea16
--- /dev/null
+++ b/templates/components/vectordbs/typescript/chroma/generate.ts
@@ -0,0 +1,37 @@
+/* eslint-disable turbo/no-undeclared-env-vars */
+import * as dotenv from "dotenv";
+import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
+import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
+import { getDocuments } from "./loader";
+import { initSettings } from "./settings";
+import { checkRequiredEnvVars } from "./shared";
+
+dotenv.config();
+
+async function loadAndIndex() {
+  // load objects from storage and convert them into LlamaIndex Document objects
+  const documents = await getDocuments();
+
+  // create vector store
+  const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
+
+  const vectorStore = new ChromaVectorStore({
+    collectionName: process.env.CHROMA_COLLECTION,
+    chromaClientParams: { path: chromaUri },
+  });
+
+  // create index from all the Documentss and store them in Pinecone
+  console.log("Start creating embeddings...");
+  const storageContext = await storageContextFromDefaults({ vectorStore });
+  await VectorStoreIndex.fromDocuments(documents, { storageContext });
+  console.log(
+    "Successfully created embeddings and save to your ChromaDB index.",
+  );
+}
+
+(async () => {
+  checkRequiredEnvVars();
+  initSettings();
+  await loadAndIndex();
+  console.log("Finished generating storage.");
+})();
diff --git a/templates/components/vectordbs/typescript/chroma/index.ts b/templates/components/vectordbs/typescript/chroma/index.ts
new file mode 100644
index 00000000..1d36e643
--- /dev/null
+++ b/templates/components/vectordbs/typescript/chroma/index.ts
@@ -0,0 +1,16 @@
+/* eslint-disable turbo/no-undeclared-env-vars */
+import { VectorStoreIndex } from "llamaindex";
+import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
+import { checkRequiredEnvVars } from "./shared";
+
+export async function getDataSource() {
+  checkRequiredEnvVars();
+  const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
+
+  const store = new ChromaVectorStore({
+    collectionName: process.env.CHROMA_COLLECTION,
+    chromaClientParams: { path: chromaUri },
+  });
+
+  return await VectorStoreIndex.fromVectorStore(store);
+}
diff --git a/templates/components/vectordbs/typescript/chroma/shared.ts b/templates/components/vectordbs/typescript/chroma/shared.ts
new file mode 100644
index 00000000..4c884175
--- /dev/null
+++ b/templates/components/vectordbs/typescript/chroma/shared.ts
@@ -0,0 +1,18 @@
+const REQUIRED_ENV_VARS = ["CHROMA_COLLECTION", "CHROMA_HOST", "CHROMA_PORT"];
+
+export function checkRequiredEnvVars() {
+  const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
+    return !process.env[envVar];
+  });
+
+  if (missingEnvVars.length > 0) {
+    console.log(
+      `The following environment variables are required but missing: ${missingEnvVars.join(
+        ", ",
+      )}`,
+    );
+    throw new Error(
+      `Missing environment variables: ${missingEnvVars.join(", ")}`,
+    );
+  }
+}
diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json
index 0dd1a363..d61b7fba 100644
--- a/templates/types/streaming/express/package.json
+++ b/templates/types/streaming/express/package.json
@@ -1,13 +1,13 @@
 {
   "name": "llama-index-express-streaming",
   "version": "1.0.0",
-  "main": "dist/index.mjs",
+  "main": "dist/index.js",
   "scripts": {
     "format": "prettier --ignore-unknown --cache --check .",
     "format:write": "prettier --ignore-unknown --write .",
     "build": "tsup index.ts --format cjs --dts",
     "start": "node dist/index.js",
-    "dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.mjs\""
+    "dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.js\""
   },
   "dependencies": {
     "ai": "^3.0.21",
diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json
index 3093201b..3bcaf0cd 100644
--- a/templates/types/streaming/nextjs/package.json
+++ b/templates/types/streaming/nextjs/package.json
@@ -18,7 +18,7 @@
     "class-variance-authority": "^0.7.0",
     "clsx": "^2.1.1",
     "dotenv": "^16.3.1",
-    "llamaindex": "0.3.9",
+    "llamaindex": "0.3.8",
     "lucide-react": "^0.294.0",
     "next": "^14.0.3",
     "pdf2json": "3.0.5",
-- 
GitLab