From e1d513d399769f6f5bb693b9beac4496089798d6 Mon Sep 17 00:00:00 2001
From: Eric Hare <ericrhare@gmail.com>
Date: Wed, 29 Nov 2023 08:42:48 -0800
Subject: [PATCH] Update Astra DB integration for API changes (#9193)

* Update Astra DB integration for API changes

* Update astra.py

* Fix issue in specification of metadata filter
---
 CHANGELOG.md                       |   2 +
 llama_index/vector_stores/astra.py | 109 +++++++++++++++++++++++++----
 2 files changed, 99 insertions(+), 12 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f7a4f3999d..1687342dea 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,7 @@
 ### New Features
 
 - Add new abstractions for `LlamaDataset`'s (#9165)
+- Add metadata filtering and MMR mode support for `AstraDBVectorStore` (#9193)
 
 ### Breaking Changes / Deprecations
 
@@ -13,6 +14,7 @@
 ### Bug Fixes / Nits
 
 - Use `azure_deployment` kwarg in `AzureOpenAILLM` (#9174)
+- Fix similarity score return for `AstraDBVectorStore` Integration (#9193)
 
 ## [0.9.8] - 2023-11-26
 
diff --git a/llama_index/vector_stores/astra.py b/llama_index/vector_stores/astra.py
index 092626484e..5b0b4f872e 100644
--- a/llama_index/vector_stores/astra.py
+++ b/llama_index/vector_stores/astra.py
@@ -5,12 +5,16 @@ powered by the astrapy library
 
 """
 import logging
-from typing import Any, List, Optional, cast
+from typing import Any, Dict, List, Optional, cast
 
+from llama_index.indices.query.embedding_utils import get_top_k_mmr_embeddings
 from llama_index.schema import BaseNode, MetadataMode
 from llama_index.vector_stores.types import (
+    ExactMatchFilter,
+    MetadataFilters,
     VectorStore,
     VectorStoreQuery,
+    VectorStoreQueryMode,
     VectorStoreQueryResult,
 )
 from llama_index.vector_stores.utils import (
@@ -20,6 +24,7 @@ from llama_index.vector_stores.utils import (
 
 _logger = logging.getLogger(__name__)
 
+DEFAULT_MMR_PREFETCH_FACTOR = 4.0
 MAX_INSERT_BATCH_SIZE = 20
 
 
@@ -58,7 +63,9 @@ class AstraDBVectorStore(VectorStore):
         namespace: Optional[str] = None,
         ttl_seconds: Optional[int] = None,
     ) -> None:
-        import_err_msg = "`astrapy` package not found, please run `pip install astrapy`"
+        import_err_msg = (
+            "`astrapy` package not found, please run `pip install --upgrade astrapy`"
+        )
 
         # Try to import astrapy for use
         try:
@@ -153,25 +160,103 @@ class AstraDBVectorStore(VectorStore):
         """Return the underlying Astra vector table object."""
         return self._astra_db_collection
 
+    @staticmethod
+    def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
+        if any(not isinstance(f, ExactMatchFilter) for f in query_filters.filters):
+            raise NotImplementedError("Only `ExactMatchFilter` filters are supported")
+        return {f"metadata.{f.key}": f.value for f in query_filters.filters}
+
     def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
         """Query index for top k most similar nodes."""
+        # Get the currently available query modes
+        _available_query_modes = [
+            VectorStoreQueryMode.DEFAULT,
+            VectorStoreQueryMode.MMR,
+        ]
+
+        # Reject query if not available
+        if query.mode not in _available_query_modes:
+            raise NotImplementedError(f"Query mode {query.mode} not available.")
+
         # Get the query embedding
         query_embedding = cast(List[float], query.query_embedding)
 
-        # Set the parameters accordingly
-        sort = {"$vector": query_embedding}
-        options = {"limit": query.similarity_top_k}
-        projection = {"$vector": 1, "$similarity": 1, "content": 1}
+        # Process the metadata filters as needed
+        if query.filters is not None:
+            query_metadata = self._query_filters_to_dict(query.filters)
+        else:
+            query_metadata = {}
+
+        # Get the scores depending on the query mode
+        if query.mode == VectorStoreQueryMode.DEFAULT:
+            # Call the vector_find method of AstraPy
+            matches = self._astra_db_collection.vector_find(
+                vector=query_embedding,
+                limit=query.similarity_top_k,
+                filter=query_metadata,
+            )
+
+            # Get the scores associated with each
+            top_k_scores = [match["$similarity"] for match in matches]
+        elif query.mode == VectorStoreQueryMode.MMR:
+            # Querying a larger number of vectors and then doing MMR on them.
+            if (
+                kwargs.get("mmr_prefetch_factor") is not None
+                and kwargs.get("mmr_prefetch_k") is not None
+            ):
+                raise ValueError(
+                    "'mmr_prefetch_factor' and 'mmr_prefetch_k' "
+                    "cannot coexist in a call to query()"
+                )
+            else:
+                if kwargs.get("mmr_prefetch_k") is not None:
+                    prefetch_k0 = int(kwargs["mmr_prefetch_k"])
+                else:
+                    prefetch_k0 = int(
+                        query.similarity_top_k
+                        * kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR)
+                    )
+            # Get the most we can possibly need to fetch
+            prefetch_k = max(prefetch_k0, query.similarity_top_k)
+
+            # Call AstraPy to fetch them
+            prefetch_matches = self._astra_db_collection.vector_find(
+                vector=query_embedding,
+                limit=prefetch_k,
+                filter=query_metadata,
+            )
+
+            # Get the MMR threshold
+            mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold")
+
+            # If we have found documents, we can proceed
+            if prefetch_matches:
+                pf_match_indices, pf_match_embeddings = zip(
+                    *enumerate(match["$vector"] for match in prefetch_matches)
+                )
+            else:
+                pf_match_indices, pf_match_embeddings = [], []
+
+            # Create lists for the indices and embeddings
+            pf_match_indices = list(pf_match_indices)
+            pf_match_embeddings = list(pf_match_embeddings)
+
+            # Call the Llama utility function to get the top k
+            mmr_similarities, mmr_indices = get_top_k_mmr_embeddings(
+                query_embedding,
+                pf_match_embeddings,
+                similarity_top_k=query.similarity_top_k,
+                embedding_ids=pf_match_indices,
+                mmr_threshold=mmr_threshold,
+            )
 
-        # Call the find method of the Astra API
-        matches = self._astra_db_collection.find(
-            sort=sort, options=options, projection=projection
-        )["data"]["documents"]
+            # Finally, build the final results based on the mmr values
+            matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices]
+            top_k_scores = mmr_similarities
 
         # We have three lists to return
         top_k_nodes = []
         top_k_ids = []
-        top_k_scores = []
 
         # Get every match
         for my_match in matches:
@@ -184,8 +269,8 @@ class AstraDBVectorStore(VectorStore):
             # Append to the respective lists
             top_k_nodes.append(node)
             top_k_ids.append(my_match["_id"])
-            top_k_scores.append(my_match["$similarity"])
 
+        # return our final result
         return VectorStoreQueryResult(
             nodes=top_k_nodes,
             similarities=top_k_scores,
-- 
GitLab