From 648482b0f1ebba312e73d721deae1c9d6034a23f Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Fri, 12 Jan 2024 13:35:00 +0700
Subject: [PATCH] Feat: Add support for ChromaDB (#310)

Co-authored-by: Aarav Navani <38411399+oofmeister27@users.noreply.github.com>
---
 .changeset/purple-camels-walk.md              |   5 +
 examples/astradb/load.ts                      |   2 +-
 examples/chromadb/README.md                   |  13 ++
 examples/chromadb/test.ts                     |  40 +++++
 examples/{astradb => }/data/movie_reviews.csv | Bin 19413 -> 19313 bytes
 examples/package.json                         |   5 +-
 packages/core/package.json                    |   1 +
 packages/core/src/storage/index.ts            |   1 +
 .../storage/vectorStore/ChromaVectorStore.ts  | 148 ++++++++++++++++++
 pnpm-lock.yaml                                |  42 ++++-
 10 files changed, 253 insertions(+), 4 deletions(-)
 create mode 100644 .changeset/purple-camels-walk.md
 create mode 100644 examples/chromadb/README.md
 create mode 100644 examples/chromadb/test.ts
 rename examples/{astradb => }/data/movie_reviews.csv (99%)
 create mode 100644 packages/core/src/storage/vectorStore/ChromaVectorStore.ts

diff --git a/.changeset/purple-camels-walk.md b/.changeset/purple-camels-walk.md
new file mode 100644
index 000000000..6d972ce3e
--- /dev/null
+++ b/.changeset/purple-camels-walk.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+Feat: Add support for Chroma DB as a vector store
diff --git a/examples/astradb/load.ts b/examples/astradb/load.ts
index 298ad7551..6c422fb14 100644
--- a/examples/astradb/load.ts
+++ b/examples/astradb/load.ts
@@ -10,7 +10,7 @@ const collectionName = "movie_reviews";
 async function main() {
   try {
     const reader = new PapaCSVReader(false);
-    const docs = await reader.loadData("astradb/data/movie_reviews.csv");
+    const docs = await reader.loadData("../data/movie_reviews.csv");
 
     const astraVS = new AstraDBVectorStore();
     await astraVS.create(collectionName, {
diff --git a/examples/chromadb/README.md b/examples/chromadb/README.md
new file mode 100644
index 000000000..5b1c6d7b2
--- /dev/null
+++ b/examples/chromadb/README.md
@@ -0,0 +1,13 @@
+# Chroma Vector Store Example
+
+How to run `examples/chromadb/test.ts`:
+
+Export your OpenAI API Key using `export OPEN_API_KEY=insert your api key here`
+
+If you haven't installed chromadb, run `pip install chromadb`. Start the server using `chroma run`.
+
+Now, open a new terminal window and inside `examples`, run `pnpx ts-node chromadb/test.ts`.
+
+Here's the output for the input query `Tell me about Godfrey Cheshire's rating of La Sapienza.`:
+
+`Godfrey Cheshire gave La Sapienza a rating of 4 out of 4, describing it as fresh and the most astonishing and important movie to emerge from France in quite some time.`
diff --git a/examples/chromadb/test.ts b/examples/chromadb/test.ts
new file mode 100644
index 000000000..51d2b2692
--- /dev/null
+++ b/examples/chromadb/test.ts
@@ -0,0 +1,40 @@
+import {
+  ChromaVectorStore,
+  PapaCSVReader,
+  storageContextFromDefaults,
+  VectorStoreIndex,
+} from "llamaindex";
+
+const collectionName = "movie_reviews";
+
+async function main() {
+  const sourceFile: string = "./data/movie_reviews.csv";
+
+  try {
+    console.log(`Loading data from ${sourceFile}`);
+    const reader = new PapaCSVReader(false, ", ", "\n", {
+      header: true,
+    });
+    const docs = await reader.loadData(sourceFile);
+
+    console.log("Creating ChromaDB vector store");
+    const chromaVS = new ChromaVectorStore({ collectionName });
+    const ctx = await storageContextFromDefaults({ vectorStore: chromaVS });
+
+    console.log("Embedding documents and adding to index");
+    const index = await VectorStoreIndex.fromDocuments(docs, {
+      storageContext: ctx,
+    });
+
+    console.log("Querying index");
+    const queryEngine = index.asQueryEngine();
+    const response = await queryEngine.query(
+      "Tell me about Godfrey Cheshire's rating of La Sapienza.",
+    );
+    console.log(response.toString());
+  } catch (e) {
+    console.error(e);
+  }
+}
+
+main();
diff --git a/examples/astradb/data/movie_reviews.csv b/examples/data/movie_reviews.csv
similarity index 99%
rename from examples/astradb/data/movie_reviews.csv
rename to examples/data/movie_reviews.csv
index d605bdbc0ae9345d56282507076849748b7460ee..eaebe6dc672fcaa58b45f608dd3b17c0a4c61843 100644
GIT binary patch
delta 499
zcmcaQo$=!|#tGgNeQr<IV7xckn^|qME|V<dWN+s4lciazCQoCzIyr~+#pF=7r<1MO
z7fimxesS_Tj`fq{IQMUs<I-lF?8SY0a-D$WWPP5OVA_cH1(-JByD+&<KzgzQ|BA_V
z0y2}I^3T{@C!om)<{S~62&RjKc7heT3tyWoCvtbPoQTY13DHlJ>jcCmD~es8{8Q}R
z<`?3OjFVqU?3%n(^3CLp(xQ`>ORb;0QCe(ruC(*yjnWR2H%g07=9JL@%el*Z1Jl8>
zOTjds97NnlP7TcOkV~EXPR?|)ihSkddGc2_7c1~HZf;g|V4QqJ>EPrx<tLM~RqQ9P
zR@pu|S#{>*zpD2q-&K1#`M7%dWP6QMlesjPZ{DQI#W;Dd){)Ik+Kh~oTXX^^&(pP@
zEU4Qvd7kd>$x(W3lOO8sojhNE-{d%hU6UmZpG^K|xMuQdqvFZ(##<&|F)p9%V-hlX
zp^5F}`zCiL?=gKjd70VP$uZ_XCWl%yOulCEfATfUzmu<69hf}T`sL<&8*awQ-L{RB
zdF^61XWJ<<PVTjzHd)2t?PMXxU6ao`?w{Q2G;6Yk^Ww>8oDWZ);IeJ9yKBbe4X)cJ
fr@QT*tnB`9Gph$X<78gXGm~$5p4hzD%ZL#GIl|Ng

delta 751
zcmex3jq&Pq#tGhxyc4}|GxAPWW4y=6JK3FCZL=1WEF+kCo{@L57)uo+@8k(AR~dOH
zr?S3a<elux_LPx#vJv|NM&8Mn*e^2jP8Q@6p1h4?4wRM3xrUK<vmlo?$Pkc}Klcts
z-pL<D#U?B8u}oIv0T}}0Df2#u%Uy=D8U$D;^YO2M^VlX|;-3NGZEg?{1{n*|-XOp>
zdA;C7h{R+NFGgr5oX0xZK=>LGPf_F$67QeLXE5(ABk$yUV(%DvH=hw_1Q~o*VizOt
z<aWt7AkWB&PVSIe4`umCyFgh>rM01~r_x$bmWs?bFiUK*f$S0}>y@lJm}LznB_^lH
zB{A|&J|<_%$UFJBTqPs#<P!NSjJ%sY6!;k#c{hhEI)Gd-UFjer@8k&OCycz4ZB*<T
zc_&w^YzGCn>P$x7$){ECGxAPerS_7Mck*QQaz@_CTpFhsc_&}dSkB11xmuG8WJIgh
z5k}t4p4wm^dFcc)@=l(ua~w=sPoAsWz{opUR<Dhbck*Jry^Oq*6ZH2n@=lgB0Qvle
z0m$AxhHDskCubNHGxAQpZM21vcXGcm$SyXM5Juj~Q6{#GypyM!++pOMTx<$*NsQT6
zkRQ!|F!E03v;djkY4M+tcXEg2Uq;@^tyTvZc_+JCzhvaytY^awa-pSdBO~wRUAD1|
zyqm@C6hT53_R|=7C!eziMagalkXuR}_cQWNHgTE-5_ek6$UC{v`7lVtWg8>!<X0{k
ujJ%VBTtRFOx801qlMlN+WaQnv#GM^v+-i?AjJ%V}Jx?(5ZnpO_VgvvwG~_z~

diff --git a/examples/package.json b/examples/package.json
index 07e116249..465ea579d 100644
--- a/examples/package.json
+++ b/examples/package.json
@@ -3,12 +3,13 @@
   "private": true,
   "name": "examples",
   "dependencies": {
-    "@notionhq/client": "^2.2.13",
     "@datastax/astra-db-ts": "^0.1.2",
+    "@notionhq/client": "^2.2.13",
     "@pinecone-database/pinecone": "^1.1.2",
+    "chromadb": "^1.7.3",
     "commander": "^11.1.0",
-    "llamaindex": "latest",
     "dotenv": "^16.3.1",
+    "llamaindex": "latest",
     "mongodb": "^6.2.0"
   },
   "devDependencies": {
diff --git a/packages/core/package.json b/packages/core/package.json
index f94497dfd..733474e9c 100644
--- a/packages/core/package.json
+++ b/packages/core/package.json
@@ -10,6 +10,7 @@
     "@pinecone-database/pinecone": "^1.1.2",
     "@xenova/transformers": "^2.10.0",
     "assemblyai": "^4.0.0",
+    "chromadb": "^1.7.3",
     "file-type": "^18.7.0",
     "js-tiktoken": "^1.0.8",
     "lodash": "^4.17.21",
diff --git a/packages/core/src/storage/index.ts b/packages/core/src/storage/index.ts
index 83cedd389..796f8fb9f 100644
--- a/packages/core/src/storage/index.ts
+++ b/packages/core/src/storage/index.ts
@@ -8,6 +8,7 @@ export * from "./indexStore/types";
 export { SimpleKVStore } from "./kvStore/SimpleKVStore";
 export * from "./kvStore/types";
 export { AstraDBVectorStore } from "./vectorStore/AstraDBVectorStore";
+export { ChromaVectorStore } from "./vectorStore/ChromaVectorStore";
 export { MongoDBAtlasVectorSearch } from "./vectorStore/MongoDBAtlasVectorStore";
 export { PGVectorStore } from "./vectorStore/PGVectorStore";
 export { PineconeVectorStore } from "./vectorStore/PineconeVectorStore";
diff --git a/packages/core/src/storage/vectorStore/ChromaVectorStore.ts b/packages/core/src/storage/vectorStore/ChromaVectorStore.ts
new file mode 100644
index 000000000..d6d8d52de
--- /dev/null
+++ b/packages/core/src/storage/vectorStore/ChromaVectorStore.ts
@@ -0,0 +1,148 @@
+import {
+  AddParams,
+  ChromaClient,
+  ChromaClientParams,
+  Collection,
+  IncludeEnum,
+  QueryResponse,
+  Where,
+  WhereDocument,
+} from "chromadb";
+import { BaseNode, MetadataMode } from "../../Node";
+import {
+  VectorStore,
+  VectorStoreQuery,
+  VectorStoreQueryMode,
+  VectorStoreQueryResult,
+} from "./types";
+import { metadataDictToNode, nodeToMetadata } from "./utils";
+
+type ChromaDeleteOptions = {
+  where?: Where;
+  whereDocument?: WhereDocument;
+};
+
+type ChromaQueryOptions = {
+  whereDocument?: WhereDocument;
+};
+
+const DEFAULT_TEXT_KEY = "text";
+
+export class ChromaVectorStore implements VectorStore {
+  storesText: boolean = true;
+  flatMetadata: boolean = true;
+  textKey: string;
+  private chromaClient: ChromaClient;
+  private collection: Collection | null = null;
+  private collectionName: string;
+
+  constructor(init: {
+    collectionName: string;
+    textKey?: string;
+    chromaClientParams?: ChromaClientParams;
+  }) {
+    this.collectionName = init.collectionName;
+    this.chromaClient = new ChromaClient(init.chromaClientParams);
+    this.textKey = init.textKey ?? DEFAULT_TEXT_KEY;
+  }
+
+  client(): ChromaClient {
+    return this.chromaClient;
+  }
+
+  async getCollection(): Promise<Collection> {
+    if (!this.collection) {
+      const coll = await this.chromaClient.createCollection({
+        name: this.collectionName,
+      });
+      this.collection = coll;
+    }
+    return this.collection;
+  }
+
+  private getDataToInsert(nodes: BaseNode[]): AddParams {
+    const metadatas = nodes.map((node) =>
+      nodeToMetadata(node, true, this.textKey, this.flatMetadata),
+    );
+    return {
+      embeddings: nodes.map((node) => node.getEmbedding()),
+      ids: nodes.map((node) => node.id_),
+      metadatas,
+      documents: nodes.map((node) => node.getContent(MetadataMode.NONE)),
+    };
+  }
+
+  async add(nodes: BaseNode[]): Promise<string[]> {
+    if (!nodes || nodes.length === 0) {
+      return [];
+    }
+
+    const dataToInsert = this.getDataToInsert(nodes);
+    const collection = await this.getCollection();
+    await collection.add(dataToInsert);
+    return nodes.map((node) => node.id_);
+  }
+
+  async delete(
+    refDocId: string,
+    deleteOptions?: ChromaDeleteOptions,
+  ): Promise<void> {
+    const collection = await this.getCollection();
+    await collection.delete({
+      ids: [refDocId],
+      where: deleteOptions?.where,
+      whereDocument: deleteOptions?.whereDocument,
+    });
+  }
+
+  async query(
+    query: VectorStoreQuery,
+    options?: ChromaQueryOptions,
+  ): Promise<VectorStoreQueryResult> {
+    if (query.docIds) {
+      throw new Error("ChromaDB does not support querying by docIDs");
+    }
+    if (query.mode != VectorStoreQueryMode.DEFAULT) {
+      throw new Error("ChromaDB does not support querying by mode");
+    }
+
+    const chromaWhere: { [x: string]: string | number | boolean } = {};
+    if (query.filters) {
+      query.filters.filters.map((filter) => {
+        const filterKey = filter.key;
+        const filterValue = filter.value;
+        chromaWhere[filterKey] = filterValue;
+      });
+    }
+
+    const collection = await this.getCollection();
+    const queryResponse: QueryResponse = await collection.query({
+      queryEmbeddings: query.queryEmbedding ?? undefined,
+      queryTexts: query.queryStr ?? undefined,
+      nResults: query.similarityTopK,
+      where: Object.keys(chromaWhere).length ? chromaWhere : undefined,
+      whereDocument: options?.whereDocument,
+      //ChromaDB doesn't return the result embeddings by default so we need to include them
+      include: [
+        IncludeEnum.Distances,
+        IncludeEnum.Metadatas,
+        IncludeEnum.Documents,
+        IncludeEnum.Embeddings,
+      ],
+    });
+    const vectorStoreQueryResult: VectorStoreQueryResult = {
+      nodes: queryResponse.ids[0].map((id, index) => {
+        const text = (queryResponse.documents as string[][])[0][index];
+        const metaData = queryResponse.metadatas[0][index] ?? {};
+        const node = metadataDictToNode(metaData);
+        node.setContent(text);
+        return node;
+      }),
+      similarities: (queryResponse.distances as number[][])[0].map(
+        (distance) => 1 - distance,
+      ),
+      ids: queryResponse.ids[0],
+    };
+    return vectorStoreQueryResult;
+  }
+}
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index cb1ee0bed..214a3896d 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -115,6 +115,9 @@ importers:
       '@pinecone-database/pinecone':
         specifier: ^1.1.2
         version: 1.1.2
+      chromadb:
+        specifier: ^1.7.3
+        version: 1.7.3(openai@4.20.1)
       commander:
         specifier: ^11.1.0
         version: 11.1.0
@@ -158,6 +161,9 @@ importers:
       assemblyai:
         specifier: ^4.0.0
         version: 4.0.0
+      chromadb:
+        specifier: ^1.7.3
+        version: 1.7.3(openai@4.20.1)
       file-type:
         specifier: ^18.7.0
         version: 18.7.0
@@ -6252,6 +6258,28 @@ packages:
     resolution: {integrity: sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==}
     engines: {node: '>=10'}
 
+  /chromadb@1.7.3(openai@4.20.1):
+    resolution: {integrity: sha512-3GgvQjpqgk5C89x5EuTDaXKbfrdqYDJ5UVyLQ3ZmwxnpetNc+HhRDGjkvXa5KSvpQ3lmKoyDoqnN4tZepfFkbw==}
+    engines: {node: '>=14.17.0'}
+    peerDependencies:
+      '@google/generative-ai': ^0.1.1
+      cohere-ai: ^5.0.0 || ^6.0.0 || ^7.0.0
+      openai: ^3.0.0 || ^4.0.0
+    peerDependenciesMeta:
+      '@google/generative-ai':
+        optional: true
+      cohere-ai:
+        optional: true
+      openai:
+        optional: true
+    dependencies:
+      cliui: 8.0.1
+      isomorphic-fetch: 3.0.0
+      openai: 4.20.1
+    transitivePeerDependencies:
+      - encoding
+    dev: false
+
   /chrome-trace-event@1.0.3:
     resolution: {integrity: sha512-p3KULyQg4S7NIHixdwbGX+nFHkoBiA4YQmyWtjb8XngSKV124nJmRysgAeujbUVb15vh+RvFUfCPqU7rXk+hZg==}
     engines: {node: '>=6.0'}
@@ -6356,7 +6384,6 @@ packages:
       string-width: 4.2.3
       strip-ansi: 6.0.1
       wrap-ansi: 7.0.0
-    dev: true
 
   /clone-deep@4.0.1:
     resolution: {integrity: sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ==}
@@ -10117,6 +10144,15 @@ packages:
     resolution: {integrity: sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==}
     engines: {node: '>=0.10.0'}
 
+  /isomorphic-fetch@3.0.0:
+    resolution: {integrity: sha512-qvUtwJ3j6qwsF3jLxkZ72qCgjMysPzDfeV240JHiGZsANBYd+EEuu35v7dfrJ9Up0Ak07D7GGSkGhCHTqg/5wA==}
+    dependencies:
+      node-fetch: 2.7.0(encoding@0.1.13)
+      whatwg-fetch: 3.6.20
+    transitivePeerDependencies:
+      - encoding
+    dev: false
+
   /isomorphic-timers-promises@1.0.1:
     resolution: {integrity: sha512-u4sej9B1LPSxTGKB/HiuzvEQnXH0ECYkSVQU39koSwmFAxhlEAFl9RdTvLv4TOTQUgBS5O3O5fwUxk6byBZ+IQ==}
     engines: {node: '>=10'}
@@ -16377,6 +16413,10 @@ packages:
     engines: {node: '>=0.8.0'}
     dev: false
 
+  /whatwg-fetch@3.6.20:
+    resolution: {integrity: sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==}
+    dev: false
+
   /whatwg-url@13.0.0:
     resolution: {integrity: sha512-9WWbymnqj57+XEuqADHrCJ2eSXzn8WXIW/YSGaZtb2WKAInQ6CHfaUUcTyyver0p8BDg5StLQq8h1vtZuwmOig==}
     engines: {node: '>=16'}
-- 
GitLab