From e1d513d399769f6f5bb693b9beac4496089798d6 Mon Sep 17 00:00:00 2001 From: Eric Hare <ericrhare@gmail.com> Date: Wed, 29 Nov 2023 08:42:48 -0800 Subject: [PATCH] Update Astra DB integration for API changes (#9193) * Update Astra DB integration for API changes * Update astra.py * Fix issue in specification of metadata filter --- CHANGELOG.md | 2 + llama_index/vector_stores/astra.py | 109 +++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f7a4f3999d..1687342dea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features - Add new abstractions for `LlamaDataset`'s (#9165) +- Add metadata filtering and MMR mode support for `AstraDBVectorStore` (#9193) ### Breaking Changes / Deprecations @@ -13,6 +14,7 @@ ### Bug Fixes / Nits - Use `azure_deployment` kwarg in `AzureOpenAILLM` (#9174) +- Fix similarity score return for `AstraDBVectorStore` Integration (#9193) ## [0.9.8] - 2023-11-26 diff --git a/llama_index/vector_stores/astra.py b/llama_index/vector_stores/astra.py index 092626484e..5b0b4f872e 100644 --- a/llama_index/vector_stores/astra.py +++ b/llama_index/vector_stores/astra.py @@ -5,12 +5,16 @@ powered by the astrapy library """ import logging -from typing import Any, List, Optional, cast +from typing import Any, Dict, List, Optional, cast +from llama_index.indices.query.embedding_utils import get_top_k_mmr_embeddings from llama_index.schema import BaseNode, MetadataMode from llama_index.vector_stores.types import ( + ExactMatchFilter, + MetadataFilters, VectorStore, VectorStoreQuery, + VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.vector_stores.utils import ( @@ -20,6 +24,7 @@ from llama_index.vector_stores.utils import ( _logger = logging.getLogger(__name__) +DEFAULT_MMR_PREFETCH_FACTOR = 4.0 MAX_INSERT_BATCH_SIZE = 20 @@ -58,7 +63,9 @@ class AstraDBVectorStore(VectorStore): namespace: Optional[str] = None, ttl_seconds: Optional[int] = None, ) -> None: - import_err_msg = "`astrapy` package not found, please run `pip install astrapy`" + import_err_msg = ( + "`astrapy` package not found, please run `pip install --upgrade astrapy`" + ) # Try to import astrapy for use try: @@ -153,25 +160,103 @@ class AstraDBVectorStore(VectorStore): """Return the underlying Astra vector table object.""" return self._astra_db_collection + @staticmethod + def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]: + if any(not isinstance(f, ExactMatchFilter) for f in query_filters.filters): + raise NotImplementedError("Only `ExactMatchFilter` filters are supported") + return {f"metadata.{f.key}": f.value for f in query_filters.filters} + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: """Query index for top k most similar nodes.""" + # Get the currently available query modes + _available_query_modes = [ + VectorStoreQueryMode.DEFAULT, + VectorStoreQueryMode.MMR, + ] + + # Reject query if not available + if query.mode not in _available_query_modes: + raise NotImplementedError(f"Query mode {query.mode} not available.") + # Get the query embedding query_embedding = cast(List[float], query.query_embedding) - # Set the parameters accordingly - sort = {"$vector": query_embedding} - options = {"limit": query.similarity_top_k} - projection = {"$vector": 1, "$similarity": 1, "content": 1} + # Process the metadata filters as needed + if query.filters is not None: + query_metadata = self._query_filters_to_dict(query.filters) + else: + query_metadata = {} + + # Get the scores depending on the query mode + if query.mode == VectorStoreQueryMode.DEFAULT: + # Call the vector_find method of AstraPy + matches = self._astra_db_collection.vector_find( + vector=query_embedding, + limit=query.similarity_top_k, + filter=query_metadata, + ) + + # Get the scores associated with each + top_k_scores = [match["$similarity"] for match in matches] + elif query.mode == VectorStoreQueryMode.MMR: + # Querying a larger number of vectors and then doing MMR on them. + if ( + kwargs.get("mmr_prefetch_factor") is not None + and kwargs.get("mmr_prefetch_k") is not None + ): + raise ValueError( + "'mmr_prefetch_factor' and 'mmr_prefetch_k' " + "cannot coexist in a call to query()" + ) + else: + if kwargs.get("mmr_prefetch_k") is not None: + prefetch_k0 = int(kwargs["mmr_prefetch_k"]) + else: + prefetch_k0 = int( + query.similarity_top_k + * kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR) + ) + # Get the most we can possibly need to fetch + prefetch_k = max(prefetch_k0, query.similarity_top_k) + + # Call AstraPy to fetch them + prefetch_matches = self._astra_db_collection.vector_find( + vector=query_embedding, + limit=prefetch_k, + filter=query_metadata, + ) + + # Get the MMR threshold + mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold") + + # If we have found documents, we can proceed + if prefetch_matches: + pf_match_indices, pf_match_embeddings = zip( + *enumerate(match["$vector"] for match in prefetch_matches) + ) + else: + pf_match_indices, pf_match_embeddings = [], [] + + # Create lists for the indices and embeddings + pf_match_indices = list(pf_match_indices) + pf_match_embeddings = list(pf_match_embeddings) + + # Call the Llama utility function to get the top k + mmr_similarities, mmr_indices = get_top_k_mmr_embeddings( + query_embedding, + pf_match_embeddings, + similarity_top_k=query.similarity_top_k, + embedding_ids=pf_match_indices, + mmr_threshold=mmr_threshold, + ) - # Call the find method of the Astra API - matches = self._astra_db_collection.find( - sort=sort, options=options, projection=projection - )["data"]["documents"] + # Finally, build the final results based on the mmr values + matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices] + top_k_scores = mmr_similarities # We have three lists to return top_k_nodes = [] top_k_ids = [] - top_k_scores = [] # Get every match for my_match in matches: @@ -184,8 +269,8 @@ class AstraDBVectorStore(VectorStore): # Append to the respective lists top_k_nodes.append(node) top_k_ids.append(my_match["_id"]) - top_k_scores.append(my_match["$similarity"]) + # return our final result return VectorStoreQueryResult( nodes=top_k_nodes, similarities=top_k_scores, -- GitLab