From 4394c7f11e907c4a7c9926ae98eb53e6d60a1619 Mon Sep 17 00:00:00 2001 From: Max Jakob <max.jakob@elastic.co> Date: Thu, 21 Mar 2024 05:47:34 +0100 Subject: [PATCH] Elasticsearch: preserve vector store tests (#12079) When migrating to an integration package, the tests had been lost. --- .../vector_stores/elasticsearch/base.py | 5 +- .../pyproject.toml | 1 + .../tests/docker-compose.yml | 20 + .../tests/test_vector_stores_elasticsearch.py | 494 +++++++++++++++++- 4 files changed, 515 insertions(+), 5 deletions(-) create mode 100644 llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py index e79445f527..c32056f71f 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py @@ -228,7 +228,10 @@ class ElasticsearchStore(BasePydanticVectorStore): @staticmethod def get_user_agent() -> str: """Get user agent for elasticsearch client.""" - return "llama_index-py-vs" + import llama_index.core + + version = getattr(llama_index.core, "__version__", "") + return f"llama_index-py-vs/{version}" async def _create_index_if_not_exists( self, index_name: str, dims_length: Optional[int] = None diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml index 5bec75d4ae..12670ff21c 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml @@ -41,6 +41,7 @@ mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" pytest = "7.2.1" +pytest-asyncio = "0.23.6" pytest-mock = "3.11.1" ruff = "0.0.292" tree-sitter-languages = "^1.8.0" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml new file mode 100644 index 0000000000..b225fb27cf --- /dev/null +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch + environment: + - discovery.type=single-node + - xpack.security.enabled=false # security has been disabled, so no login or password is required. + - xpack.security.http.ssl.enabled=false + - xpack.license.self_generated.type=trial + ports: + - "9200:9200" + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:9200/_cluster/health || exit 1", + ] + interval: 10s + retries: 60 diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py index 0c5629f4d2..40da870537 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py @@ -1,7 +1,493 @@ -from llama_index.core.vector_stores.types import BasePydanticVectorStore +import logging +import os +import re +import uuid +from typing import Dict, Generator, List, Union + +import pandas as pd +import pytest + +from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode +from llama_index.core.vector_stores.types import ( + ExactMatchFilter, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryMode, +) from llama_index.vector_stores.elasticsearch import ElasticsearchStore +## +# Start Elasticsearch locally +# cd tests +# docker-compose up elasticsearch +# +# Run tests +# cd tests +# pytest test_vector_stores_elasticsearch.py + + +logging.basicConfig(level=logging.DEBUG) + +try: + import elasticsearch + + es_client = elasticsearch.Elasticsearch("http://localhost:9200") + es_client.info() + + elasticsearch_not_available = False + + es_license = es_client.license.get() + basic_license: bool = es_license["license"]["type"] == "basic" +except (ImportError, Exception) as err: + elasticsearch_not_available = True + basic_license = True + + +@pytest.fixture() +def index_name() -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + +@pytest.fixture(scope="session") +def elasticsearch_connection() -> Union[dict, Generator[dict, None, None]]: + # Running this integration test with Elastic Cloud + # Required for in-stack inference testing (ELSER + model_id) + from elasticsearch import Elasticsearch + + es_url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + es_username = os.environ.get("ES_USERNAME", "elastic") + es_password = os.environ.get("ES_PASSWORD", "changeme") + + if cloud_id: + yield { + "es_cloud_id": cloud_id, + "es_user": es_username, + "es_password": es_password, + } + es = Elasticsearch(cloud_id=cloud_id, basic_auth=(es_username, es_password)) + + else: + # Running this integration test with local docker instance + yield { + "es_url": es_url, + } + es = Elasticsearch(hosts=es_url) + + # Clear all indexes + index_names = es.indices.get(index="_all").keys() + for index_name in index_names: + if index_name.startswith("test_"): + es.indices.delete(index=index_name) + es.indices.refresh(index="_all") + return {} + + +@pytest.fixture(scope="session") +def node_embeddings() -> List[TextNode]: + return [ + TextNode( + text="lorem ipsum", + id_="c330d77f-90bd-4c51-9ed2-57d8d693b3b0", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")}, + metadata={ + "author": "Stephen King", + "theme": "Friendship", + }, + embedding=[1.0, 0.0, 0.0], + ), + TextNode( + text="lorem ipsum", + id_="c3d1e1dd-8fb4-4b8f-b7ea-7fa96038d39d", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")}, + metadata={ + "director": "Francis Ford Coppola", + "theme": "Mafia", + }, + embedding=[0.0, 1.0, 0.0], + ), + TextNode( + text="lorem ipsum", + id_="c3ew11cd-8fb4-4b8f-b7ea-7fa96038d39d", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")}, + metadata={ + "director": "Christopher Nolan", + }, + embedding=[0.0, 0.0, 1.0], + ), + TextNode( + text="I was taught that the way of progress was neither swift nor easy.", + id_="0b31ae71-b797-4e88-8495-031371a7752e", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-3")}, + metadate={ + "author": "Marie Curie", + }, + embedding=[0.0, 0.0, 0.9], + ), + TextNode( + text=( + "The important thing is not to stop questioning." + + " Curiosity has its own reason for existing." + ), + id_="bd2e080b-159a-4030-acc3-d98afd2ba49b", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-4")}, + metadate={ + "author": "Albert Einstein", + }, + embedding=[0.0, 0.0, 0.5], + ), + TextNode( + text=( + "I am no bird; and no net ensnares me;" + + " I am a free human being with an independent will." + ), + id_="f658de3b-8cef-4d1c-8bed-9a263c907251", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-5")}, + metadate={ + "author": "Charlotte Bronte", + }, + embedding=[0.0, 0.0, 0.3], + ), + ] + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +def test_instance_creation(index_name: str, elasticsearch_connection: Dict) -> None: + es_store = ElasticsearchStore( + **elasticsearch_connection, + index_name=index_name, + ) + assert isinstance(es_store, ElasticsearchStore) + + +@pytest.fixture() +def es_store(index_name: str, elasticsearch_connection: Dict) -> ElasticsearchStore: + return ElasticsearchStore( + **elasticsearch_connection, + index_name=index_name, + distance_strategy="EUCLIDEAN_DISTANCE", + ) + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_query( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) + ) + assert res.nodes + assert res.nodes[0].get_content() == "lorem ipsum" + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_text_query( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + VectorStoreQuery( + query_str="lorem", + mode=VectorStoreQueryMode.TEXT_SEARCH, + similarity_top_k=1, + ) + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + VectorStoreQuery( + query_str="lorem", + mode=VectorStoreQueryMode.TEXT_SEARCH, + similarity_top_k=1, + ) + ) + assert res.nodes + assert res.nodes[0].get_content() == "lorem ipsum" + + +@pytest.mark.skipif( + elasticsearch_not_available, + basic_license, + reason="elasticsearch is not available or license is basic", +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_hybrid_query( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + VectorStoreQuery( + query_str="lorem", + query_embedding=[1.0, 0.0, 0.0], + mode=VectorStoreQueryMode.HYBRID, + similarity_top_k=1, + ) + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + VectorStoreQuery( + query_str="lorem", + query_embedding=[1.0, 0.0, 0.0], + mode=VectorStoreQueryMode.HYBRID, + similarity_top_k=1, + ) + ) + assert res.nodes + assert res.nodes[0].get_content() == "lorem ipsum" + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_query_with_filters( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + filters = MetadataFilters( + filters=[ExactMatchFilter(key="author", value="Stephen King")] + ) + q = VectorStoreQuery( + query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10, filters=filters + ) + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery(q) + else: + es_store.add(node_embeddings) + res = es_store.query(q) + assert res.nodes + assert len(res.nodes) == 1 + assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0" + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_query_with_es_filters( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10) + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}] + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}] + ) + assert res.nodes + assert len(res.nodes) == 1 + assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0" + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_query_and_delete( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) + + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery(q) + else: + es_store.add(node_embeddings) + res = es_store.query(q) + assert res.nodes + assert len(res.nodes) == 1 + assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0" + + if use_async: + await es_store.adelete("test-0") + res = await es_store.aquery(q) + else: + es_store.delete("test-0") + res = es_store.query(q) + assert res.nodes + assert len(res.nodes) == 1 + assert res.nodes[0].node_id == "f658de3b-8cef-4d1c-8bed-9a263c907251" + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_embed_query_ranked( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + einstein_bronte_curie = [ + "bd2e080b-159a-4030-acc3-d98afd2ba49b", + "f658de3b-8cef-4d1c-8bed-9a263c907251", + "0b31ae71-b797-4e88-8495-031371a7752e", + ] + query_get_1_first = VectorStoreQuery( + query_embedding=[0.0, 0.0, 0.5], similarity_top_k=3 + ) + await check_top_match( + es_store, node_embeddings, use_async, query_get_1_first, *einstein_bronte_curie + ) + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_text_query_ranked( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + node1 = "0b31ae71-b797-4e88-8495-031371a7752e" + node2 = "f658de3b-8cef-4d1c-8bed-9a263c907251" + + query_get_1_first = VectorStoreQuery( + query_str="I was", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=2 + ) + await check_top_match( + es_store, node_embeddings, use_async, query_get_1_first, node1, node2 + ) + + query_get_2_first = VectorStoreQuery( + query_str="I am", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=2 + ) + await check_top_match( + es_store, node_embeddings, use_async, query_get_2_first, node2, node1 + ) + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_text_query_ranked_hybrid( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + node1 = "f658de3b-8cef-4d1c-8bed-9a263c907251" + node2 = "0b31ae71-b797-4e88-8495-031371a7752e" + + query_get_1_first = VectorStoreQuery( + query_str="I was", + query_embedding=[0.0, 0.0, 0.5], + mode=VectorStoreQueryMode.HYBRID, + similarity_top_k=2, + ) + await check_top_match( + es_store, node_embeddings, use_async, query_get_1_first, node1, node2 + ) + + +@pytest.mark.skipif( + elasticsearch_not_available, reason="elasticsearch is not available" +) +def test_check_user_agent( + index_name: str, + node_embeddings: List[TextNode], +) -> None: + from elastic_transport import AsyncTransport + from elasticsearch import AsyncElasticsearch + + class CustomTransport(AsyncTransport): + requests = [] + + async def perform_request(self, *args, **kwargs): # type: ignore + self.requests.append(kwargs) + return await super().perform_request(*args, **kwargs) + + es_client_instance = AsyncElasticsearch( + "http://localhost:9200", + transport_class=CustomTransport, + ) + + es_store = ElasticsearchStore( + es_client=es_client_instance, + index_name=index_name, + distance_strategy="EUCLIDEAN_DISTANCE", + ) + + es_store.add(node_embeddings) + + user_agent = es_client_instance.transport.requests[0]["headers"][ # type: ignore + "user-agent" + ] + pattern = r"^llama_index-py-vs/\d+\.\d+\.\d+(\.post\d+)?$" + match = re.match(pattern, user_agent) + + assert ( + match is not None + ), f"The string '{user_agent}' does not match the expected user-agent." + -def test_class(): - names_of_base_classes = [b.__name__ for b in ElasticsearchStore.__mro__] - assert BasePydanticVectorStore.__name__ in names_of_base_classes +async def check_top_match( + es_store: ElasticsearchStore, + node_embeddings: List[TextNode], + use_async: bool, + query: VectorStoreQuery, + *expected_nodes: str, +) -> None: + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery(query) + else: + es_store.add(node_embeddings) + res = es_store.query(query) + assert res.nodes + # test the nodes are return in the expected order + for i, node in enumerate(expected_nodes): + assert res.nodes[i].node_id == node + # test the returned order is in descending order w.r.t. similarities + # test similarities are normalized (0, 1) + df = pd.DataFrame({"node": res.nodes, "sim": res.similarities, "id": res.ids}) + sorted_by_sim = df.sort_values(by="sim", ascending=False) + for idx, item in enumerate(sorted_by_sim.itertuples()): + res_node = res.nodes[idx] + assert res_node.node_id == item.id + assert 0 <= item.sim <= 1 -- GitLab