Skip to content
Snippets Groups Projects
Commit 9e30a21f authored by Vits's avatar Vits
Browse files

Linting, formatting and removing unnecessary prints

parent 5115a12a
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ class BaseIndex(BaseModel): ...@@ -32,7 +32,7 @@ class BaseIndex(BaseModel):
This method should be implemented by subclasses. This method should be implemented by subclasses.
""" """
raise NotImplementedError("This method should be implemented by subclasses.") raise NotImplementedError("This method should be implemented by subclasses.")
def _remove_and_sync(self, routes_to_delete: dict): def _remove_and_sync(self, routes_to_delete: dict):
""" """
Remove embeddings in a routes syncing process from the index. Remove embeddings in a routes syncing process from the index.
...@@ -86,7 +86,9 @@ class BaseIndex(BaseModel): ...@@ -86,7 +86,9 @@ class BaseIndex(BaseModel):
""" """
raise NotImplementedError("This method should be implemented by subclasses.") raise NotImplementedError("This method should be implemented by subclasses.")
def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): def _sync_index(
self, local_route_names: List[str], local_utterances: List[str], dimensions: int
):
""" """
Synchronize the local index with the remote index based on the specified mode. Synchronize the local index with the remote index based on the specified mode.
Modes: Modes:
......
...@@ -46,7 +46,9 @@ class LocalIndex(BaseIndex): ...@@ -46,7 +46,9 @@ class LocalIndex(BaseIndex):
if self.sync is not None: if self.sync is not None:
logger.warning("Sync remove is not implemented for LocalIndex.") logger.warning("Sync remove is not implemented for LocalIndex.")
def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): def _sync_index(
self, local_route_names: List[str], local_utterances: List[str], dimensions: int
):
if self.sync is not None: if self.sync is not None:
logger.error("Sync remove is not implemented for LocalIndex.") logger.error("Sync remove is not implemented for LocalIndex.")
......
...@@ -201,7 +201,9 @@ class PineconeIndex(BaseIndex): ...@@ -201,7 +201,9 @@ class PineconeIndex(BaseIndex):
logger.warning("Index could not be initialized.") logger.warning("Index could not be initialized.")
self.host = index_stats["host"] if index_stats else None self.host = index_stats["host"] if index_stats else None
def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): def _sync_index(
self, local_route_names: List[str], local_utterances: List[str], dimensions: int
):
if self.index is None: if self.index is None:
self.dimensions = self.dimensions or dimensions self.dimensions = self.dimensions or dimensions
self.index = self._init_index(force_create=True) self.index = self._init_index(force_create=True)
...@@ -253,7 +255,7 @@ class PineconeIndex(BaseIndex): ...@@ -253,7 +255,7 @@ class PineconeIndex(BaseIndex):
layer_routes[route] = list(local_utterances) layer_routes[route] = list(local_utterances)
elif self.sync == "merge-force-remote": elif self.sync == "merge-force-remote":
if route in local_dict and route not in remote_dict: if route in local_dict and route not in remote_dict:
utterances_to_include = local_utterances utterances_to_include = set(local_utterances)
if local_utterances: if local_utterances:
layer_routes[route] = list(local_utterances) layer_routes[route] = list(local_utterances)
else: else:
...@@ -288,8 +290,6 @@ class PineconeIndex(BaseIndex): ...@@ -288,8 +290,6 @@ class PineconeIndex(BaseIndex):
for utterance in utterances_to_include: for utterance in utterances_to_include:
routes_to_add.append((route, utterance)) routes_to_add.append((route, utterance))
logger.info(f"Layer routes: {layer_routes}")
return routes_to_add, routes_to_delete, layer_routes return routes_to_add, routes_to_delete, layer_routes
def _batch_upsert(self, batch: List[Dict]): def _batch_upsert(self, batch: List[Dict]):
......
...@@ -163,8 +163,10 @@ class QdrantIndex(BaseIndex): ...@@ -163,8 +163,10 @@ class QdrantIndex(BaseIndex):
def _remove_and_sync(self, routes_to_delete: dict): def _remove_and_sync(self, routes_to_delete: dict):
if self.sync is not None: if self.sync is not None:
logger.error("Sync remove is not implemented for LocalIndex.") logger.error("Sync remove is not implemented for LocalIndex.")
def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): def _sync_index(
self, local_route_names: List[str], local_utterances: List[str], dimensions: int
):
if self.sync is not None: if self.sync is not None:
logger.error("Sync remove is not implemented for QdrantIndex.") logger.error("Sync remove is not implemented for QdrantIndex.")
......
...@@ -487,14 +487,13 @@ class RouteLayer: ...@@ -487,14 +487,13 @@ class RouteLayer:
utterances=all_utterances, utterances=all_utterances,
) )
def _add_and_sync_routes(self, routes: List[Route]): def _add_and_sync_routes(self, routes: List[Route]):
# create embeddings for all routes and sync at startup with remote ones based on sync setting # create embeddings for all routes and sync at startup with remote ones based on sync setting
local_route_names, local_utterances = self._extract_routes_details(routes) local_route_names, local_utterances = self._extract_routes_details(routes)
routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index(
local_route_names=local_route_names, local_route_names=local_route_names,
local_utterances=local_utterances, local_utterances=local_utterances,
dimensions=len(self.encoder(["dummy"])[0]) dimensions=len(self.encoder(["dummy"])[0]),
) )
layer_routes = [ layer_routes = [
...@@ -508,8 +507,10 @@ class RouteLayer: ...@@ -508,8 +507,10 @@ class RouteLayer:
self.index._remove_and_sync(data_to_delete) self.index._remove_and_sync(data_to_delete)
all_utterances_to_add = [utt for _, utt in routes_to_add] all_utterances_to_add = [utt for _, utt in routes_to_add]
embedded_utterances_to_add = self.encoder(all_utterances_to_add) if all_utterances_to_add else [] embedded_utterances_to_add = (
self.encoder(all_utterances_to_add) if all_utterances_to_add else []
)
route_names_to_add = [route for route, _, in routes_to_add] route_names_to_add = [route for route, _, in routes_to_add]
self.index.add( self.index.add(
...@@ -517,14 +518,16 @@ class RouteLayer: ...@@ -517,14 +518,16 @@ class RouteLayer:
routes=route_names_to_add, routes=route_names_to_add,
utterances=all_utterances_to_add, utterances=all_utterances_to_add,
) )
self._set_layer_routes(layer_routes) self._set_layer_routes(layer_routes)
def _extract_routes_details(self, routes: List[Route]) -> Tuple[List[str], List[str]]: def _extract_routes_details(
self, routes: List[Route]
) -> Tuple:
route_names = [route.name for route in routes for _ in route.utterances] route_names = [route.name for route in routes for _ in route.utterances]
utterances = [utterance for route in routes for utterance in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances]
return route_names, utterances return route_names, utterances
def _encode(self, text: str) -> Any: def _encode(self, text: str) -> Any:
"""Given some text, encode it.""" """Given some text, encode it."""
# create query vector # create query vector
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment