Skip to content
Snippets Groups Projects
Commit 4475b300 authored by Vits's avatar Vits
Browse files

Introducing sync setting and management to sync data between remote and local layers

parent 265f27a7
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ class BaseIndex(BaseModel):
utterances: Optional[np.ndarray] = None
dimensions: Union[int, None] = None
type: str = "base"
sync: str = "merge-force-local"
def add(
self, embeddings: List[List[float]], routes: List[str], utterances: List[Any]
......@@ -73,6 +74,21 @@ class BaseIndex(BaseModel):
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def _sync_index(self, local_routes: dict):
"""
Synchronize the local index with the remote index based on the specified mode.
Modes:
- "error": Raise an error if local and remote are not synchronized.
- "remote": Take remote as the source of truth and update local to align.
- "local": Take local as the source of truth and update remote to align.
- "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely.
- "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely.
- "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class Config:
arbitrary_types_allowed = True
......@@ -65,6 +65,7 @@ class PineconeIndex(BaseIndex):
host: str = "",
namespace: Optional[str] = "",
base_url: Optional[str] = "https://api.pinecone.io",
sync: str = "merge-force-local",
):
super().__init__()
self.index_name = index_name
......@@ -77,6 +78,7 @@ class PineconeIndex(BaseIndex):
self.type = "pinecone"
self.api_key = api_key or os.getenv("PINECONE_API_KEY")
self.base_url = base_url
self.sync = sync
if self.api_key is None:
raise ValueError("Pinecone API key is required.")
......@@ -195,6 +197,57 @@ class PineconeIndex(BaseIndex):
logger.warning("Index could not be initialized.")
self.host = index_stats["host"] if index_stats else None
def _sync_index(self, local_routes: dict):
remote_routes = self.get_routes()
remote_dict = {route: set() for route, _ in remote_routes}
for route, utterance in remote_routes:
remote_dict[route].add(utterance)
local_dict = {route: set() for route in local_routes['routes']}
for route, utterance in zip(local_routes['routes'], local_routes['utterances']):
local_dict[route].add(utterance)
all_routes = set(remote_dict.keys()).union(local_dict.keys())
routes_to_add = []
routes_to_delete = []
for route in all_routes:
local_utterances = local_dict.get(route, set())
remote_utterances = remote_dict.get(route, set())
if self.sync == "error":
if local_utterances != remote_utterances:
raise ValueError(f"Synchronization error: Differences found in route '{route}'")
utterances_to_include = set()
elif self.sync == "remote":
utterances_to_include = set()
elif self.sync == "local":
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances])
elif self.sync == "merge-force-remote":
if route in local_dict and route not in remote_dict:
utterances_to_include = local_utterances
else:
utterances_to_include = set()
elif self.sync == "merge-force-local":
if route in local_dict:
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances])
else:
utterances_to_include = set()
elif self.sync == "merge":
utterances_to_include = local_utterances - remote_utterances
else:
raise ValueError("Invalid sync mode specified")
for utterance in utterances_to_include:
indices = [i for i, x in enumerate(local_routes['utterances']) if x == utterance and local_routes['routes'][i] == route]
routes_to_add.extend([(local_routes['embeddings'][idx], route, utterance) for idx in indices])
return routes_to_add, routes_to_delete
def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records."""
if self.index is not None:
......@@ -208,15 +261,33 @@ class PineconeIndex(BaseIndex):
routes: List[str],
utterances: List[str],
batch_size: int = 100,
sync: bool = False,
):
"""Add vectors to Pinecone in batches."""
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
if sync:
local_routes = {"routes": routes, "utterances": utterances, "embeddings": embeddings}
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
routes_to_delete = {}
for route, utterance in data_to_delete:
routes_to_delete.setdefault(route, []).append(utterance)
for route, utterances in routes_to_delete.items():
remote_routes = self._get_routes_with_ids(route_name=route)
ids_to_delete = [r["id"] for r in remote_routes if (r["route"], r["utterance"]) in zip([route]*len(utterances), utterances)]
if ids_to_delete:
self.index.delete(ids=ids_to_delete)
else:
data_to_upsert = zip(embeddings, routes, utterances)
vectors_to_upsert = [
PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
for vector, route, utterance in zip(embeddings, routes, utterances)
for vector, route, utterance in data_to_upsert
]
for i in range(0, len(vectors_to_upsert), batch_size):
......@@ -227,6 +298,15 @@ class PineconeIndex(BaseIndex):
clean_route = clean_route_name(route_name)
ids, _ = self._get_all(prefix=f"{clean_route}#")
return ids
def _get_routes_with_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, _ = self._get_all(prefix=f"{clean_route}#")
route_tuples = []
for id in ids:
res_meta = self.index.fetch(ids=[id], namespace=self.namespace)
route_tuples.extend([{"id": id, "route": x["metadata"]["sr_route"], "utterance": x["metadata"]["sr_utterance"]} for x in res_meta["vectors"].values()])
return route_tuples
def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
"""
......
......@@ -470,6 +470,7 @@ class RouteLayer:
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
sync=True,
)
def _encode(self, text: str) -> Any:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment