Skip to content
Snippets Groups Projects
Commit 54df32bb authored by Vits's avatar Vits
Browse files

Added _add_and_sync to replace add for index syncing when adding routes at startup

parent 4ffe29c6
No related branches found
No related tags found
No related merge requests found
......@@ -18,14 +18,13 @@ class BaseIndex(BaseModel):
utterances: Optional[np.ndarray] = None
dimensions: Union[int, None] = None
type: str = "base"
sync: str = "merge-force-local"
sync: Union[str, None] = None
def add(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[Any],
sync: bool = False,
):
"""
Add embeddings to the index.
......@@ -33,6 +32,18 @@ class BaseIndex(BaseModel):
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[Any],
):
"""
Add embeddings to the index and manage index syncing if necessary.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def delete(self, route_name: str):
"""
Deletes route by route name.
......
......@@ -4,6 +4,7 @@ import numpy as np
from semantic_router.index.base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger
class LocalIndex(BaseIndex):
......@@ -25,7 +26,6 @@ class LocalIndex(BaseIndex):
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
sync: bool = False,
):
embeds = np.array(embeddings) # type: ignore
routes_arr = np.array(routes)
......@@ -42,6 +42,16 @@ class LocalIndex(BaseIndex):
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
):
if self.sync is not None:
logger.warning("Sync add is not implemented for LocalIndex.")
self.add(embeddings, routes, utterances)
def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.
......
......@@ -65,7 +65,7 @@ class PineconeIndex(BaseIndex):
host: str = "",
namespace: Optional[str] = "",
base_url: Optional[str] = "https://api.pinecone.io",
sync: str = "merge-force-local",
sync: str = "local",
):
super().__init__()
self.index_name = index_name
......@@ -282,7 +282,6 @@ class PineconeIndex(BaseIndex):
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
sync: bool = False,
batch_size: int = 100,
):
"""Add vectors to Pinecone in batches."""
......@@ -290,14 +289,34 @@ class PineconeIndex(BaseIndex):
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
if sync:
local_routes = {
"routes": routes,
"utterances": utterances,
"embeddings": embeddings,
}
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
vectors_to_upsert = [
PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
for vector, route, utterance in zip(embeddings, routes, utterances)
]
for i in range(0, len(vectors_to_upsert), batch_size):
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = 100,
):
"""Add vectors to Pinecone in batches."""
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
local_routes = {
"routes": routes,
"utterances": utterances,
"embeddings": embeddings,
}
if self.sync is not None:
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
routes_to_delete: dict = {}
for route, utterance in data_to_delete:
routes_to_delete.setdefault(route, []).append(utterance)
......@@ -312,9 +331,11 @@ class PineconeIndex(BaseIndex):
]
if ids_to_delete and self.index:
self.index.delete(ids=ids_to_delete)
else:
data_to_upsert = zip(embeddings, routes, utterances)
data_to_upsert = [
(vector, route, utterance)
for vector, route, utterance in zip(embeddings, routes, utterances)
]
vectors_to_upsert = [
PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
......@@ -389,7 +410,9 @@ class PineconeIndex(BaseIndex):
if self.index
else {}
)
metadata.extend([x["metadata"] for x in res_meta["vectors"].values()])
metadata.extend(
[x["metadata"] for x in res_meta["vectors"].values()]
)
# extract metadata only
# Check if there's a next page token; if not, break the loop
......
......@@ -160,16 +160,24 @@ class QdrantIndex(BaseIndex):
**self.config,
)
def add(
def _add_and_sync(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
sync: bool = False,
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
if sync:
if self.sync is not None:
logger.warning("Sync add is not implemented for QdrantIndex")
self.add(embeddings, routes, utterances, batch_size)
def add(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
self.dimensions = self.dimensions or len(embeddings[0])
self._init_collection()
......
......@@ -466,11 +466,10 @@ class RouteLayer:
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index.add(
self.index._add_and_sync(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
sync=True,
)
def _encode(self, text: str) -> Any:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment