From 748780c53227ba401616d8beb562cb4d8ba9b689 Mon Sep 17 00:00:00 2001 From: Vittorio <vittorio.mayellaro.dev@gmail.com> Date: Fri, 23 Aug 2024 14:05:06 +0200 Subject: [PATCH] Implemented aget_routes async method for pinecone index --- semantic_router/index/base.py | 11 +++ semantic_router/index/local.py | 3 + semantic_router/index/pinecone.py | 113 ++++++++++++++++++++++++++++++ semantic_router/index/postgres.py | 4 ++ semantic_router/index/qdrant.py | 3 + 5 files changed, 134 insertions(+) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index d0f12ac6..a23f92bf 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -90,6 +90,17 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def aget_routes(self): + """ + Asynchronously get a list of route and utterance objects currently stored in the index. + This method should be implemented by subclasses. + + :returns: A list of tuples, each containing a route name and an associated utterance. + :rtype: list[tuple] + :raises NotImplementedError: If the method is not implemented by the subclass. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + def delete_index(self): """ Deletes or resets the index. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 7150b267..f2398618 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -128,6 +128,9 @@ class LocalIndex(BaseIndex): route_names = [self.routes[i] for i in idx] return scores, route_names + def aget_routes(self): + logger.error("Sync remove is not implemented for LocalIndex.") + def delete(self, route_name: str): """ Delete all records of a specific route from the index. diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index b4f6033f..5bb4b4aa 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -528,6 +528,18 @@ class PineconeIndex(BaseIndex): route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names + async def aget_routes(self) -> list[tuple]: + """ + Asynchronously get a list of route and utterance objects currently stored in the index. + + Returns: + List[Tuple]: A list of (route_name, utterance) objects. + """ + if self.async_client is None or self.host is None: + raise ValueError("Async client or host are not initialized.") + + return await self._async_get_routes() + def delete_index(self): self.client.delete_index(self.index_name) @@ -584,5 +596,106 @@ class PineconeIndex(BaseIndex): async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: return await response.json(content_type=None) + async def _async_get_all( + self, prefix: str | None = None, include_metadata: bool = False + ) -> tuple[list[str], list[dict]]: + """ + Retrieves all vector IDs from the Pinecone index using pagination asynchronously. + """ + if self.index is None: + raise ValueError("Index is None, could not retrieve vector IDs.") + + all_vector_ids = [] + next_page_token = None + + if prefix: + prefix_str = f"?prefix={prefix}" + else: + prefix_str = "" + + list_url = f"https://{self.host}/vectors/list{prefix_str}" + params: dict = {} + if self.namespace: + params["namespace"] = self.namespace + metadata = [] + + while True: + if next_page_token: + params["paginationToken"] = next_page_token + + async with self.async_client.get( + list_url, params=params, headers={"Api-Key": self.api_key} + ) as response: + if response.status != 200: + error_text = await response.text() + print(f"Error listing vectors: {response.status} - {error_text}") + break + + response_data = await response.json(content_type=None) + + vector_ids = [vec["id"] for vec in response_data.get("vectors", [])] + if not vector_ids: + break + all_vector_ids.extend(vector_ids) + + if include_metadata: + metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids] + metadata_results = await asyncio.gather(*metadata_tasks) + metadata.extend(metadata_results) + + next_page_token = response_data.get("pagination", {}).get("next") + if not next_page_token: + break + + return all_vector_ids, metadata + + async def _async_fetch_metadata(self, vector_id: str) -> dict: + """ + Fetch metadata for a single vector ID asynchronously using the async_client. + """ + url = f"https://{self.host}/vectors/fetch" + + params = { + "ids": [vector_id], + } + + headers = { + "Api-Key": self.api_key, + } + + async with self.async_client.get( + url, params=params, headers=headers + ) as response: + if response.status != 200: + error_text = await response.text() + print( + f"Error fetching metadata for vector {vector_id}: {response.status} - {error_text}" + ) + return {} + + try: + response_data = await response.json(content_type=None) + print(f"RESPONSE: {response_data}") + except Exception as e: + print(f"Failed to decode JSON for vector {vector_id}: {e}") + return {} + + return ( + response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) + ) + + async def _async_get_routes(self) -> list[tuple]: + """ + Gets a list of route and utterance objects currently stored in the index. + + Returns: + List[Tuple]: A list of (route_name, utterance) objects. + """ + _, metadata = await self._async_get_all(include_metadata=True) + print("AAAAAAAAAAAAAAAAAAA") + print(metadata) + route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] + return route_tuples + def __len__(self): return self.index.describe_index_stats()["total_vector_count"] diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 4c971d4d..52afdeec 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from semantic_router.index.base import BaseIndex from semantic_router.schema import Metric +from semantic_router.utils.logger import logger class MetricPgVecOperatorMap(Enum): @@ -456,6 +457,9 @@ class PostgresIndex(BaseIndex): cur.execute(f"DROP TABLE IF EXISTS {table_name}") self.conn.commit() + def aget_routes(self): + logger.error("Sync remove is not implemented for PostgresIndex.") + def __len__(self): """ Returns the total number of vectors in the index. diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index c1a5e28b..3077da33 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -317,6 +317,9 @@ class QdrantIndex(BaseIndex): route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results] return np.array(scores), route_names + def aget_routes(self): + logger.error("Sync remove is not implemented for QdrantIndex.") + def delete_index(self): self.client.delete_collection(self.index_name) -- GitLab