diff --git a/docs/examples/vector_stores/postgres.ipynb b/docs/examples/vector_stores/postgres.ipynb index 21183edd68600c5ac841351de016dd970f1e50ad..6643d7e83555247ad6e7fd40dc3b164c3ca43211 100644 --- a/docs/examples/vector_stores/postgres.ipynb +++ b/docs/examples/vector_stores/postgres.ipynb @@ -87,7 +87,7 @@ "import os\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = \"<your key>\"\n", - "openai.api_key = \"<your key>\"" + "openai.api_key = os.environ[\"OPENAI_API_KEY\"]" ] }, { @@ -130,7 +130,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Document ID: d05d1211-b9af-4b05-8da6-956e4b389467\n" + "Document ID: 88efac05-2277-4eda-a94c-c9247c9aca1c\n" ] } ], @@ -185,12 +185,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1358441c0c864d87a860820ed8cf2b2c", + "model_id": "40b4c3becfc64c5184360b8b8e81ca9a", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Parsing documents into nodes: 0%| | 0/1 [00:00<?, ?it/s]" + "Parsing nodes: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, @@ -199,12 +199,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "69a77c77ff1c48cc8107b445ae4fa0cc", + "model_id": "37763ad7b17f4481a7e67df379304d31", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Generating embeddings: 0%| | 0/17 [00:00<?, ?it/s]" + "Generating embeddings: 0%| | 0/22 [00:00<?, ?it/s]" ] }, "metadata": {}, @@ -261,11 +261,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "The author worked on writing and programming before college. They wrote short stories and tried\n", - "writing programs on an IBM 1401 computer. They also built a microcomputer and started programming on\n", - "it, writing simple games and a word processor. In college, the author initially planned to study\n", - "philosophy but switched to AI. They were inspired by a novel called The Moon is a Harsh Mistress and\n", - "a PBS documentary featuring Terry Winograd using SHRDLU.\n" + "The author worked on writing and programming before college. Initially, the author wrote short\n", + "stories and later started programming on an IBM 1401 using an early version of Fortran. The author\n", + "then transitioned to working with microcomputers, building a computer kit and eventually getting a\n", + "TRS-80 to further explore programming. In college, the author initially planned to study philosophy\n", + "but switched to studying AI due to a lack of interest in philosophy courses. The author was inspired\n", + "to work on AI after encountering works like Heinlein's novel \"The Moon is a Harsh Mistress\" and\n", + "seeing Terry Winograd using SHRDLU in a PBS documentary.\n" ] } ], @@ -391,16 +393,7 @@ "execution_count": null, "id": "65a7e133-39da-40c5-b2c5-7af2c0a3a792", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/suo/dev/llama_index/llama_index/vector_stores/postgres.py:217: SAWarning: TypeDecorator TSVector() will not produce a cache key because the ``cache_ok`` attribute is not set to True. This can have significant performance implications including some performance degradations in comparison to prior SQLAlchemy versions. Set this attribute to True if this type object's state is safe to use in a cache key, or False to disable this warning. (Background on this warning at: https://sqlalche.me/e/20/cprf)\n", - " session.commit()\n" - ] - } - ], + "outputs": [], "source": [ "from sqlalchemy import make_url\n", "\n", @@ -458,6 +451,330 @@ "print(hybrid_response)" ] }, + { + "cell_type": "markdown", + "id": "2e5e8083", + "metadata": {}, + "source": [ + "### Metadata filters\n", + "\n", + "PGVectorStore supports storing metadata in nodes, and filtering based on that metadata during the retrieval step." + ] + }, + { + "cell_type": "markdown", + "id": "2d0ad3fc", + "metadata": {}, + "source": [ + "#### Download git commits dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63e90a89", + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p 'data/git_commits/'\n", + "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/csv/commit_history.csv' -O 'data/git_commits/commit_history.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fef41f44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit': '44e41c12ab25e36c202f58e068ced262eadc8d16', 'author': 'Lakshmi Narayanan Sreethar<lakshmi@timescale.com>', 'date': 'Tue Sep 5 21:03:21 2023 +0530', 'change summary': 'Fix segfault in set_integer_now_func', 'change details': 'When an invalid function oid is passed to set_integer_now_func, it finds out that the function oid is invalid but before throwing the error, it calls ReleaseSysCache on an invalid tuple causing a segfault. Fixed that by removing the invalid call to ReleaseSysCache. Fixes #6037 '}\n", + "4167\n" + ] + } + ], + "source": [ + "import csv\n", + "\n", + "with open(\"data/git_commits/commit_history.csv\", \"r\") as f:\n", + " commits = list(csv.DictReader(f))\n", + "\n", + "print(commits[0])\n", + "print(len(commits))" + ] + }, + { + "cell_type": "markdown", + "id": "3b0d9f47", + "metadata": {}, + "source": [ + "#### Add nodes with custom metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3920109b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Node ID: e084ffbd-24e0-4bd9-b7c8-287fe1abd85d\n", + "Text: Fix segfault in set_integer_now_func When an invalid function\n", + "oid is passed to set_integer_now_func, it finds out that the function\n", + "oid is invalid but before throwing the error, it calls ReleaseSysCache\n", + "on an invalid tuple causing a segfault. Fixed that by removing the\n", + "invalid call to ReleaseSysCache. Fixes #6037\n", + "2023-03-22 to 2023-09-05\n", + "{'36882414+akuzm@users.noreply.github.com', 'erik@timescale.com', 'konstantina@timescale.com', 'mats@timescale.com', 'nikhil@timescale.com', 'dmitry@timescale.com', 'jguthrie@timescale.com', 'rafia.sabih@gmail.com', 'engel@sero-systems.de', 'satish.8483@gmail.com', 'me@noctarius.com', 'sven@timescale.com', 'jan@timescale.com', 'lakshmi@timescale.com', 'fabriziomello@gmail.com'}\n" + ] + } + ], + "source": [ + "# Create TextNode for each of the first 100 commits\n", + "from llama_index.core.schema import TextNode\n", + "from datetime import datetime\n", + "\n", + "nodes = []\n", + "dates = set()\n", + "authors = set()\n", + "for commit in commits[:100]:\n", + " author_email = commit[\"author\"].split(\"<\")[1][:-1]\n", + " commit_date = datetime.strptime(\n", + " commit[\"date\"], \"%a %b %d %H:%M:%S %Y %z\"\n", + " ).strftime(\"%Y-%m-%d\")\n", + " commit_text = commit[\"change summary\"]\n", + " if commit[\"change details\"]:\n", + " commit_text += \"\\n\\n\" + commit[\"change details\"]\n", + " nodes.append(\n", + " TextNode(\n", + " text=commit_text,\n", + " metadata={\n", + " \"commit_date\": commit_date,\n", + " \"author\": author_email,\n", + " },\n", + " )\n", + " )\n", + " dates.add(commit_date)\n", + " authors.add(author_email)\n", + "\n", + "print(nodes[0])\n", + "print(min(dates), \"to\", max(dates))\n", + "print(authors)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a638f76a", + "metadata": {}, + "outputs": [], + "source": [ + "vector_store = PGVectorStore.from_params(\n", + " database=db_name,\n", + " host=url.host,\n", + " password=url.password,\n", + " port=url.port,\n", + " user=url.username,\n", + " table_name=\"metadata_filter_demo3\",\n", + " embed_dim=1536, # openai embedding dimension\n", + ")\n", + "\n", + "index = VectorStoreIndex.from_vector_store(vector_store=vector_store)\n", + "index.insert_nodes(nodes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f7cf45", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lakshmi fixed the segfault by removing the invalid call to ReleaseSysCache that was causing the issue.\n" + ] + } + ], + "source": [ + "print(index.as_query_engine().query(\"How did Lakshmi fix the segfault?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "7ab03ed4", + "metadata": {}, + "source": [ + "#### Apply metadata filters\n", + "\n", + "Now we can filter by commit author or by date when retrieving nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa6212e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit_date': '2023-08-07', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-07', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-23', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-07-13', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-27', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-21', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-30', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-10', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-20', 'author': 'sven@timescale.com'}\n" + ] + } + ], + "source": [ + "from llama_index.core.vector_stores.types import (\n", + " MetadataFilter,\n", + " MetadataFilters,\n", + ")\n", + "\n", + "filters = MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(key=\"author\", value=\"mats@timescale.com\"),\n", + " MetadataFilter(key=\"author\", value=\"sven@timescale.com\"),\n", + " ],\n", + " condition=\"or\",\n", + ")\n", + "\n", + "retriever = index.as_retriever(\n", + " similarity_top_k=10,\n", + " filters=filters,\n", + ")\n", + "\n", + "retrieved_nodes = retriever.retrieve(\"What is this software project about?\")\n", + "\n", + "for node in retrieved_nodes:\n", + " print(node.node.metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67c19ec6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit_date': '2023-08-23', 'author': 'erik@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': '36882414+akuzm@users.noreply.github.com'}\n", + "{'commit_date': '2023-08-17', 'author': 'konstantina@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-23', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': '36882414+akuzm@users.noreply.github.com'}\n", + "{'commit_date': '2023-08-21', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-24', 'author': 'lakshmi@timescale.com'}\n", + "{'commit_date': '2023-08-16', 'author': '36882414+akuzm@users.noreply.github.com'}\n", + "{'commit_date': '2023-08-20', 'author': 'sven@timescale.com'}\n" + ] + } + ], + "source": [ + "filters = MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(key=\"commit_date\", value=\"2023-08-15\", operator=\">=\"),\n", + " MetadataFilter(key=\"commit_date\", value=\"2023-08-25\", operator=\"<=\"),\n", + " ],\n", + " condition=\"and\",\n", + ")\n", + "\n", + "retriever = index.as_retriever(\n", + " similarity_top_k=10,\n", + " filters=filters,\n", + ")\n", + "\n", + "retrieved_nodes = retriever.retrieve(\"What is this software project about?\")\n", + "\n", + "for node in retrieved_nodes:\n", + " print(node.node.metadata)" + ] + }, + { + "cell_type": "markdown", + "id": "4f6e9cdf", + "metadata": {}, + "source": [ + "#### Apply nested filters\n", + "\n", + "In the above examples, we combined multiple filters using AND or OR. We can also combine multiple sets of filters.\n", + "\n", + "e.g. in SQL:\n", + "```sql\n", + "WHERE (commit_date >= '2023-08-01' AND commit_date <= '2023-08-15') AND (author = 'mats@timescale.com' OR author = 'sven@timescale.com')\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94f20be7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'commit_date': '2023-08-07', 'author': 'mats@timescale.com'}\n", + "{'commit_date': '2023-08-07', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-15', 'author': 'sven@timescale.com'}\n", + "{'commit_date': '2023-08-10', 'author': 'mats@timescale.com'}\n" + ] + } + ], + "source": [ + "filters = MetadataFilters(\n", + " filters=[\n", + " MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(\n", + " key=\"commit_date\", value=\"2023-08-01\", operator=\">=\"\n", + " ),\n", + " MetadataFilter(\n", + " key=\"commit_date\", value=\"2023-08-15\", operator=\"<=\"\n", + " ),\n", + " ],\n", + " condition=\"and\",\n", + " ),\n", + " MetadataFilters(\n", + " filters=[\n", + " MetadataFilter(key=\"author\", value=\"mats@timescale.com\"),\n", + " MetadataFilter(key=\"author\", value=\"sven@timescale.com\"),\n", + " ],\n", + " condition=\"or\",\n", + " ),\n", + " ],\n", + " condition=\"and\",\n", + ")\n", + "\n", + "retriever = index.as_retriever(\n", + " similarity_top_k=10,\n", + " filters=filters,\n", + ")\n", + "\n", + "retrieved_nodes = retriever.retrieve(\"What is this software project about?\")\n", + "\n", + "for node in retrieved_nodes:\n", + " print(node.node.metadata)" + ] + }, { "cell_type": "markdown", "id": "2b274ecb", diff --git a/llama-index-core/llama_index/core/vector_stores/types.py b/llama-index-core/llama_index/core/vector_stores/types.py index 76a4bf6d3819de1a7ac831cdde8783c410889894..403af0ef59666c541af430e5b0ac80890b436f77 100644 --- a/llama-index-core/llama_index/core/vector_stores/types.py +++ b/llama-index-core/llama_index/core/vector_stores/types.py @@ -119,14 +119,10 @@ ExactMatchFilter = MetadataFilter class MetadataFilters(BaseModel): - """Metadata filters for vector stores. - - Currently only supports exact match filters. - TODO: support more advanced expressions. - """ + """Metadata filters for vector stores.""" # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc. - filters: List[Union[MetadataFilter, ExactMatchFilter]] + filters: List[Union[MetadataFilter, ExactMatchFilter, "MetadataFilters"]] # and/or such conditions for combining different filters condition: Optional[FilterCondition] = FilterCondition.AND diff --git a/llama-index-core/pyproject.toml b/llama-index-core/pyproject.toml index 9af62c102dc105a24af8a88f4b43d22e76773138..39720e9f79478440f7fadd5223782346cca01799 100644 --- a/llama-index-core/pyproject.toml +++ b/llama-index-core/pyproject.toml @@ -65,8 +65,8 @@ typing-extensions = ">=4.5.0" typing-inspect = ">=0.8.0" requests = ">=2.31.0" # Pin to avoid CVE-2023-32681 in requests 2.3 to 2.30 gradientai = {optional = true, version = ">=1.4.0"} -asyncpg = {optional = true, version = "^0.28.0"} -pgvector = {optional = true, version = "^0.1.0"} +asyncpg = {optional = true, version = "^0.29.0"} +pgvector = {optional = true, version = "^0.2.4"} optimum = {extras = ["onnxruntime"], optional = true, version = "^1.13.2"} sentencepiece = {optional = true, version = "^0.1.99"} transformers = {extras = ["torch"], optional = true, version = "^4.33.1"} diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py index 9a495536ea0ba422724cbb40246d40ed07ba21cb..b7d114ddc3986ea632e3f9d233f835361704e4c1 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py @@ -350,12 +350,10 @@ class PGVectorStore(BasePydanticVectorStore): _logger.warning(f"Unknown operator: {operator}, fallback to '='") return "=" - def _apply_filters_and_limit( - self, - stmt: Select, - limit: int, - metadata_filters: Optional[MetadataFilters] = None, - ) -> Any: + def _recursively_apply_filters(self, filters: List[MetadataFilters]) -> Any: + """ + Returns a sqlalchemy where clause. + """ import sqlalchemy sqlalchemy_conditions = { @@ -363,31 +361,44 @@ class PGVectorStore(BasePydanticVectorStore): "and": sqlalchemy.sql.and_, } - if metadata_filters: - if metadata_filters.condition not in sqlalchemy_conditions: - raise ValueError( - f"Invalid condition: {metadata_filters.condition}. " - f"Must be one of {list(sqlalchemy_conditions.keys())}" - ) - stmt = stmt.where( # type: ignore - sqlalchemy_conditions[metadata_filters.condition]( - *( - ( - sqlalchemy.text( - f"metadata_::jsonb->'{filter_.key}' " - f"{self._to_postgres_operator(filter_.operator)} " - f"'[\"{filter_.value}\"]'" - ) - if filter_.operator == FilterOperator.IN - else sqlalchemy.text( - f"metadata_->>'{filter_.key}' " - f"{self._to_postgres_operator(filter_.operator)} " - f"'{filter_.value}'" - ) + if filters.condition not in sqlalchemy_conditions: + raise ValueError( + f"Invalid condition: {filters.condition}. " + f"Must be one of {list(sqlalchemy_conditions.keys())}" + ) + + return sqlalchemy_conditions[filters.condition]( + *( + ( + ( + sqlalchemy.text( + f"metadata_::jsonb->'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'[\"{filter_.value}\"]'" + ) + if filter_.operator == FilterOperator.IN + else sqlalchemy.text( + f"metadata_->>'{filter_.key}' " + f"{self._to_postgres_operator(filter_.operator)} " + f"'{filter_.value}'" ) - for filter_ in metadata_filters.filters ) + if not isinstance(filter_, MetadataFilters) + else self._recursively_apply_filters(filter_) ) + for filter_ in filters.filters + ) + ) + + def _apply_filters_and_limit( + self, + stmt: Select, + limit: int, + metadata_filters: Optional[MetadataFilters] = None, + ) -> Any: + if metadata_filters: + stmt = stmt.where( # type: ignore + self._recursively_apply_filters(metadata_filters) ) return stmt.limit(limit) # type: ignore