Skip to content
Snippets Groups Projects
Unverified Commit 8f5aad38 authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge pull request #344 from aurelio-labs/vittorio/340-add-sync-setting-parameter-to-index

feat: add sync setting parameter to index
parents 0169f52d 54df32bb
No related branches found
No related tags found
No related merge requests found
[tool.poetry]
name = "semantic-router"
version = "0.0.50"
version = "0.0.51"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <james@aurelio.ai>",
......
......@@ -18,9 +18,13 @@ class BaseIndex(BaseModel):
utterances: Optional[np.ndarray] = None
dimensions: Union[int, None] = None
type: str = "base"
sync: Union[str, None] = None
def add(
self, embeddings: List[List[float]], routes: List[str], utterances: List[Any]
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[Any],
):
"""
Add embeddings to the index.
......@@ -28,6 +32,18 @@ class BaseIndex(BaseModel):
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[Any],
):
"""
Add embeddings to the index and manage index syncing if necessary.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def delete(self, route_name: str):
"""
Deletes route by route name.
......@@ -74,5 +90,20 @@ class BaseIndex(BaseModel):
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def _sync_index(self, local_routes: dict):
"""
Synchronize the local index with the remote index based on the specified mode.
Modes:
- "error": Raise an error if local and remote are not synchronized.
- "remote": Take remote as the source of truth and update local to align.
- "local": Take local as the source of truth and update remote to align.
- "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely.
- "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely.
- "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class Config:
arbitrary_types_allowed = True
......@@ -4,6 +4,7 @@ import numpy as np
from semantic_router.index.base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger
class LocalIndex(BaseIndex):
......@@ -21,7 +22,10 @@ class LocalIndex(BaseIndex):
arbitrary_types_allowed = True
def add(
self, embeddings: List[List[float]], routes: List[str], utterances: List[str]
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
):
embeds = np.array(embeddings) # type: ignore
routes_arr = np.array(routes)
......@@ -38,6 +42,16 @@ class LocalIndex(BaseIndex):
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
):
if self.sync is not None:
logger.warning("Sync add is not implemented for LocalIndex.")
self.add(embeddings, routes, utterances)
def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.
......
......@@ -65,6 +65,7 @@ class PineconeIndex(BaseIndex):
host: str = "",
namespace: Optional[str] = "",
base_url: Optional[str] = "https://api.pinecone.io",
sync: str = "local",
):
super().__init__()
self.index_name = index_name
......@@ -77,6 +78,7 @@ class PineconeIndex(BaseIndex):
self.type = "pinecone"
self.api_key = api_key or os.getenv("PINECONE_API_KEY")
self.base_url = base_url
self.sync = sync
if self.api_key is None:
raise ValueError("Pinecone API key is required.")
......@@ -195,6 +197,79 @@ class PineconeIndex(BaseIndex):
logger.warning("Index could not be initialized.")
self.host = index_stats["host"] if index_stats else None
def _sync_index(self, local_routes: dict):
remote_routes = self.get_routes()
remote_dict: dict = {route: set() for route, _ in remote_routes}
for route, utterance in remote_routes:
remote_dict[route].add(utterance)
local_dict: dict = {route: set() for route in local_routes["routes"]}
for route, utterance in zip(local_routes["routes"], local_routes["utterances"]):
local_dict[route].add(utterance)
all_routes = set(remote_dict.keys()).union(local_dict.keys())
routes_to_add = []
routes_to_delete = []
for route in all_routes:
local_utterances = local_dict.get(route, set())
remote_utterances = remote_dict.get(route, set())
if self.sync == "error":
if local_utterances != remote_utterances:
raise ValueError(
f"Synchronization error: Differences found in route '{route}'"
)
utterances_to_include: set = set()
elif self.sync == "remote":
utterances_to_include = set()
elif self.sync == "local":
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend(
[
(route, utterance)
for utterance in remote_utterances
if utterance not in local_utterances
]
)
elif self.sync == "merge-force-remote":
if route in local_dict and route not in remote_dict:
utterances_to_include = local_utterances
else:
utterances_to_include = set()
elif self.sync == "merge-force-local":
if route in local_dict:
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend(
[
(route, utterance)
for utterance in remote_utterances
if utterance not in local_utterances
]
)
else:
utterances_to_include = set()
elif self.sync == "merge":
utterances_to_include = local_utterances - remote_utterances
else:
raise ValueError("Invalid sync mode specified")
for utterance in utterances_to_include:
indices = [
i
for i, x in enumerate(local_routes["utterances"])
if x == utterance and local_routes["routes"][i] == route
]
routes_to_add.extend(
[
(local_routes["embeddings"][idx], route, utterance)
for idx in indices
]
)
return routes_to_add, routes_to_delete
def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records."""
if self.index is not None:
......@@ -223,11 +298,73 @@ class PineconeIndex(BaseIndex):
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = 100,
):
"""Add vectors to Pinecone in batches."""
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
local_routes = {
"routes": routes,
"utterances": utterances,
"embeddings": embeddings,
}
if self.sync is not None:
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
routes_to_delete: dict = {}
for route, utterance in data_to_delete:
routes_to_delete.setdefault(route, []).append(utterance)
for route, utterances in routes_to_delete.items():
remote_routes = self._get_routes_with_ids(route_name=route)
ids_to_delete = [
r["id"]
for r in remote_routes
if (r["route"], r["utterance"])
in zip([route] * len(utterances), utterances)
]
if ids_to_delete and self.index:
self.index.delete(ids=ids_to_delete)
else:
data_to_upsert = [
(vector, route, utterance)
for vector, route, utterance in zip(embeddings, routes, utterances)
]
vectors_to_upsert = [
PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
for vector, route, utterance in data_to_upsert
]
for i in range(0, len(vectors_to_upsert), batch_size):
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)
def _get_route_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, _ = self._get_all(prefix=f"{clean_route}#")
return ids
def _get_routes_with_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, metadata = self._get_all(prefix=f"{clean_route}#", include_metadata=True)
route_tuples = []
for id, data in zip(ids, metadata):
route_tuples.append(
{
"id": id,
"route": data["sr_route"],
"utterance": data["sr_utterance"],
}
)
return route_tuples
def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
"""
Retrieves all vector IDs from the Pinecone index using pagination.
......@@ -267,9 +404,16 @@ class PineconeIndex(BaseIndex):
# if we need metadata, we fetch it
if include_metadata:
res_meta = self.index.fetch(ids=vector_ids, namespace=self.namespace)
for id in vector_ids:
res_meta = (
self.index.fetch(ids=[id], namespace=self.namespace)
if self.index
else {}
)
metadata.extend(
[x["metadata"] for x in res_meta["vectors"].values()]
)
# extract metadata only
metadata.extend([x["metadata"] for x in res_meta["vectors"].values()])
# Check if there's a next page token; if not, break the loop
next_page_token = response_data.get("pagination", {}).get("next")
......
......@@ -160,6 +160,17 @@ class QdrantIndex(BaseIndex):
**self.config,
)
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
if self.sync is not None:
logger.warning("Sync add is not implemented for QdrantIndex")
self.add(embeddings, routes, utterances, batch_size)
def add(
self,
embeddings: List[List[float]],
......
......@@ -466,7 +466,7 @@ class RouteLayer:
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index.add(
self.index._add_and_sync(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment