diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 9023061e4d20b11d3e38dc4088205594562e6a58..d98ae19e404d49040f08ac828cc252298fa70dcd 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger RETRY_WAIT_TIME = 2.5 + class BaseIndex(BaseModel): """ Base class for indices using Pydantic's BaseModel. @@ -38,12 +39,31 @@ class BaseIndex(BaseModel): function_schemas: Optional[List[Dict[str, Any]]] = None, metadata_list: List[Dict[str, Any]] = [], ): - """ - Add embeddings to the index. + """Add embeddings to the index. This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") + async def aadd( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None, + metadata_list: List[Dict[str, Any]] = [], + ): + """Add vectors to the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self.add( + embeddings=embeddings, + routes=routes, + utterances=utterances, + function_schemas=function_schemas, + metadata_list=metadata_list, + ) + def get_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -58,7 +78,7 @@ class BaseIndex(BaseModel): _, metadata = self._get_all(include_metadata=True) route_tuples = parse_route_info(metadata=metadata) return [Utterance.from_tuple(x) for x in route_tuples] - + async def aget_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -108,6 +128,14 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + async def _async_remove_and_sync(self, routes_to_delete: dict): + """ + Remove embeddings in a routes syncing process from the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self._remove_and_sync(routes_to_delete=routes_to_delete) + def delete(self, route_name: str): """ Deletes route by route name. @@ -197,8 +225,10 @@ class BaseIndex(BaseModel): value="", scope=scope, ) - - async def _async_read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + + async def _async_read_config( + self, field: str, scope: str | None = None + ) -> ConfigParameter: """Read a config parameter from the index asynchronously. :param field: The field to read. @@ -221,7 +251,7 @@ class BaseIndex(BaseModel): """ logger.warning("This method should be implemented by subclasses.") return config - + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """Write a config parameter to the index asynchronously. @@ -232,9 +262,9 @@ class BaseIndex(BaseModel): """ logger.warning("Async method not implemented.") return self._write_config(config=config) - + # _________________________ END CONFIG _________________________ - + def _read_hash(self) -> ConfigParameter: """Read the hash of the previously written index. @@ -242,7 +272,7 @@ class BaseIndex(BaseModel): :rtype: ConfigParameter """ return self._read_config(field="sr_hash") - + async def _async_read_hash(self) -> ConfigParameter: """Read the hash of the previously written index asynchronously. @@ -266,7 +296,7 @@ class BaseIndex(BaseModel): return False else: raise ValueError(f"Invalid lock value: {lock_config.value}") - + async def _ais_locked(self, scope: str | None = None) -> bool: """Check if the index is locked for a given scope (if applicable). @@ -282,7 +312,7 @@ class BaseIndex(BaseModel): return False else: raise ValueError(f"Invalid lock value: {lock_config.value}") - + def lock( self, value: bool, wait: int = 0, scope: str | None = None ) -> ConfigParameter: @@ -316,8 +346,10 @@ class BaseIndex(BaseModel): ) self._write_config(lock_param) return lock_param - - async def alock(self, value: bool, wait: int = 0, scope: str | None = None) -> ConfigParameter: + + async def alock( + self, value: bool, wait: int = 0, scope: str | None = None + ) -> ConfigParameter: """Lock/unlock the index for a given scope (if applicable). If index already locked/unlocked, raises ValueError. """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 06141bcc0f0b77467a19a39e3978ad5a1e0875f1..9ed1938c03c8a1f1b812b880121b7c6280020369 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -227,7 +227,7 @@ class PineconeIndex(BaseIndex): def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records. - + :param batch: The batch of records to upsert. :type batch: List[Dict] """ @@ -235,10 +235,10 @@ class PineconeIndex(BaseIndex): self.index.upsert(vectors=batch, namespace=self.namespace) else: raise ValueError("Index is None, could not upsert.") - + async def _async_batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records asynchronously. - + :param batch: The batch of records to upsert. :type batch: List[Dict] """ @@ -351,7 +351,9 @@ class PineconeIndex(BaseIndex): in zip([route] * len(utterances), utterances) ] if ids_to_delete and self.index: - await self._async_delete(ids=ids_to_delete, namespace=self.namespace or "") + await self._async_delete( + ids=ids_to_delete, namespace=self.namespace or "" + ) def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) @@ -376,13 +378,17 @@ class PineconeIndex(BaseIndex): } ) return route_tuples - + async def _async_get_routes_with_ids(self, route_name: str): clean_route = clean_route_name(route_name) - ids, metadata = await self._async_get_all(prefix=f"{clean_route}#", include_metadata=True) + ids, metadata = await self._async_get_all( + prefix=f"{clean_route}#", include_metadata=True + ) route_tuples = [] for id, data in zip(ids, metadata): - route_tuples.append({"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]}) + route_tuples.append( + {"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]} + ) return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): @@ -532,7 +538,7 @@ class PineconeIndex(BaseIndex): namespace="sr_config", ) return config - + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """Method to write a config parameter to the remote Pinecone index. @@ -646,7 +652,7 @@ class PineconeIndex(BaseIndex): async def _async_list_indexes(self): async with self.async_client.get(f"{self.base_url}/indexes") as response: return await response.json(content_type=None) - + async def _async_upsert( self, vectors: list[dict], @@ -682,13 +688,15 @@ class PineconeIndex(BaseIndex): json=params, ) as response: return await response.json(content_type=None) - + async def _async_delete(self, ids: list[str], namespace: str = ""): params = { "ids": ids, "namespace": namespace, } - async with self.async_client.post(f"{self.base_url}/vectors/delete", json=params) as response: + async with self.async_client.post( + f"{self.base_url}/vectors/delete", json=params + ) as response: return await response.json(content_type=None) async def _async_describe_index(self, name: str): diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 5dc6aa4c9923c7d353813011c1a4c59bdf3897c9..bd02ee6f98be34c061e8909c08b126af34ebb24a 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -585,8 +585,10 @@ class BaseRouter(BaseModel): # unlock index after sync _ = self.index.lock(value=False) return diff.to_utterance_str() - - async def async_sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: + + async def async_sync( + self, sync_mode: str, force: bool = False, wait: int = 0 + ) -> List[str]: """Runs a sync of the local routes with the remote index. :param sync_mode: The mode to sync the routes with the remote index. @@ -660,7 +662,9 @@ class BaseRouter(BaseModel): # update hash self._write_hash() - async def _async_execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): + async def _async_execute_sync_strategy( + self, strategy: Dict[str, Dict[str, List[Utterance]]] + ): """Executes the provided sync strategy, either deleting or upserting routes from the local and remote instances as defined in the strategy. @@ -806,7 +810,7 @@ class BaseRouter(BaseModel): :type route: Route """ raise NotImplementedError("This method must be implemented by subclasses.") - + async def aadd(self, routes: List[Route] | Route): """Add a route to the local SemanticRouter and index asynchronously. @@ -929,7 +933,7 @@ class BaseRouter(BaseModel): hash_config = config.get_hash() self.index._write_config(config=hash_config) return hash_config - + async def _async_write_hash(self) -> ConfigParameter: config = self.to_config() hash_config = config.get_hash() @@ -951,7 +955,7 @@ class BaseRouter(BaseModel): return True else: return False - + async def async_is_synced(self) -> bool: """Check if the local and remote route layer instances are synchronized asynchronously. @@ -997,7 +1001,7 @@ class BaseRouter(BaseModel): local_utterances=local_utterances, remote_utterances=remote_utterances ) return diff_obj.to_utterance_str(include_metadata=include_metadata) - + async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]: """Get the difference between the local and remote utterances asynchronously. Returns a list of strings showing what is different in the remote when diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 6b8b98ae41e73f00bb69373cc6b7f6470a2b41d6..9151dc715af25a1a94ff4e231c5d02ca7ceabf5b 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -883,4 +883,4 @@ class TestAsyncSemanticRouter: assert await route_layer.async_is_synced() # clear index - route_layer.index.index.delete(namespace="", delete_all=True) \ No newline at end of file + route_layer.index.index.delete(namespace="", delete_all=True)