From 4394c7f11e907c4a7c9926ae98eb53e6d60a1619 Mon Sep 17 00:00:00 2001
From: Max Jakob <max.jakob@elastic.co>
Date: Thu, 21 Mar 2024 05:47:34 +0100
Subject: [PATCH] Elasticsearch: preserve vector store tests (#12079)

When migrating to an integration package, the tests had been lost.
---
 .../vector_stores/elasticsearch/base.py       |   5 +-
 .../pyproject.toml                            |   1 +
 .../tests/docker-compose.yml                  |  20 +
 .../tests/test_vector_stores_elasticsearch.py | 494 +++++++++++++++++-
 4 files changed, 515 insertions(+), 5 deletions(-)
 create mode 100644 llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml

diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py
index e79445f527..c32056f71f 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/base.py
@@ -228,7 +228,10 @@ class ElasticsearchStore(BasePydanticVectorStore):
     @staticmethod
     def get_user_agent() -> str:
         """Get user agent for elasticsearch client."""
-        return "llama_index-py-vs"
+        import llama_index.core
+
+        version = getattr(llama_index.core, "__version__", "")
+        return f"llama_index-py-vs/{version}"
 
     async def _create_index_if_not_exists(
         self, index_name: str, dims_length: Optional[int] = None
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml
index 5bec75d4ae..12670ff21c 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/pyproject.toml
@@ -41,6 +41,7 @@ mypy = "0.991"
 pre-commit = "3.2.0"
 pylint = "2.15.10"
 pytest = "7.2.1"
+pytest-asyncio = "0.23.6"
 pytest-mock = "3.11.1"
 ruff = "0.0.292"
 tree-sitter-languages = "^1.8.0"
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml
new file mode 100644
index 0000000000..b225fb27cf
--- /dev/null
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/docker-compose.yml
@@ -0,0 +1,20 @@
+version: "3"
+
+services:
+  elasticsearch:
+    image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
+    environment:
+      - discovery.type=single-node
+      - xpack.security.enabled=false # security has been disabled, so no login or password is required.
+      - xpack.security.http.ssl.enabled=false
+      - xpack.license.self_generated.type=trial
+    ports:
+      - "9200:9200"
+    healthcheck:
+      test:
+        [
+          "CMD-SHELL",
+          "curl --silent --fail http://localhost:9200/_cluster/health || exit 1",
+        ]
+      interval: 10s
+      retries: 60
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py
index 0c5629f4d2..40da870537 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-elasticsearch/tests/test_vector_stores_elasticsearch.py
@@ -1,7 +1,493 @@
-from llama_index.core.vector_stores.types import BasePydanticVectorStore
+import logging
+import os
+import re
+import uuid
+from typing import Dict, Generator, List, Union
+
+import pandas as pd
+import pytest
+
+from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode
+from llama_index.core.vector_stores.types import (
+    ExactMatchFilter,
+    MetadataFilters,
+    VectorStoreQuery,
+    VectorStoreQueryMode,
+)
 from llama_index.vector_stores.elasticsearch import ElasticsearchStore
 
+##
+# Start Elasticsearch locally
+# cd tests
+# docker-compose up elasticsearch
+#
+# Run tests
+# cd tests
+# pytest test_vector_stores_elasticsearch.py
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+try:
+    import elasticsearch
+
+    es_client = elasticsearch.Elasticsearch("http://localhost:9200")
+    es_client.info()
+
+    elasticsearch_not_available = False
+
+    es_license = es_client.license.get()
+    basic_license: bool = es_license["license"]["type"] == "basic"
+except (ImportError, Exception) as err:
+    elasticsearch_not_available = True
+    basic_license = True
+
+
+@pytest.fixture()
+def index_name() -> str:
+    """Return the index name."""
+    return f"test_{uuid.uuid4().hex}"
+
+
+@pytest.fixture(scope="session")
+def elasticsearch_connection() -> Union[dict, Generator[dict, None, None]]:
+    # Running this integration test with Elastic Cloud
+    # Required for in-stack inference testing (ELSER + model_id)
+    from elasticsearch import Elasticsearch
+
+    es_url = os.environ.get("ES_URL", "http://localhost:9200")
+    cloud_id = os.environ.get("ES_CLOUD_ID")
+    es_username = os.environ.get("ES_USERNAME", "elastic")
+    es_password = os.environ.get("ES_PASSWORD", "changeme")
+
+    if cloud_id:
+        yield {
+            "es_cloud_id": cloud_id,
+            "es_user": es_username,
+            "es_password": es_password,
+        }
+        es = Elasticsearch(cloud_id=cloud_id, basic_auth=(es_username, es_password))
+
+    else:
+        # Running this integration test with local docker instance
+        yield {
+            "es_url": es_url,
+        }
+        es = Elasticsearch(hosts=es_url)
+
+    # Clear all indexes
+    index_names = es.indices.get(index="_all").keys()
+    for index_name in index_names:
+        if index_name.startswith("test_"):
+            es.indices.delete(index=index_name)
+    es.indices.refresh(index="_all")
+    return {}
+
+
+@pytest.fixture(scope="session")
+def node_embeddings() -> List[TextNode]:
+    return [
+        TextNode(
+            text="lorem ipsum",
+            id_="c330d77f-90bd-4c51-9ed2-57d8d693b3b0",
+            relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")},
+            metadata={
+                "author": "Stephen King",
+                "theme": "Friendship",
+            },
+            embedding=[1.0, 0.0, 0.0],
+        ),
+        TextNode(
+            text="lorem ipsum",
+            id_="c3d1e1dd-8fb4-4b8f-b7ea-7fa96038d39d",
+            relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")},
+            metadata={
+                "director": "Francis Ford Coppola",
+                "theme": "Mafia",
+            },
+            embedding=[0.0, 1.0, 0.0],
+        ),
+        TextNode(
+            text="lorem ipsum",
+            id_="c3ew11cd-8fb4-4b8f-b7ea-7fa96038d39d",
+            relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")},
+            metadata={
+                "director": "Christopher Nolan",
+            },
+            embedding=[0.0, 0.0, 1.0],
+        ),
+        TextNode(
+            text="I was taught that the way of progress was neither swift nor easy.",
+            id_="0b31ae71-b797-4e88-8495-031371a7752e",
+            relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-3")},
+            metadate={
+                "author": "Marie Curie",
+            },
+            embedding=[0.0, 0.0, 0.9],
+        ),
+        TextNode(
+            text=(
+                "The important thing is not to stop questioning."
+                + " Curiosity has its own reason for existing."
+            ),
+            id_="bd2e080b-159a-4030-acc3-d98afd2ba49b",
+            relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-4")},
+            metadate={
+                "author": "Albert Einstein",
+            },
+            embedding=[0.0, 0.0, 0.5],
+        ),
+        TextNode(
+            text=(
+                "I am no bird; and no net ensnares me;"
+                + " I am a free human being with an independent will."
+            ),
+            id_="f658de3b-8cef-4d1c-8bed-9a263c907251",
+            relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-5")},
+            metadate={
+                "author": "Charlotte Bronte",
+            },
+            embedding=[0.0, 0.0, 0.3],
+        ),
+    ]
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+def test_instance_creation(index_name: str, elasticsearch_connection: Dict) -> None:
+    es_store = ElasticsearchStore(
+        **elasticsearch_connection,
+        index_name=index_name,
+    )
+    assert isinstance(es_store, ElasticsearchStore)
+
+
+@pytest.fixture()
+def es_store(index_name: str, elasticsearch_connection: Dict) -> ElasticsearchStore:
+    return ElasticsearchStore(
+        **elasticsearch_connection,
+        index_name=index_name,
+        distance_strategy="EUCLIDEAN_DISTANCE",
+    )
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_and_query(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(
+            VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1)
+        )
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(
+            VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1)
+        )
+    assert res.nodes
+    assert res.nodes[0].get_content() == "lorem ipsum"
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_and_text_query(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(
+            VectorStoreQuery(
+                query_str="lorem",
+                mode=VectorStoreQueryMode.TEXT_SEARCH,
+                similarity_top_k=1,
+            )
+        )
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(
+            VectorStoreQuery(
+                query_str="lorem",
+                mode=VectorStoreQueryMode.TEXT_SEARCH,
+                similarity_top_k=1,
+            )
+        )
+    assert res.nodes
+    assert res.nodes[0].get_content() == "lorem ipsum"
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available,
+    basic_license,
+    reason="elasticsearch is not available or license is basic",
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_and_hybrid_query(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(
+            VectorStoreQuery(
+                query_str="lorem",
+                query_embedding=[1.0, 0.0, 0.0],
+                mode=VectorStoreQueryMode.HYBRID,
+                similarity_top_k=1,
+            )
+        )
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(
+            VectorStoreQuery(
+                query_str="lorem",
+                query_embedding=[1.0, 0.0, 0.0],
+                mode=VectorStoreQueryMode.HYBRID,
+                similarity_top_k=1,
+            )
+        )
+    assert res.nodes
+    assert res.nodes[0].get_content() == "lorem ipsum"
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_query_with_filters(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    filters = MetadataFilters(
+        filters=[ExactMatchFilter(key="author", value="Stephen King")]
+    )
+    q = VectorStoreQuery(
+        query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10, filters=filters
+    )
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(q)
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(q)
+    assert res.nodes
+    assert len(res.nodes) == 1
+    assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0"
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_query_with_es_filters(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10)
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(
+            q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}]
+        )
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(
+            q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}]
+        )
+    assert res.nodes
+    assert len(res.nodes) == 1
+    assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0"
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_query_and_delete(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1)
+
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(q)
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(q)
+    assert res.nodes
+    assert len(res.nodes) == 1
+    assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0"
+
+    if use_async:
+        await es_store.adelete("test-0")
+        res = await es_store.aquery(q)
+    else:
+        es_store.delete("test-0")
+        res = es_store.query(q)
+    assert res.nodes
+    assert len(res.nodes) == 1
+    assert res.nodes[0].node_id == "f658de3b-8cef-4d1c-8bed-9a263c907251"
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_and_embed_query_ranked(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    einstein_bronte_curie = [
+        "bd2e080b-159a-4030-acc3-d98afd2ba49b",
+        "f658de3b-8cef-4d1c-8bed-9a263c907251",
+        "0b31ae71-b797-4e88-8495-031371a7752e",
+    ]
+    query_get_1_first = VectorStoreQuery(
+        query_embedding=[0.0, 0.0, 0.5], similarity_top_k=3
+    )
+    await check_top_match(
+        es_store, node_embeddings, use_async, query_get_1_first, *einstein_bronte_curie
+    )
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_and_text_query_ranked(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    node1 = "0b31ae71-b797-4e88-8495-031371a7752e"
+    node2 = "f658de3b-8cef-4d1c-8bed-9a263c907251"
+
+    query_get_1_first = VectorStoreQuery(
+        query_str="I was", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=2
+    )
+    await check_top_match(
+        es_store, node_embeddings, use_async, query_get_1_first, node1, node2
+    )
+
+    query_get_2_first = VectorStoreQuery(
+        query_str="I am", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=2
+    )
+    await check_top_match(
+        es_store, node_embeddings, use_async, query_get_2_first, node2, node1
+    )
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("use_async", [True, False])
+async def test_add_to_es_and_text_query_ranked_hybrid(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+) -> None:
+    node1 = "f658de3b-8cef-4d1c-8bed-9a263c907251"
+    node2 = "0b31ae71-b797-4e88-8495-031371a7752e"
+
+    query_get_1_first = VectorStoreQuery(
+        query_str="I was",
+        query_embedding=[0.0, 0.0, 0.5],
+        mode=VectorStoreQueryMode.HYBRID,
+        similarity_top_k=2,
+    )
+    await check_top_match(
+        es_store, node_embeddings, use_async, query_get_1_first, node1, node2
+    )
+
+
+@pytest.mark.skipif(
+    elasticsearch_not_available, reason="elasticsearch is not available"
+)
+def test_check_user_agent(
+    index_name: str,
+    node_embeddings: List[TextNode],
+) -> None:
+    from elastic_transport import AsyncTransport
+    from elasticsearch import AsyncElasticsearch
+
+    class CustomTransport(AsyncTransport):
+        requests = []
+
+        async def perform_request(self, *args, **kwargs):  # type: ignore
+            self.requests.append(kwargs)
+            return await super().perform_request(*args, **kwargs)
+
+    es_client_instance = AsyncElasticsearch(
+        "http://localhost:9200",
+        transport_class=CustomTransport,
+    )
+
+    es_store = ElasticsearchStore(
+        es_client=es_client_instance,
+        index_name=index_name,
+        distance_strategy="EUCLIDEAN_DISTANCE",
+    )
+
+    es_store.add(node_embeddings)
+
+    user_agent = es_client_instance.transport.requests[0]["headers"][  # type: ignore
+        "user-agent"
+    ]
+    pattern = r"^llama_index-py-vs/\d+\.\d+\.\d+(\.post\d+)?$"
+    match = re.match(pattern, user_agent)
+
+    assert (
+        match is not None
+    ), f"The string '{user_agent}' does not match the expected user-agent."
+
 
-def test_class():
-    names_of_base_classes = [b.__name__ for b in ElasticsearchStore.__mro__]
-    assert BasePydanticVectorStore.__name__ in names_of_base_classes
+async def check_top_match(
+    es_store: ElasticsearchStore,
+    node_embeddings: List[TextNode],
+    use_async: bool,
+    query: VectorStoreQuery,
+    *expected_nodes: str,
+) -> None:
+    if use_async:
+        await es_store.async_add(node_embeddings)
+        res = await es_store.aquery(query)
+    else:
+        es_store.add(node_embeddings)
+        res = es_store.query(query)
+    assert res.nodes
+    # test the nodes are return in the expected order
+    for i, node in enumerate(expected_nodes):
+        assert res.nodes[i].node_id == node
+    # test the returned order is in descending order w.r.t. similarities
+    # test similarities are normalized (0, 1)
+    df = pd.DataFrame({"node": res.nodes, "sim": res.similarities, "id": res.ids})
+    sorted_by_sim = df.sort_values(by="sim", ascending=False)
+    for idx, item in enumerate(sorted_by_sim.itertuples()):
+        res_node = res.nodes[idx]
+        assert res_node.node_id == item.id
+        assert 0 <= item.sim <= 1
-- 
GitLab