From b067b58bc4ea53f7158a13cea79a4aa4d55f57f3 Mon Sep 17 00:00:00 2001 From: Rendy Febry <rendyfebry@economist.com> Date: Mon, 11 Dec 2023 00:24:20 +0700 Subject: [PATCH] Feat/PgVector Support custom hnsw.ef_search and ivfflat.probes (#9420) --- docs/examples/vector_stores/postgres.ipynb | 58 ++++++++++++++++++++++ llama_index/vector_stores/postgres.py | 54 +++++++++++++++++--- 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/docs/examples/vector_stores/postgres.ipynb b/docs/examples/vector_stores/postgres.ipynb index 8cb7d3e170..707e9d5ba8 100644 --- a/docs/examples/vector_stores/postgres.ipynb +++ b/docs/examples/vector_stores/postgres.ipynb @@ -447,6 +447,64 @@ "source": [ "print(hybrid_response)" ] + }, + { + "cell_type": "markdown", + "id": "2b274ecb", + "metadata": {}, + "source": [ + "### PgVector Query Options" + ] + }, + { + "cell_type": "markdown", + "id": "a490a0fa", + "metadata": {}, + "source": [ + "#### IVFFlat Probes\n", + "\n", + "Specify the number of [IVFFlat probes](https://github.com/pgvector/pgvector?tab=readme-ov-file#query-options) (1 by default)\n", + "\n", + "When retrieving from the index, you can specify an appropriate number of IVFFlat probes (higher is better for recall, lower is better for speed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "111a3682", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = index.as_retriever(\n", + " vector_store_query_mode=query_mode,\n", + " similarity_top_k=top_k,\n", + " vector_store_kwargs={\"ivfflat_probes\": 10},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6104ef8d", + "metadata": {}, + "source": [ + "#### HNSW EF Search\n", + "\n", + "Specify the size of the dynamic [candidate list](https://github.com/pgvector/pgvector?tab=readme-ov-file#query-options-1) for search (40 by default)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3a44758", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = index.as_retriever(\n", + " vector_store_query_mode=query_mode,\n", + " similarity_top_k=top_k,\n", + " vector_store_kwargs={\"hnsw_ef_search\": 300},\n", + ")" + ] } ], "metadata": { diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py index 42d59168fb..74a6ba15eb 100644 --- a/llama_index/vector_stores/postgres.py +++ b/llama_index/vector_stores/postgres.py @@ -393,9 +393,21 @@ class PGVectorStore(BasePydanticVectorStore): embedding: Optional[List[float]], limit: int = 10, metadata_filters: Optional[MetadataFilters] = None, + **kwargs: Any, ) -> List[DBEmbeddingRow]: stmt = self._build_query(embedding, limit, metadata_filters) with self._session() as session, session.begin(): + from sqlalchemy import text + + if kwargs.get("ivfflat_probes"): + session.execute( + text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}") + ) + if kwargs.get("hnsw_ef_search"): + session.execute( + text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}") + ) + res = session.execute( stmt, ) @@ -414,9 +426,21 @@ class PGVectorStore(BasePydanticVectorStore): embedding: Optional[List[float]], limit: int = 10, metadata_filters: Optional[MetadataFilters] = None, + **kwargs: Any, ) -> List[DBEmbeddingRow]: stmt = self._build_query(embedding, limit, metadata_filters) async with self._async_session() as async_session, async_session.begin(): + from sqlalchemy import text + + if kwargs.get("hnsw_ef_search"): + await async_session.execute( + text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}") + ) + if kwargs.get("ivfflat_probes"): + await async_session.execute( + text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}") + ) + res = await async_session.execute(stmt) return [ DBEmbeddingRow( @@ -495,7 +519,7 @@ class PGVectorStore(BasePydanticVectorStore): ] async def _async_hybrid_query( - self, query: VectorStoreQuery + self, query: VectorStoreQuery, **kwargs: Any ) -> List[DBEmbeddingRow]: import asyncio @@ -506,7 +530,10 @@ class PGVectorStore(BasePydanticVectorStore): results = await asyncio.gather( self._aquery_with_score( - query.query_embedding, query.similarity_top_k, query.filters + query.query_embedding, + query.similarity_top_k, + query.filters, + **kwargs, ), self._async_sparse_query_with_rank( query.query_str, sparse_top_k, query.filters @@ -517,14 +544,19 @@ class PGVectorStore(BasePydanticVectorStore): all_results = dense_results + sparse_results return _dedup_results(all_results) - def _hybrid_query(self, query: VectorStoreQuery) -> List[DBEmbeddingRow]: + def _hybrid_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> List[DBEmbeddingRow]: if query.alpha is not None: _logger.warning("postgres hybrid search does not support alpha parameter.") sparse_top_k = query.sparse_top_k or query.similarity_top_k dense_results = self._query_with_score( - query.query_embedding, query.similarity_top_k, query.filters + query.query_embedding, + query.similarity_top_k, + query.filters, + **kwargs, ) sparse_results = self._sparse_query_with_rank( @@ -566,7 +598,7 @@ class PGVectorStore(BasePydanticVectorStore): ) -> VectorStoreQueryResult: self._initialize() if query.mode == VectorStoreQueryMode.HYBRID: - results = await self._async_hybrid_query(query) + results = await self._async_hybrid_query(query, **kwargs) elif query.mode in [ VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.TEXT_SEARCH, @@ -577,7 +609,10 @@ class PGVectorStore(BasePydanticVectorStore): ) elif query.mode == VectorStoreQueryMode.DEFAULT: results = await self._aquery_with_score( - query.query_embedding, query.similarity_top_k, query.filters + query.query_embedding, + query.similarity_top_k, + query.filters, + **kwargs, ) else: raise ValueError(f"Invalid query mode: {query.mode}") @@ -587,7 +622,7 @@ class PGVectorStore(BasePydanticVectorStore): def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: self._initialize() if query.mode == VectorStoreQueryMode.HYBRID: - results = self._hybrid_query(query) + results = self._hybrid_query(query, **kwargs) elif query.mode in [ VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.TEXT_SEARCH, @@ -598,7 +633,10 @@ class PGVectorStore(BasePydanticVectorStore): ) elif query.mode == VectorStoreQueryMode.DEFAULT: results = self._query_with_score( - query.query_embedding, query.similarity_top_k, query.filters + query.query_embedding, + query.similarity_top_k, + query.filters, + **kwargs, ) else: raise ValueError(f"Invalid query mode: {query.mode}") -- GitLab