diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py index d0d718dda9ba897a136a5ccd590b0b286d32a5db..ccf10b02eacc3eb257c3092402d0b6e3f8ba9eee 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py @@ -1,16 +1,12 @@ import os -from typing import Dict, List +from typing import Dict, List, Generator import pytest from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode -from llama_index.core.vector_stores.types import VectorStoreQuery from llama_index.vector_stores.chroma import ChromaVectorStore +from llama_index.core.vector_stores.types import VectorStoreQuery ## -# Start chromadb locally -# cd tests -# docker-compose up -# # Run tests # cd tests/vector_stores # pytest test_chromadb.py @@ -29,20 +25,19 @@ try: conn__ = chromadb.HttpClient(**PARAMS) # type: ignore conn__.get_or_create_collection(COLLECTION_NAME) - chromadb_not_available = False + http_client_chromadb_mode = True except (ImportError, Exception): - chromadb_not_available = True + http_client_chromadb_mode = False -@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") -def test_instance_creation_from_collection() -> None: - connection = chromadb.HttpClient(**PARAMS) - collection = connection.get_collection(COLLECTION_NAME) - store = ChromaVectorStore.from_collection(collection) - assert isinstance(store, ChromaVectorStore) - - -@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") +# To test chromadb http-client functionality do: +# cd tests +# docker-compose up +# +@pytest.mark.skipif( + http_client_chromadb_mode is False, + reason="chromadb is not running in http client mode", +) def test_instance_creation_from_http_params() -> None: store = ChromaVectorStore.from_params( host=PARAMS["host"], @@ -53,7 +48,13 @@ def test_instance_creation_from_http_params() -> None: assert isinstance(store, ChromaVectorStore) -@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") +def test_instance_creation_from_collection() -> None: + chroma_client = chromadb.Client() + collection = chroma_client.get_or_create_collection(COLLECTION_NAME) + store = ChromaVectorStore.from_collection(collection) + assert isinstance(store, ChromaVectorStore) + + def test_instance_creation_from_persist_dir() -> None: store = ChromaVectorStore.from_params( persist_dir="./data", @@ -64,10 +65,11 @@ def test_instance_creation_from_persist_dir() -> None: @pytest.fixture() -def vector_store() -> ChromaVectorStore: - connection = chromadb.HttpClient(**PARAMS) - collection = connection.get_collection(COLLECTION_NAME) - return ChromaVectorStore(chroma_collection=collection) +def vector_store() -> Generator[ChromaVectorStore, None, None]: + chroma_client = chromadb.Client() + collection = chroma_client.get_or_create_collection(COLLECTION_NAME) + yield ChromaVectorStore(chroma_collection=collection) + chroma_client.delete_collection(name=COLLECTION_NAME) @pytest.fixture(scope="session") @@ -106,7 +108,7 @@ def node_embeddings() -> List[TextNode]: text="I was taught that the way of progress was neither swift nor easy.", id_="0b31ae71-b797-4e88-8495-031371a7752e", relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-3")}, - metadate={ + metadata={ "author": "Marie Curie", }, embedding=[0.0, 0.0, 0.9], @@ -118,7 +120,7 @@ def node_embeddings() -> List[TextNode]: ), id_="bd2e080b-159a-4030-acc3-d98afd2ba49b", relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-4")}, - metadate={ + metadata={ "author": "Albert Einstein", }, embedding=[0.0, 0.0, 0.5], @@ -130,7 +132,7 @@ def node_embeddings() -> List[TextNode]: ), id_="f658de3b-8cef-4d1c-8bed-9a263c907251", relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-5")}, - metadate={ + metadata={ "author": "Charlotte Bronte", }, embedding=[0.0, 0.0, 0.3], @@ -138,7 +140,6 @@ def node_embeddings() -> List[TextNode]: ] -@pytest.mark.skipif(chromadb_not_available, reason="chromadb is not available") @pytest.mark.asyncio() @pytest.mark.parametrize("use_async", [True, False]) async def test_add_to_chromadb_and_query(