From 748780c53227ba401616d8beb562cb4d8ba9b689 Mon Sep 17 00:00:00 2001
From: Vittorio <vittorio.mayellaro.dev@gmail.com>
Date: Fri, 23 Aug 2024 14:05:06 +0200
Subject: [PATCH] Implemented aget_routes async method for pinecone index

---
 semantic_router/index/base.py     |  11 +++
 semantic_router/index/local.py    |   3 +
 semantic_router/index/pinecone.py | 113 ++++++++++++++++++++++++++++++
 semantic_router/index/postgres.py |   4 ++
 semantic_router/index/qdrant.py   |   3 +
 5 files changed, 134 insertions(+)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index d0f12ac6..a23f92bf 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -90,6 +90,17 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
+    def aget_routes(self):
+        """
+        Asynchronously get a list of route and utterance objects currently stored in the index.
+        This method should be implemented by subclasses.
+
+        :returns: A list of tuples, each containing a route name and an associated utterance.
+        :rtype: list[tuple]
+        :raises NotImplementedError: If the method is not implemented by the subclass.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
     def delete_index(self):
         """
         Deletes or resets the index.
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 7150b267..f2398618 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -128,6 +128,9 @@ class LocalIndex(BaseIndex):
             route_names = [self.routes[i] for i in idx]
         return scores, route_names
 
+    def aget_routes(self):
+        logger.error("Sync remove is not implemented for LocalIndex.")
+
     def delete(self, route_name: str):
         """
         Delete all records of a specific route from the index.
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index b4f6033f..5bb4b4aa 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -528,6 +528,18 @@ class PineconeIndex(BaseIndex):
         route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
         return np.array(scores), route_names
 
+    async def aget_routes(self) -> list[tuple]:
+        """
+        Asynchronously get a list of route and utterance objects currently stored in the index.
+
+        Returns:
+            List[Tuple]: A list of (route_name, utterance) objects.
+        """
+        if self.async_client is None or self.host is None:
+            raise ValueError("Async client or host are not initialized.")
+
+        return await self._async_get_routes()
+
     def delete_index(self):
         self.client.delete_index(self.index_name)
 
@@ -584,5 +596,106 @@ class PineconeIndex(BaseIndex):
         async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
             return await response.json(content_type=None)
 
+    async def _async_get_all(
+        self, prefix: str | None = None, include_metadata: bool = False
+    ) -> tuple[list[str], list[dict]]:
+        """
+        Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
+        """
+        if self.index is None:
+            raise ValueError("Index is None, could not retrieve vector IDs.")
+
+        all_vector_ids = []
+        next_page_token = None
+
+        if prefix:
+            prefix_str = f"?prefix={prefix}"
+        else:
+            prefix_str = ""
+
+        list_url = f"https://{self.host}/vectors/list{prefix_str}"
+        params: dict = {}
+        if self.namespace:
+            params["namespace"] = self.namespace
+        metadata = []
+
+        while True:
+            if next_page_token:
+                params["paginationToken"] = next_page_token
+
+            async with self.async_client.get(
+                list_url, params=params, headers={"Api-Key": self.api_key}
+            ) as response:
+                if response.status != 200:
+                    error_text = await response.text()
+                    print(f"Error listing vectors: {response.status} - {error_text}")
+                    break
+
+                response_data = await response.json(content_type=None)
+
+            vector_ids = [vec["id"] for vec in response_data.get("vectors", [])]
+            if not vector_ids:
+                break
+            all_vector_ids.extend(vector_ids)
+
+            if include_metadata:
+                metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids]
+                metadata_results = await asyncio.gather(*metadata_tasks)
+                metadata.extend(metadata_results)
+
+            next_page_token = response_data.get("pagination", {}).get("next")
+            if not next_page_token:
+                break
+
+        return all_vector_ids, metadata
+
+    async def _async_fetch_metadata(self, vector_id: str) -> dict:
+        """
+        Fetch metadata for a single vector ID asynchronously using the async_client.
+        """
+        url = f"https://{self.host}/vectors/fetch"
+
+        params = {
+            "ids": [vector_id],
+        }
+
+        headers = {
+            "Api-Key": self.api_key,
+        }
+
+        async with self.async_client.get(
+            url, params=params, headers=headers
+        ) as response:
+            if response.status != 200:
+                error_text = await response.text()
+                print(
+                    f"Error fetching metadata for vector {vector_id}: {response.status} - {error_text}"
+                )
+                return {}
+
+            try:
+                response_data = await response.json(content_type=None)
+                print(f"RESPONSE: {response_data}")
+            except Exception as e:
+                print(f"Failed to decode JSON for vector {vector_id}: {e}")
+                return {}
+
+            return (
+                response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
+            )
+
+    async def _async_get_routes(self) -> list[tuple]:
+        """
+        Gets a list of route and utterance objects currently stored in the index.
+
+        Returns:
+            List[Tuple]: A list of (route_name, utterance) objects.
+        """
+        _, metadata = await self._async_get_all(include_metadata=True)
+        print("AAAAAAAAAAAAAAAAAAA")
+        print(metadata)
+        route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata]
+        return route_tuples
+
     def __len__(self):
         return self.index.describe_index_stats()["total_vector_count"]
diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py
index 4c971d4d..52afdeec 100644
--- a/semantic_router/index/postgres.py
+++ b/semantic_router/index/postgres.py
@@ -9,6 +9,7 @@ from pydantic import BaseModel
 
 from semantic_router.index.base import BaseIndex
 from semantic_router.schema import Metric
+from semantic_router.utils.logger import logger
 
 
 class MetricPgVecOperatorMap(Enum):
@@ -456,6 +457,9 @@ class PostgresIndex(BaseIndex):
             cur.execute(f"DROP TABLE IF EXISTS {table_name}")
             self.conn.commit()
 
+    def aget_routes(self):
+        logger.error("Sync remove is not implemented for PostgresIndex.")
+
     def __len__(self):
         """
         Returns the total number of vectors in the index.
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index c1a5e28b..3077da33 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -317,6 +317,9 @@ class QdrantIndex(BaseIndex):
         route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
         return np.array(scores), route_names
 
+    def aget_routes(self):
+        logger.error("Sync remove is not implemented for QdrantIndex.")
+
     def delete_index(self):
         self.client.delete_collection(self.index_name)
 
-- 
GitLab