From 81cdc9880702b92629723bf6a48899cbd102b36b Mon Sep 17 00:00:00 2001 From: Jules Kuehn <jk@jules.lol> Date: Fri, 8 Mar 2024 21:04:27 -0500 Subject: [PATCH] feat: nested metadata filters (PGVectorStore) (#11778) * feat: nested metadata filters (PGVectorStore) * minimize changes --- docs/examples/vector_stores/postgres.ipynb | 359 +++++++++++++++++- .../llama_index/core/vector_stores/types.py | 8 +- llama-index-core/pyproject.toml | 4 +- .../vector_stores/postgres/base.py | 67 ++-- 4 files changed, 381 insertions(+), 57 deletions(-) diff --git a/docs/examples/vector_stores/postgres.ipynb b/docs/examples/vector_stores/postgres.ipynb index 21183edd68..6643d7e835 100644 --- a/docs/examples/vector_stores/postgres.ipynb +++ b/docs/examples/vector_stores/postgres.ipynb @@ -87,7 +87,7 @@ "import os\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = \"<your key>\"\n", - "openai.api_key = \"<your key>\"" + "openai.api_key = os.environ[\"OPENAI_API_KEY\"]" ] }, { @@ -130,7 +130,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Document ID: d05d1211-b9af-4b05-8da6-956e4b389467\n" + "Document ID: 88efac05-2277-4eda-a94c-c9247c9aca1c\n" ] } ], @@ -185,12 +185,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1358441c0c864d87a860820ed8cf2b2c", + "model_id": "40b4c3becfc64c5184360b8b8e81ca9a", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Parsing documents into nodes: 0%| | 0/1 [00:00<?, ?it/s]" + "Parsing nodes: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, @@ -199,12 +199,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "69a77c77ff1c48cc8107b445ae4fa0cc", + "model_id": "37763ad7b17f4481a7e67df379304d31", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Generating embeddings: 0%| | 0/17 [00:00<?, ?it/s]" + "Generating embeddings: 0%| | 0/22 [00:00<?, ?it/s]" ] }, "metadata": {}, @@ -261,11 +261,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "The author worked on writing and programming before college. They wrote short stories and tried\n", - "writing programs on an IBM 1401 computer. They also built a microcomputer and started programming on\n", - "it, writing simple games and a word processor. In college, the author initially planned to study\n", - "philosophy but switched to AI. They were inspired by a novel called The Moon is a Harsh Mistress and\n", - "a PBS documentary featuring Terry Winograd using SHRDLU.\n" + "The author worked on writing and programming before college. Initially, the author wrote short\n", + "stories and later started programming on an IBM 1401 using an early version of Fortran. The author\n", + "then transitioned to working with microcomputers, building a computer kit and eventually getting a\n", + "TRS-80 to further explore programming. In college, the author initially planned to study philosophy\n", + "but switched to studying AI due to a lack of interest in philosophy courses. The author was inspired\n", + "to work on AI after encountering works like Heinlein's novel \"The Moon is a Harsh Mistress\" and\n", + "seeing Terry Winograd using SHRDLU in a PBS documentary.\n" ] } ], @@ -391,16 +393,7 @@ "execution_count": null, "id": "65a7e133-39da-40c5-b2c5-7af2c0a3a792", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/suo/dev/llama_index/llama_index/vector_stores/postgres.py:217: SAWarning: TypeDecorator TSVector() will not produce a cache key because the ``cache_ok`` attribute is not set to True. This can have significant performance implications including some performance degradations in comparison to prior SQLAlchemy versions. Set this attribute to True if this type object's state is safe to use in a cache key, or False to disable this warning. (Background on this warning at: https://sqlalche.me/e/20/cprf)\n", - " session.commit()\n" - ] - } - ], + "outputs": [], "source": [ "from sqlalchemy import make_url\n", "\n", @@ -458,6 +451,330 @@ "print(hybrid_response)" ] }, + { + "cell_type": "markdown", + "id": "2e5e8083", + "metadata": {}, + "source": [ + "### Metadata filters\n", + "\n", + "PGVectorStore supports storing metadata in nodes, and filtering based on that metadata during the retrieval step." + ] + }, + { + "cell_type": "markdown", + "id": "2d0ad3fc", + "metadata": {}, + "source": [ + "#### Download git commits dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63e90a89", + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p 'data/git_commits/'\n", + "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/csv/commit_history.csv' -O 'data/git_commits/commit_history.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fef41f44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit': '44e41c12ab25e36c202f58e068ced262eadc8d16', 'author': 'Lakshmi Narayanan Sreethar<lakshmi@timescale.com>', 'date': 'Tue Sep 5 21:03:21 2023 +0530', 'change summary': 'Fix segfault in set_integer_now_func', 'change details': 'When an invalid function oid is passed to set_integer_now_func, it finds out that the function oid is invalid but before throwing the error, it calls ReleaseSysCache on an invalid tuple causing a segfault. Fixed that by removing the invalid call to ReleaseSysCache. Fixes #6037 '}\n", + "4167\n" + ] + } + ], + "source": [ + "import csv\n", + "\n", + "with open(\"data/git_commits/commit_history.csv\", \"r\") as f:\n", + " commits = list(csv.DictReader(f))\n", + "\n", + "print(commits[0])\n", + "print(len(commits))" + ] + }, + { + "cell_type": "markdown", + "id": "3b0d9f47", + "metadata": {}, + "source": [ + "#### Add nodes with custom metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3920109b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Node ID: e084ffbd-24e0-4bd9-b7c8-287fe1abd85d\n", + "Text: Fix segfault in set_integer_now_func When an invalid function\n", + "oid is passed to set_integer_now_func, it finds out that the function\n", + "oid is invalid but before throwing the error, it calls ReleaseSysCache\n", + "on an invalid tuple causing a segfault. Fixed that by removing the\n", + "invalid call to ReleaseSysCache. Fixes #6037\n", + "2023-03-22 to 2023-09-05\n", + "{'36882414+akuzm@users.noreply.github.com', 'erik@timescale.com', 'konstantina@timescale.com', 'mats@timescale.com', 'nikhil@timescale.com', 'dmitry@timescale.com', 'jguthrie@timescale.com', 'rafia.sabih@gmail.com', 'engel@sero-systems.de', 'satish.8483@gmail.com', 'me@noctarius.com', 'sven@timescale.com', 'jan@timescale.com', 'lakshmi@timescale.com', 'fabriziomello@gmail.com'}\n" + ] + } + ], + "source": [ + "# Create TextNode for each of the first 100 commits\n", + "from llama_index.core.schema import TextNode\n", + "from datetime import datetime\n", + "\n", + "nodes = []\n", + "dates = set()\n", + "authors = set()\n", + "for commit in commits[:100]:\n", + " author_email = commit[\"author\"].split(\"<\")[1][:-1]\n", + " commit_date = datetime.strptime(\n", + " commit[\"date\"], \"%a %b %d %H:%M:%S %Y %z\"\n", + " ).strftime(\"%Y-%m-%d\")\n", + " commit_text = commit[\"change summary\"]\n", + " if commit[\"change details\"]:\n", + " commit_text += \"\\n\\n\" + commit[\"change details\"]\n", + " nodes.append(\n", + " TextNode(\n", + " text=commit_text,\n", + " metadata={\n", + " \"commit_date\": commit_date,\n", + " \"author\": author_email,\n", + " },\n", + " )\n", + " )\n", + " dates.add(commit_date)\n", + " authors.add(author_email)\n", + "\n", + "print(nodes[0])\n", + "print(min(dates), \"to\", max(dates))\n", + "print(authors)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a638f76a", + "metadata": {}, + "outputs": [], + "source": [ + "vector_store = PGVectorStore.from_params(\n", + " database=db_name,\n", + " host=url.host,\n", + " password=url.password,\n", + " port=url.port,\n", + " user=url.username,\n", + " table_name=\"metadata_filter_demo3\",\n", + " embed_dim=1536, # openai embedding dimension\n", + ")\n", + "\n", + "index = VectorStoreIndex.from_vector_store(vector_store=vector_store)\n", + "index.insert_nodes(nodes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f7cf45", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lakshmi fixed the segfault by removing the invalid call to ReleaseSysCache that was causing the issue.\n" + ] + } + ], + "source": [ + "print(index.as_query_engine().query(\"How did Lakshmi fix the segfault?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "7ab03ed4", + "metadata": {}, + "source": [ + "#### Apply metadata filters\n", + "\n", + "Now we can filter by commit author or by date when retrieving nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa6212e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit_date': '2023-08-07', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-07', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-23', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-07-13', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-27', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-21', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-30', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-10', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-20', 'author': 'sven@timescale.com'}\n" + ] + } + ], + "source": [ + "from llama_index.core.vector_stores.types import (\n", + " MetadataFilter,\n", + " MetadataFilters,\n", + ")\n", + "\n", + "filters = MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(key=\"author\", value=\"mats@timescale.com\"),\n", + " MetadataFilter(key=\"author\", value=\"sven@timescale.com\"),\n", + " ],\n", + " condition=\"or\",\n", + ")\n", + "\n", + "retriever = index.as_retriever(\n", + " similarity_top_k=10,\n", + " filters=filters,\n", + ")\n", + "\n", + "retrieved_nodes = retriever.retrieve(\"What is this software project about?\")\n", + "\n", + "for node in retrieved_nodes:\n", + " print(node.node.metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67c19ec6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit_date': '2023-08-23', 'author': 'erik@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': '36882414+akuzm@users.noreply.github.com'}\n", + "{'commit_date': '2023-08-17', 'author': 'konstantina@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-23', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': '36882414+akuzm@users.noreply.github.com'}\n", + "{'commit_date': '2023-08-21', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-24', 'author': 'lakshmi@timescale.com'}\n", + "{'commit_date': '2023-08-16', 'author': '36882414+akuzm@users.noreply.github.com'}\n", + "{'commit_date': '2023-08-20', 'author': 'sven@timescale.com'}\n" + ] + } + ], + "source": [ + "filters = MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(key=\"commit_date\", value=\"2023-08-15\", operator=\">=\"),\n", + " MetadataFilter(key=\"commit_date\", value=\"2023-08-25\", operator=\"<=\"),\n", + " ],\n", + " condition=\"and\",\n", + ")\n", + "\n", + "retriever = index.as_retriever(\n", + " similarity_top_k=10,\n", + " filters=filters,\n", + ")\n", + "\n", + "retrieved_nodes = retriever.retrieve(\"What is this software project about?\")\n", + "\n", + "for node in retrieved_nodes:\n", + " print(node.node.metadata)" + ] + }, + { + "cell_type": "markdown", + "id": "4f6e9cdf", + "metadata": {}, + "source": [ + "#### Apply nested filters\n", + "\n", + "In the above examples, we combined multiple filters using AND or OR. We can also combine multiple sets of filters.\n", + "\n", + "e.g. in SQL:\n", + "```sql\n", + "WHERE (commit_date >= '2023-08-01' AND commit_date <= '2023-08-15') AND (author = 'mats@timescale.com' OR author = 'sven@timescale.com')\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94f20be7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit_date': '2023-08-07', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-07', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-10', 'author': 'mats@timescale.com'}\n" + ] + } + ], + "source": [ + "filters = MetadataFilters(\n", + " filters=[\n", + " MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(\n", + " key=\"commit_date\", value=\"2023-08-01\", operator=\">=\"\n", + " ),\n", + " MetadataFilter(\n", + " key=\"commit_date\", value=\"2023-08-15\", operator=\"<=\"\n", + " ),\n", + " ],\n", + " condition=\"and\",\n", + " ),\n", + " MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(key=\"author\", value=\"mats@timescale.com\"),\n", + " MetadataFilter(key=\"author\", value=\"sven@timescale.com\"),\n", + " ],\n", + " condition=\"or\",\n", + " ),\n", + " ],\n", + " condition=\"and\",\n", + ")\n", + "\n", + "retriever = index.as_retriever(\n", + " similarity_top_k=10,\n", + " filters=filters,\n", + ")\n", + "\n", + "retrieved_nodes = retriever.retrieve(\"What is this software project about?\")\n", + "\n", + "for node in retrieved_nodes:\n", + " print(node.node.metadata)" + ] + }, { "cell_type": "markdown", "id": "2b274ecb", diff --git a/llama-index-core/llama_index/core/vector_stores/types.py b/llama-index-core/llama_index/core/vector_stores/types.py index 76a4bf6d38..403af0ef59 100644 --- a/llama-index-core/llama_index/core/vector_stores/types.py +++ b/llama-index-core/llama_index/core/vector_stores/types.py @@ -119,14 +119,10 @@ ExactMatchFilter = MetadataFilter class MetadataFilters(BaseModel): - """Metadata filters for vector stores. - - Currently only supports exact match filters. - TODO: support more advanced expressions. - """ + """Metadata filters for vector stores.""" # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc. - filters: List[Union[MetadataFilter, ExactMatchFilter]] + filters: List[Union[MetadataFilter, ExactMatchFilter, "MetadataFilters"]] # and/or such conditions for combining different filters condition: Optional[FilterCondition] = FilterCondition.AND diff --git a/llama-index-core/pyproject.toml b/llama-index-core/pyproject.toml index 9af62c102d..39720e9f79 100644 --- a/llama-index-core/pyproject.toml +++ b/llama-index-core/pyproject.toml @@ -65,8 +65,8 @@ typing-extensions = ">=4.5.0" typing-inspect = ">=0.8.0" requests = ">=2.31.0" # Pin to avoid CVE-2023-32681 in requests 2.3 to 2.30 gradientai = {optional = true, version = ">=1.4.0"} -asyncpg = {optional = true, version = "^0.28.0"} -pgvector = {optional = true, version = "^0.1.0"} +asyncpg = {optional = true, version = "^0.29.0"} +pgvector = {optional = true, version = "^0.2.4"} optimum = {extras = ["onnxruntime"], optional = true, version = "^1.13.2"} sentencepiece = {optional = true, version = "^0.1.99"} transformers = {extras = ["torch"], optional = true, version = "^4.33.1"} diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py index 9a495536ea..b7d114ddc3 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py @@ -350,12 +350,10 @@ class PGVectorStore(BasePydanticVectorStore): _logger.warning(f"Unknown operator: {operator}, fallback to '='") return "=" - def _apply_filters_and_limit( - self, - stmt: Select, - limit: int, - metadata_filters: Optional[MetadataFilters] = None, - ) -> Any: + def _recursively_apply_filters(self, filters: List[MetadataFilters]) -> Any: + """ + Returns a sqlalchemy where clause. + """ import sqlalchemy sqlalchemy_conditions = { @@ -363,31 +361,44 @@ class PGVectorStore(BasePydanticVectorStore): "and": sqlalchemy.sql.and_, } - if metadata_filters: - if metadata_filters.condition not in sqlalchemy_conditions: - raise ValueError( - f"Invalid condition: {metadata_filters.condition}. " - f"Must be one of {list(sqlalchemy_conditions.keys())}" - ) - stmt = stmt.where( # type: ignore - sqlalchemy_conditions[metadata_filters.condition]( - *( - ( - sqlalchemy.text( - f"metadata_::jsonb->'{filter_.key}' " - f"{self._to_postgres_operator(filter_.operator)} " - f"'[\"{filter_.value}\"]'" - ) - if filter_.operator == FilterOperator.IN - else sqlalchemy.text( - f"metadata_->>'{filter_.key}' " - f"{self._to_postgres_operator(filter_.operator)} " - f"'{filter_.value}'" - ) + if filters.condition not in sqlalchemy_conditions: + raise ValueError( + f"Invalid condition: {filters.condition}. " + f"Must be one of {list(sqlalchemy_conditions.keys())}" + ) + + return sqlalchemy_conditions[filters.condition]( + *( + ( + ( + sqlalchemy.text( + f"metadata_::jsonb->'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'[\"{filter_.value}\"]'" + ) + if filter_.operator == FilterOperator.IN + else sqlalchemy.text( + f"metadata_->>'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'{filter_.value}'" ) - for filter_ in metadata_filters.filters ) + if not isinstance(filter_, MetadataFilters) + else self._recursively_apply_filters(filter_) ) + for filter_ in filters.filters + ) + ) + + def _apply_filters_and_limit( + self, + stmt: Select, + limit: int, + metadata_filters: Optional[MetadataFilters] = None, + ) -> Any: + if metadata_filters: + stmt = stmt.where( # type: ignore + self._recursively_apply_filters(metadata_filters) ) return stmt.limit(limit) # type: ignore -- GitLab