From 7ee16f8136f13de7b18df6fadff8349277a18e03 Mon Sep 17 00:00:00 2001
From: Motoki saito <stmtk1@users.noreply.github.com>
Date: Mon, 22 Jan 2024 02:26:35 +0900
Subject: [PATCH] add from_collection method to ChromaVectorStore class
 (#10167)

add from_collection_method
---
 .gitignore                           |  3 +++
 llama_index/vector_stores/chroma.py  | 14 +++++++++++++-
 tests/vector_stores/test_chromadb.py | 14 ++++++++++++--
 3 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/.gitignore b/.gitignore
index 990c18de22..b0956e9a96 100644
--- a/.gitignore
+++ b/.gitignore
@@ -151,3 +151,6 @@ Pipfile.lock
 
 # pyright
 pyrightconfig.json
+
+# persist dir for chromadb test
+/data/
diff --git a/llama_index/vector_stores/chroma.py b/llama_index/vector_stores/chroma.py
index c883ef0edb..fb4c392dbb 100644
--- a/llama_index/vector_stores/chroma.py
+++ b/llama_index/vector_stores/chroma.py
@@ -165,6 +165,18 @@ class ChromaVectorStore(BasePydanticVectorStore):
             collection_kwargs=collection_kwargs or {},
         )
 
+    @classmethod
+    def from_collection(cls, collection: Any) -> "ChromaVectorStore":
+        try:
+            from chromadb import Collection
+        except ImportError:
+            raise ImportError(import_err_msg)
+
+        if not isinstance(collection, Collection):
+            raise Exception("argument is not chromadb collection instance")
+
+        return cls(chroma_collection=collection)
+
     @classmethod
     def from_params(
         cls,
@@ -174,7 +186,7 @@ class ChromaVectorStore(BasePydanticVectorStore):
         ssl: bool = False,
         headers: Optional[Dict[str, str]] = None,
         persist_dir: Optional[str] = None,
-        collection_kwargs: Optional[dict] = {},
+        collection_kwargs: dict = {},
         **kwargs: Any,
     ) -> "ChromaVectorStore":
         try:
diff --git a/tests/vector_stores/test_chromadb.py b/tests/vector_stores/test_chromadb.py
index ca38955150..a8291a39ba 100644
--- a/tests/vector_stores/test_chromadb.py
+++ b/tests/vector_stores/test_chromadb.py
@@ -38,13 +38,13 @@ except (ImportError, Exception):
 def test_instance_creation_from_collection() -> None:
     connection = chromadb.HttpClient(**PARAMS)
     collection = connection.get_collection(COLLECTION_NAME)
-    store = ChromaVectorStore(chroma_collection=collection)
+    store = ChromaVectorStore.from_collection(collection)
     assert isinstance(store, ChromaVectorStore)
 
 
 @pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available")
 def test_instance_creation_from_http_params() -> None:
-    store = ChromaVectorStore(
+    store = ChromaVectorStore.from_params(
         host=PARAMS["host"],
         port=PARAMS["port"],
         collection_name=COLLECTION_NAME,
@@ -53,6 +53,16 @@ def test_instance_creation_from_http_params() -> None:
     assert isinstance(store, ChromaVectorStore)
 
 
+@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available")
+def test_instance_creation_from_persist_dir() -> None:
+    store = ChromaVectorStore.from_params(
+        persist_dir="./data",
+        collection_name=COLLECTION_NAME,
+        collection_kwargs={},
+    )
+    assert isinstance(store, ChromaVectorStore)
+
+
 @pytest.fixture()
 def vector_store() -> ChromaVectorStore:
     connection = chromadb.HttpClient(**PARAMS)
-- 
GitLab