diff --git a/.gitignore b/.gitignore index 990c18de229088f55c6c514fd0f2d49981d1b0e7..b0956e9a96d8110367aa7bbc67705a7da091722e 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ Pipfile.lock # pyright pyrightconfig.json + +# persist dir for chromadb test +/data/ diff --git a/llama_index/vector_stores/chroma.py b/llama_index/vector_stores/chroma.py index c883ef0edb11455a58dd8419781f68ef4027ad14..fb4c392dbbef0a3014526438170edf4f34c6987c 100644 --- a/llama_index/vector_stores/chroma.py +++ b/llama_index/vector_stores/chroma.py @@ -165,6 +165,18 @@ class ChromaVectorStore(BasePydanticVectorStore): collection_kwargs=collection_kwargs or {}, ) + @classmethod + def from_collection(cls, collection: Any) -> "ChromaVectorStore": + try: + from chromadb import Collection + except ImportError: + raise ImportError(import_err_msg) + + if not isinstance(collection, Collection): + raise Exception("argument is not chromadb collection instance") + + return cls(chroma_collection=collection) + @classmethod def from_params( cls, @@ -174,7 +186,7 @@ class ChromaVectorStore(BasePydanticVectorStore): ssl: bool = False, headers: Optional[Dict[str, str]] = None, persist_dir: Optional[str] = None, - collection_kwargs: Optional[dict] = {}, + collection_kwargs: dict = {}, **kwargs: Any, ) -> "ChromaVectorStore": try: diff --git a/tests/vector_stores/test_chromadb.py b/tests/vector_stores/test_chromadb.py index ca389551500ae74431bb148a059165c219e5f31e..a8291a39ba2e1a856770dd6a30cf00c2fbcc3afd 100644 --- a/tests/vector_stores/test_chromadb.py +++ b/tests/vector_stores/test_chromadb.py @@ -38,13 +38,13 @@ except (ImportError, Exception): def test_instance_creation_from_collection() -> None: connection = chromadb.HttpClient(**PARAMS) collection = connection.get_collection(COLLECTION_NAME) - store = ChromaVectorStore(chroma_collection=collection) + store = ChromaVectorStore.from_collection(collection) assert isinstance(store, ChromaVectorStore) @pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") def test_instance_creation_from_http_params() -> None: - store = ChromaVectorStore( + store = ChromaVectorStore.from_params( host=PARAMS["host"], port=PARAMS["port"], collection_name=COLLECTION_NAME, @@ -53,6 +53,16 @@ def test_instance_creation_from_http_params() -> None: assert isinstance(store, ChromaVectorStore) +@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") +def test_instance_creation_from_persist_dir() -> None: + store = ChromaVectorStore.from_params( + persist_dir="./data", + collection_name=COLLECTION_NAME, + collection_kwargs={}, + ) + assert isinstance(store, ChromaVectorStore) + + @pytest.fixture() def vector_store() -> ChromaVectorStore: connection = chromadb.HttpClient(**PARAMS)