diff --git a/docs/source/conf.py b/docs/source/conf.py index 74b46335b157b6acaf1f58381e147d0bdc509293..0174df009fc6a70dd6934a30918a29c79ebc771b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ sys.path.insert(0, os.path.abspath("../..")) # Source code dir relative to this project = "Semantic Router" copyright = "2024, Aurelio AI" author = "Aurelio AI" -release = "0.0.60" +release = "0.0.61" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index f69ceca032b158cca8e0ac1bf2eb59b559e84db4..7b294000957b93d96e1fa8dd487875fcac058881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.60" +version = "0.0.61" description = "Super fast semantic router for AI decision making" authors = [ "James Briggs <james@aurelio.ai>", diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 0664671f657918c37119b7cee93282318f21f6a7..e8ed01d54d6497bf1fc1ba60a2b3b611fe6e6b1a 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -4,4 +4,4 @@ from semantic_router.route import Route __all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] -__version__ = "0.0.60" +__version__ = "0.0.61" diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index d0f12ac682ab1686dd9d931e73082d7695a2e1fe..a23f92bfc5977c55db8924dc735127f912dd7215 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -90,6 +90,17 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def aget_routes(self): + """ + Asynchronously get a list of route and utterance objects currently stored in the index. + This method should be implemented by subclasses. + + :returns: A list of tuples, each containing a route name and an associated utterance. + :rtype: list[tuple] + :raises NotImplementedError: If the method is not implemented by the subclass. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + def delete_index(self): """ Deletes or resets the index. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 7150b267587715d09124eb6f94ab242a7795c3f9..f239861866583a35deff52b90a6b8ad28426914c 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -128,6 +128,9 @@ class LocalIndex(BaseIndex): route_names = [self.routes[i] for i in idx] return scores, route_names + def aget_routes(self): + logger.error("Sync remove is not implemented for LocalIndex.") + def delete(self, route_name: str): """ Delete all records of a specific route from the index. diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 931bb19c05aedb9d0fdafc93afe96a39f2283e68..a923852cd011c6909ab9b91e77d88385b0f34c6d 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -528,6 +528,18 @@ class PineconeIndex(BaseIndex): route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names + async def aget_routes(self) -> list[tuple]: + """ + Asynchronously get a list of route and utterance objects currently stored in the index. + + Returns: + List[Tuple]: A list of (route_name, utterance) objects. + """ + if self.async_client is None or self.host is None: + raise ValueError("Async client or host are not initialized.") + + return await self._async_get_routes() + def delete_index(self): self.client.delete_index(self.index_name) @@ -584,5 +596,101 @@ class PineconeIndex(BaseIndex): async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: return await response.json(content_type=None) + async def _async_get_all( + self, prefix: Optional[str] = None, include_metadata: bool = False + ) -> tuple[list[str], list[dict]]: + """ + Retrieves all vector IDs from the Pinecone index using pagination asynchronously. + """ + if self.index is None: + raise ValueError("Index is None, could not retrieve vector IDs.") + + all_vector_ids = [] + next_page_token = None + + if prefix: + prefix_str = f"?prefix={prefix}" + else: + prefix_str = "" + + list_url = f"https://{self.host}/vectors/list{prefix_str}" + params: dict = {} + if self.namespace: + params["namespace"] = self.namespace + metadata = [] + + while True: + if next_page_token: + params["paginationToken"] = next_page_token + + async with self.async_client.get( + list_url, params=params, headers={"Api-Key": self.api_key} + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Error fetching vectors: {error_text}") + break + + response_data = await response.json(content_type=None) + + vector_ids = [vec["id"] for vec in response_data.get("vectors", [])] + if not vector_ids: + break + all_vector_ids.extend(vector_ids) + + if include_metadata: + metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids] + metadata_results = await asyncio.gather(*metadata_tasks) + metadata.extend(metadata_results) + + next_page_token = response_data.get("pagination", {}).get("next") + if not next_page_token: + break + + return all_vector_ids, metadata + + async def _async_fetch_metadata(self, vector_id: str) -> dict: + """ + Fetch metadata for a single vector ID asynchronously using the async_client. + """ + url = f"https://{self.host}/vectors/fetch" + + params = { + "ids": [vector_id], + } + + headers = { + "Api-Key": self.api_key, + } + + async with self.async_client.get( + url, params=params, headers=headers + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Error fetching metadata: {error_text}") + return {} + + try: + response_data = await response.json(content_type=None) + except Exception as e: + logger.warning(f"No metadata found for vector {vector_id}: {e}") + return {} + + return ( + response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) + ) + + async def _async_get_routes(self) -> list[tuple]: + """ + Gets a list of route and utterance objects currently stored in the index. + + Returns: + List[Tuple]: A list of (route_name, utterance) objects. + """ + _, metadata = await self._async_get_all(include_metadata=True) + route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] + return route_tuples + def __len__(self): return self.index.describe_index_stats()["total_vector_count"] diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 4c971d4d1a71ea51a05f0698b4f7aa971aeeef5d..52afdeec961242475cc803fa861c5b7cd9eaed17 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from semantic_router.index.base import BaseIndex from semantic_router.schema import Metric +from semantic_router.utils.logger import logger class MetricPgVecOperatorMap(Enum): @@ -456,6 +457,9 @@ class PostgresIndex(BaseIndex): cur.execute(f"DROP TABLE IF EXISTS {table_name}") self.conn.commit() + def aget_routes(self): + logger.error("Sync remove is not implemented for PostgresIndex.") + def __len__(self): """ Returns the total number of vectors in the index. diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index c1a5e28b58792091e4fb568248827954547e37a5..3077da3395fc102a5f34c55a12209f5f908a9e2c 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -317,6 +317,9 @@ class QdrantIndex(BaseIndex): route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results] return np.array(scores), route_names + def aget_routes(self): + logger.error("Sync remove is not implemented for QdrantIndex.") + def delete_index(self): self.client.delete_collection(self.index_name) diff --git a/tests/unit/encoders/test_vit.py b/tests/unit/encoders/test_vit.py index 848093b35e1226813cdab30e8f87ac6481d54e46..85e66a9f961aa6b68d5b8f330988d713faab5f1f 100644 --- a/tests/unit/encoders/test_vit.py +++ b/tests/unit/encoders/test_vit.py @@ -7,7 +7,6 @@ from PIL import Image from semantic_router.encoders import VitEncoder test_model_name = "aurelio-ai/sr-test-vit" -vit_encoder = VitEncoder(name=test_model_name) embed_dim = 32 if torch.cuda.is_available(): @@ -44,15 +43,11 @@ class TestVitEncoder: with pytest.raises(ImportError): VitEncoder() - def test_vit_encoder__import_errors_torchvision(self, mocker): - mocker.patch.dict("sys.modules", {"torchvision": None}) - with pytest.raises(ImportError): - VitEncoder() - @pytest.mark.skipif( os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" ) def test_vit_encoder_initialization(self): + vit_encoder = VitEncoder(name=test_model_name) assert vit_encoder.name == test_model_name assert vit_encoder.type == "huggingface" assert vit_encoder.score_threshold == 0.5 @@ -62,6 +57,7 @@ class TestVitEncoder: os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" ) def test_vit_encoder_call(self, dummy_pil_image): + vit_encoder = VitEncoder(name=test_model_name) encoded_images = vit_encoder([dummy_pil_image] * 3) assert len(encoded_images) == 3 @@ -71,6 +67,7 @@ class TestVitEncoder: os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" ) def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): + vit_encoder = VitEncoder(name=test_model_name) encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image]) assert len(encoded_images) == 2 @@ -80,6 +77,7 @@ class TestVitEncoder: os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" ) def test_vit_encoder_process_images_device(self, dummy_pil_image): + vit_encoder = VitEncoder(name=test_model_name) imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"] assert imgs.device.type == device @@ -88,6 +86,7 @@ class TestVitEncoder: os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" ) def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img): + vit_encoder = VitEncoder(name=test_model_name) rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img) assert rgb_image.mode == "RGB"