From b067b58bc4ea53f7158a13cea79a4aa4d55f57f3 Mon Sep 17 00:00:00 2001
From: Rendy Febry <rendyfebry@economist.com>
Date: Mon, 11 Dec 2023 00:24:20 +0700
Subject: [PATCH] Feat/PgVector Support custom hnsw.ef_search and
 ivfflat.probes (#9420)

---
 docs/examples/vector_stores/postgres.ipynb | 58 ++++++++++++++++++++++
 llama_index/vector_stores/postgres.py      | 54 +++++++++++++++++---
 2 files changed, 104 insertions(+), 8 deletions(-)

diff --git a/docs/examples/vector_stores/postgres.ipynb b/docs/examples/vector_stores/postgres.ipynb
index 8cb7d3e170..707e9d5ba8 100644
--- a/docs/examples/vector_stores/postgres.ipynb
+++ b/docs/examples/vector_stores/postgres.ipynb
@@ -447,6 +447,64 @@
    "source": [
     "print(hybrid_response)"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2b274ecb",
+   "metadata": {},
+   "source": [
+    "### PgVector Query Options"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a490a0fa",
+   "metadata": {},
+   "source": [
+    "#### IVFFlat Probes\n",
+    "\n",
+    "Specify the number of [IVFFlat probes](https://github.com/pgvector/pgvector?tab=readme-ov-file#query-options) (1 by default)\n",
+    "\n",
+    "When retrieving from the index, you can specify an appropriate number of IVFFlat probes (higher is better for recall, lower is better for speed)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "111a3682",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "retriever = index.as_retriever(\n",
+    "    vector_store_query_mode=query_mode,\n",
+    "    similarity_top_k=top_k,\n",
+    "    vector_store_kwargs={\"ivfflat_probes\": 10},\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6104ef8d",
+   "metadata": {},
+   "source": [
+    "#### HNSW EF Search\n",
+    "\n",
+    "Specify the size of the dynamic [candidate list](https://github.com/pgvector/pgvector?tab=readme-ov-file#query-options-1) for search (40 by default)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f3a44758",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "retriever = index.as_retriever(\n",
+    "    vector_store_query_mode=query_mode,\n",
+    "    similarity_top_k=top_k,\n",
+    "    vector_store_kwargs={\"hnsw_ef_search\": 300},\n",
+    ")"
+   ]
   }
  ],
  "metadata": {
diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py
index 42d59168fb..74a6ba15eb 100644
--- a/llama_index/vector_stores/postgres.py
+++ b/llama_index/vector_stores/postgres.py
@@ -393,9 +393,21 @@ class PGVectorStore(BasePydanticVectorStore):
         embedding: Optional[List[float]],
         limit: int = 10,
         metadata_filters: Optional[MetadataFilters] = None,
+        **kwargs: Any,
     ) -> List[DBEmbeddingRow]:
         stmt = self._build_query(embedding, limit, metadata_filters)
         with self._session() as session, session.begin():
+            from sqlalchemy import text
+
+            if kwargs.get("ivfflat_probes"):
+                session.execute(
+                    text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
+                )
+            if kwargs.get("hnsw_ef_search"):
+                session.execute(
+                    text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
+                )
+
             res = session.execute(
                 stmt,
             )
@@ -414,9 +426,21 @@ class PGVectorStore(BasePydanticVectorStore):
         embedding: Optional[List[float]],
         limit: int = 10,
         metadata_filters: Optional[MetadataFilters] = None,
+        **kwargs: Any,
     ) -> List[DBEmbeddingRow]:
         stmt = self._build_query(embedding, limit, metadata_filters)
         async with self._async_session() as async_session, async_session.begin():
+            from sqlalchemy import text
+
+            if kwargs.get("hnsw_ef_search"):
+                await async_session.execute(
+                    text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
+                )
+            if kwargs.get("ivfflat_probes"):
+                await async_session.execute(
+                    text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
+                )
+
             res = await async_session.execute(stmt)
             return [
                 DBEmbeddingRow(
@@ -495,7 +519,7 @@ class PGVectorStore(BasePydanticVectorStore):
             ]
 
     async def _async_hybrid_query(
-        self, query: VectorStoreQuery
+        self, query: VectorStoreQuery, **kwargs: Any
     ) -> List[DBEmbeddingRow]:
         import asyncio
 
@@ -506,7 +530,10 @@ class PGVectorStore(BasePydanticVectorStore):
 
         results = await asyncio.gather(
             self._aquery_with_score(
-                query.query_embedding, query.similarity_top_k, query.filters
+                query.query_embedding,
+                query.similarity_top_k,
+                query.filters,
+                **kwargs,
             ),
             self._async_sparse_query_with_rank(
                 query.query_str, sparse_top_k, query.filters
@@ -517,14 +544,19 @@ class PGVectorStore(BasePydanticVectorStore):
         all_results = dense_results + sparse_results
         return _dedup_results(all_results)
 
-    def _hybrid_query(self, query: VectorStoreQuery) -> List[DBEmbeddingRow]:
+    def _hybrid_query(
+        self, query: VectorStoreQuery, **kwargs: Any
+    ) -> List[DBEmbeddingRow]:
         if query.alpha is not None:
             _logger.warning("postgres hybrid search does not support alpha parameter.")
 
         sparse_top_k = query.sparse_top_k or query.similarity_top_k
 
         dense_results = self._query_with_score(
-            query.query_embedding, query.similarity_top_k, query.filters
+            query.query_embedding,
+            query.similarity_top_k,
+            query.filters,
+            **kwargs,
         )
 
         sparse_results = self._sparse_query_with_rank(
@@ -566,7 +598,7 @@ class PGVectorStore(BasePydanticVectorStore):
     ) -> VectorStoreQueryResult:
         self._initialize()
         if query.mode == VectorStoreQueryMode.HYBRID:
-            results = await self._async_hybrid_query(query)
+            results = await self._async_hybrid_query(query, **kwargs)
         elif query.mode in [
             VectorStoreQueryMode.SPARSE,
             VectorStoreQueryMode.TEXT_SEARCH,
@@ -577,7 +609,10 @@ class PGVectorStore(BasePydanticVectorStore):
             )
         elif query.mode == VectorStoreQueryMode.DEFAULT:
             results = await self._aquery_with_score(
-                query.query_embedding, query.similarity_top_k, query.filters
+                query.query_embedding,
+                query.similarity_top_k,
+                query.filters,
+                **kwargs,
             )
         else:
             raise ValueError(f"Invalid query mode: {query.mode}")
@@ -587,7 +622,7 @@ class PGVectorStore(BasePydanticVectorStore):
     def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
         self._initialize()
         if query.mode == VectorStoreQueryMode.HYBRID:
-            results = self._hybrid_query(query)
+            results = self._hybrid_query(query, **kwargs)
         elif query.mode in [
             VectorStoreQueryMode.SPARSE,
             VectorStoreQueryMode.TEXT_SEARCH,
@@ -598,7 +633,10 @@ class PGVectorStore(BasePydanticVectorStore):
             )
         elif query.mode == VectorStoreQueryMode.DEFAULT:
             results = self._query_with_score(
-                query.query_embedding, query.similarity_top_k, query.filters
+                query.query_embedding,
+                query.similarity_top_k,
+                query.filters,
+                **kwargs,
             )
         else:
             raise ValueError(f"Invalid query mode: {query.mode}")
-- 
GitLab