diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py index 13ab927c21520b2e10ac4a107f768cb025f9af29..f5686afb7d3385c1d7397e05cc28b03ceea80da6 100644 --- a/llama_index/vector_stores/postgres.py +++ b/llama_index/vector_stores/postgres.py @@ -348,6 +348,8 @@ class PGVectorStore(BasePydanticVectorStore): return ">=" elif operator == FilterOperator.LTE: return "<=" + elif operator == FilterOperator.IN: + return "@>" else: _logger.warning(f"Unknown operator: {operator}, fallback to '='") return "=" @@ -374,10 +376,18 @@ class PGVectorStore(BasePydanticVectorStore): stmt = stmt.where( # type: ignore sqlalchemy_conditions[metadata_filters.condition]( *( - sqlalchemy.text( - f"metadata_->>'{filter_.key}' " - f"{self._to_postgres_operator(filter_.operator)} " - f"'{filter_.value}'" + ( + sqlalchemy.text( + f"metadata_::jsonb->'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'[\"{filter_.value}\"]'" + ) + if filter_.operator == FilterOperator.IN + else sqlalchemy.text( + f"metadata_->>'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'{filter_.value}'" + ) ) for filter_ in metadata_filters.filters ) diff --git a/tests/vector_stores/test_postgres.py b/tests/vector_stores/test_postgres.py index fcb0cf4f8a137654155f27d1588845bb51377f61..5150db239d353e964810af588adb5ba8f53f406b 100644 --- a/tests/vector_stores/test_postgres.py +++ b/tests/vector_stores/test_postgres.py @@ -13,6 +13,8 @@ from llama_index.vector_stores import PGVectorStore from llama_index.vector_stores.loading import load_vector_store from llama_index.vector_stores.types import ( ExactMatchFilter, + FilterOperator, + MetadataFilter, MetadataFilters, VectorStoreQuery, VectorStoreQueryMode, @@ -124,6 +126,13 @@ def node_embeddings() -> List[TextNode]: extra_info={"test_key": "test_value"}, embedding=_get_sample_vector(0.1), ), + TextNode( + text="consectetur adipiscing elit", + id_="ccc", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="ccc")}, + extra_info={"test_key_list": ["test_value"]}, + embedding=_get_sample_vector(0.1), + ), ] @@ -246,6 +255,37 @@ async def test_add_to_db_and_query_with_metadata_filters( assert res.nodes[0].node_id == "bbb" +@pytest.mark.skipif(postgres_not_available, reason="postgres db is not available") +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_db_and_query_with_metadata_filters_with_in_operator( + pg: PGVectorStore, node_embeddings: List[TextNode], use_async: bool +) -> None: + if use_async: + await pg.async_add(node_embeddings) + else: + pg.add(node_embeddings) + assert isinstance(pg, PGVectorStore) + assert hasattr(pg, "_engine") + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="test_key_list", value="test_value", operator=FilterOperator.IN + ) + ] + ) + q = VectorStoreQuery( + query_embedding=_get_sample_vector(0.5), similarity_top_k=10, filters=filters + ) + if use_async: + res = await pg.aquery(q) + else: + res = pg.query(q) + assert res.nodes + assert len(res.nodes) == 1 + assert res.nodes[0].node_id == "ccc" + + @pytest.mark.skipif(postgres_not_available, reason="postgres db is not available") @pytest.mark.asyncio() @pytest.mark.parametrize("use_async", [True, False])