diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index e6638fb11305de77706065e135663c21be4aeaff..6d0969fc9c482ea1a9c6f4181e63b73f0d5d9bf3 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -18,6 +18,7 @@ class BaseIndex(BaseModel): utterances: Optional[np.ndarray] = None dimensions: Union[int, None] = None type: str = "base" + sync: str = "merge-force-local" def add( self, embeddings: List[List[float]], routes: List[str], utterances: List[Any] @@ -73,6 +74,21 @@ class BaseIndex(BaseModel): This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") + + def _sync_index(self, local_routes: dict): + """ + Synchronize the local index with the remote index based on the specified mode. + Modes: + - "error": Raise an error if local and remote are not synchronized. + - "remote": Take remote as the source of truth and update local to align. + - "local": Take local as the source of truth and update remote to align. + - "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely. + - "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely. + - "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely. + + This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") class Config: arbitrary_types_allowed = True diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 7d3828f44c2a0f89b384775a0137b3483359886d..48c186abc635f98f3f5ec26a42ae7fd7c585c304 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -65,6 +65,7 @@ class PineconeIndex(BaseIndex): host: str = "", namespace: Optional[str] = "", base_url: Optional[str] = "https://api.pinecone.io", + sync: str = "merge-force-local", ): super().__init__() self.index_name = index_name @@ -77,6 +78,7 @@ class PineconeIndex(BaseIndex): self.type = "pinecone" self.api_key = api_key or os.getenv("PINECONE_API_KEY") self.base_url = base_url + self.sync = sync if self.api_key is None: raise ValueError("Pinecone API key is required.") @@ -195,6 +197,57 @@ class PineconeIndex(BaseIndex): logger.warning("Index could not be initialized.") self.host = index_stats["host"] if index_stats else None + def _sync_index(self, local_routes: dict): + remote_routes = self.get_routes() + remote_dict = {route: set() for route, _ in remote_routes} + for route, utterance in remote_routes: + remote_dict[route].add(utterance) + + local_dict = {route: set() for route in local_routes['routes']} + for route, utterance in zip(local_routes['routes'], local_routes['utterances']): + local_dict[route].add(utterance) + + all_routes = set(remote_dict.keys()).union(local_dict.keys()) + + routes_to_add = [] + routes_to_delete = [] + + for route in all_routes: + local_utterances = local_dict.get(route, set()) + remote_utterances = remote_dict.get(route, set()) + + if self.sync == "error": + if local_utterances != remote_utterances: + raise ValueError(f"Synchronization error: Differences found in route '{route}'") + utterances_to_include = set() + elif self.sync == "remote": + utterances_to_include = set() + elif self.sync == "local": + utterances_to_include = local_utterances - remote_utterances + routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances]) + elif self.sync == "merge-force-remote": + if route in local_dict and route not in remote_dict: + utterances_to_include = local_utterances + else: + utterances_to_include = set() + elif self.sync == "merge-force-local": + if route in local_dict: + utterances_to_include = local_utterances - remote_utterances + routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances]) + else: + utterances_to_include = set() + elif self.sync == "merge": + utterances_to_include = local_utterances - remote_utterances + else: + raise ValueError("Invalid sync mode specified") + + for utterance in utterances_to_include: + indices = [i for i, x in enumerate(local_routes['utterances']) if x == utterance and local_routes['routes'][i] == route] + routes_to_add.extend([(local_routes['embeddings'][idx], route, utterance) for idx in indices]) + + return routes_to_add, routes_to_delete + + def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" if self.index is not None: @@ -208,15 +261,33 @@ class PineconeIndex(BaseIndex): routes: List[str], utterances: List[str], batch_size: int = 100, + sync: bool = False, ): """Add vectors to Pinecone in batches.""" if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = self._init_index(force_create=True) + if sync: + local_routes = {"routes": routes, "utterances": utterances, "embeddings": embeddings} + data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes) + + routes_to_delete = {} + for route, utterance in data_to_delete: + routes_to_delete.setdefault(route, []).append(utterance) + + for route, utterances in routes_to_delete.items(): + remote_routes = self._get_routes_with_ids(route_name=route) + ids_to_delete = [r["id"] for r in remote_routes if (r["route"], r["utterance"]) in zip([route]*len(utterances), utterances)] + if ids_to_delete: + self.index.delete(ids=ids_to_delete) + + else: + data_to_upsert = zip(embeddings, routes, utterances) + vectors_to_upsert = [ PineconeRecord(values=vector, route=route, utterance=utterance).to_dict() - for vector, route, utterance in zip(embeddings, routes, utterances) + for vector, route, utterance in data_to_upsert ] for i in range(0, len(vectors_to_upsert), batch_size): @@ -227,6 +298,15 @@ class PineconeIndex(BaseIndex): clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") return ids + + def _get_routes_with_ids(self, route_name: str): + clean_route = clean_route_name(route_name) + ids, _ = self._get_all(prefix=f"{clean_route}#") + route_tuples = [] + for id in ids: + res_meta = self.index.fetch(ids=[id], namespace=self.namespace) + route_tuples.extend([{"id": id, "route": x["metadata"]["sr_route"], "utterance": x["metadata"]["sr_utterance"]} for x in res_meta["vectors"].values()]) + return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ diff --git a/semantic_router/layer.py b/semantic_router/layer.py index c87c79358b291caaf5fc7d24eafde67830fcbae7..3ac1596c58c615200989983722bd122e82122623 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -470,6 +470,7 @@ class RouteLayer: embeddings=embedded_utterances, routes=route_names, utterances=all_utterances, + sync=True, ) def _encode(self, text: str) -> Any: