From 16f444ce760c89d8cc86cfe2cfc7b9e0cd45ce29 Mon Sep 17 00:00:00 2001 From: fbpo23 <77107111+fbpo23@users.noreply.github.com> Date: Fri, 9 Feb 2024 16:23:31 +0000 Subject: [PATCH] Adding the possibility to use the IN operator for PGVectorStore (#10547) --- llama_index/vector_stores/postgres.py | 18 +++++++++--- tests/vector_stores/test_postgres.py | 40 +++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py index 13ab927c21..f5686afb7d 100644 --- a/llama_index/vector_stores/postgres.py +++ b/llama_index/vector_stores/postgres.py @@ -348,6 +348,8 @@ class PGVectorStore(BasePydanticVectorStore): return ">=" elif operator == FilterOperator.LTE: return "<=" + elif operator == FilterOperator.IN: + return "@>" else: _logger.warning(f"Unknown operator: {operator}, fallback to '='") return "=" @@ -374,10 +376,18 @@ class PGVectorStore(BasePydanticVectorStore): stmt = stmt.where( # type: ignore sqlalchemy_conditions[metadata_filters.condition]( *( - sqlalchemy.text( - f"metadata_->>'{filter_.key}' " - f"{self._to_postgres_operator(filter_.operator)} " - f"'{filter_.value}'" + ( + sqlalchemy.text( + f"metadata_::jsonb->'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'[\"{filter_.value}\"]'" + ) + if filter_.operator == FilterOperator.IN + else sqlalchemy.text( + f"metadata_->>'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'{filter_.value}'" + ) ) for filter_ in metadata_filters.filters ) diff --git a/tests/vector_stores/test_postgres.py b/tests/vector_stores/test_postgres.py index fcb0cf4f8a..5150db239d 100644 --- a/tests/vector_stores/test_postgres.py +++ b/tests/vector_stores/test_postgres.py @@ -13,6 +13,8 @@ from llama_index.vector_stores import PGVectorStore from llama_index.vector_stores.loading import load_vector_store from llama_index.vector_stores.types import ( ExactMatchFilter, + FilterOperator, + MetadataFilter, MetadataFilters, VectorStoreQuery, VectorStoreQueryMode, @@ -124,6 +126,13 @@ def node_embeddings() -> List[TextNode]: extra_info={"test_key": "test_value"}, embedding=_get_sample_vector(0.1), ), + TextNode( + text="consectetur adipiscing elit", + id_="ccc", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="ccc")}, + extra_info={"test_key_list": ["test_value"]}, + embedding=_get_sample_vector(0.1), + ), ] @@ -246,6 +255,37 @@ async def test_add_to_db_and_query_with_metadata_filters( assert res.nodes[0].node_id == "bbb" +@pytest.mark.skipif(postgres_not_available, reason="postgres db is not available") +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_db_and_query_with_metadata_filters_with_in_operator( + pg: PGVectorStore, node_embeddings: List[TextNode], use_async: bool +) -> None: + if use_async: + await pg.async_add(node_embeddings) + else: + pg.add(node_embeddings) + assert isinstance(pg, PGVectorStore) + assert hasattr(pg, "_engine") + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="test_key_list", value="test_value", operator=FilterOperator.IN + ) + ] + ) + q = VectorStoreQuery( + query_embedding=_get_sample_vector(0.5), similarity_top_k=10, filters=filters + ) + if use_async: + res = await pg.aquery(q) + else: + res = pg.query(q) + assert res.nodes + assert len(res.nodes) == 1 + assert res.nodes[0].node_id == "ccc" + + @pytest.mark.skipif(postgres_not_available, reason="postgres db is not available") @pytest.mark.asyncio() @pytest.mark.parametrize("use_async", [True, False]) -- GitLab