From 7ee16f8136f13de7b18df6fadff8349277a18e03 Mon Sep 17 00:00:00 2001 From: Motoki saito <stmtk1@users.noreply.github.com> Date: Mon, 22 Jan 2024 02:26:35 +0900 Subject: [PATCH] add from_collection method to ChromaVectorStore class (#10167) add from_collection_method --- .gitignore | 3 +++ llama_index/vector_stores/chroma.py | 14 +++++++++++++- tests/vector_stores/test_chromadb.py | 14 ++++++++++++-- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 990c18de22..b0956e9a96 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ Pipfile.lock # pyright pyrightconfig.json + +# persist dir for chromadb test +/data/ diff --git a/llama_index/vector_stores/chroma.py b/llama_index/vector_stores/chroma.py index c883ef0edb..fb4c392dbb 100644 --- a/llama_index/vector_stores/chroma.py +++ b/llama_index/vector_stores/chroma.py @@ -165,6 +165,18 @@ class ChromaVectorStore(BasePydanticVectorStore): collection_kwargs=collection_kwargs or {}, ) + @classmethod + def from_collection(cls, collection: Any) -> "ChromaVectorStore": + try: + from chromadb import Collection + except ImportError: + raise ImportError(import_err_msg) + + if not isinstance(collection, Collection): + raise Exception("argument is not chromadb collection instance") + + return cls(chroma_collection=collection) + @classmethod def from_params( cls, @@ -174,7 +186,7 @@ class ChromaVectorStore(BasePydanticVectorStore): ssl: bool = False, headers: Optional[Dict[str, str]] = None, persist_dir: Optional[str] = None, - collection_kwargs: Optional[dict] = {}, + collection_kwargs: dict = {}, **kwargs: Any, ) -> "ChromaVectorStore": try: diff --git a/tests/vector_stores/test_chromadb.py b/tests/vector_stores/test_chromadb.py index ca38955150..a8291a39ba 100644 --- a/tests/vector_stores/test_chromadb.py +++ b/tests/vector_stores/test_chromadb.py @@ -38,13 +38,13 @@ except (ImportError, Exception): def test_instance_creation_from_collection() -> None: connection = chromadb.HttpClient(**PARAMS) collection = connection.get_collection(COLLECTION_NAME) - store = ChromaVectorStore(chroma_collection=collection) + store = ChromaVectorStore.from_collection(collection) assert isinstance(store, ChromaVectorStore) @pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") def test_instance_creation_from_http_params() -> None: - store = ChromaVectorStore( + store = ChromaVectorStore.from_params( host=PARAMS["host"], port=PARAMS["port"], collection_name=COLLECTION_NAME, @@ -53,6 +53,16 @@ def test_instance_creation_from_http_params() -> None: assert isinstance(store, ChromaVectorStore) +@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") +def test_instance_creation_from_persist_dir() -> None: + store = ChromaVectorStore.from_params( + persist_dir="./data", + collection_name=COLLECTION_NAME, + collection_kwargs={}, + ) + assert isinstance(store, ChromaVectorStore) + + @pytest.fixture() def vector_store() -> ChromaVectorStore: connection = chromadb.HttpClient(**PARAMS) -- GitLab