Skip to content
Snippets Groups Projects
Commit 40546927 authored by James Briggs's avatar James Briggs
Browse files

feat: continued refactoring for sync features

parent e27f4446
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ import json
import numpy as np
from pydantic.v1 import BaseModel
from semantic_router.schema import ConfigParameter
from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.route import Route
from semantic_router.utils.logger import logger
......@@ -40,7 +40,7 @@ class BaseIndex(BaseModel):
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def get_utterances(self) -> List[Tuple]:
def get_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
......@@ -50,7 +50,7 @@ class BaseIndex(BaseModel):
"""
_, metadata = self._get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata)
return route_tuples
return [Utterance.from_tuple(x) for x in route_tuples]
def get_routes(self) -> List[Route]:
"""Gets a list of route objects currently stored in the index.
......
......@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Dict
import numpy as np
from semantic_router.schema import ConfigParameter
from semantic_router.schema import ConfigParameter, Utterance
from semantic_router.index.base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger
......@@ -61,7 +61,7 @@ class LocalIndex(BaseIndex):
if self.sync is not None:
logger.error("Sync remove is not implemented for LocalIndex.")
def get_utterances(self) -> List[Tuple]:
def get_utterances(self) -> List[Utterance]:
"""
Gets a list of route and utterance objects currently stored in the index.
......@@ -70,7 +70,9 @@ class LocalIndex(BaseIndex):
"""
if self.routes is None or self.utterances is None:
return []
return list(zip(self.routes, self.utterances))
return [
Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)
]
def describe(self) -> Dict:
return {
......
......@@ -15,7 +15,7 @@ from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex
from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route
from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice
from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice, Utterance, UtteranceDiff
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
......@@ -218,31 +218,23 @@ class LayerConfig:
elif ext in [".yaml", ".yml"]:
yaml.safe_dump(self.to_dict(), f)
def _get_diff(self, other: "LayerConfig") -> List[str]:
"""Get the difference between two LayerConfigs.
def to_utterances(self) -> List[Utterance]:
"""Convert the routes to a list of Utterance objects.
:param other: The LayerConfig to compare to.
:type other: LayerConfig
:return: A list of differences between the two LayerConfigs.
:rtype: List[Dict[str, Any]]
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
# TODO: formalize diffs into likely LayerDiff objects that can then
# output different formats as required to enable smarter syncs
self_yaml = yaml.dump(self.to_dict())
other_yaml = yaml.dump(other.to_dict())
differ = Differ()
return list(differ.compare(self_yaml.splitlines(), other_yaml.splitlines()))
def show_diff(self, other: "LayerConfig") -> str:
"""Show the difference between two LayerConfigs.
:param other: The LayerConfig to compare to.
:type other: LayerConfig
:return: A string showing the difference between the two LayerConfigs.
:rtype: str
"""
diff = self._get_diff(other)
return "\n".join(diff)
utterances = []
for route in self.routes:
utterances.extend([
Utterance(
route=route.name,
utterance=x,
function_schemas=route.function_schemas,
metadata=route.metadata
) for x in route.utterances
])
return utterances
def add(self, route: Route):
self.routes.append(route)
......@@ -283,6 +275,7 @@ class RouteLayer:
index: Optional[BaseIndex] = None, # type: ignore
top_k: int = 5,
aggregation: str = "sum",
auto_sync: Optional[str] = None,
):
self.index: BaseIndex = index if index is not None else LocalIndex()
if encoder is None:
......@@ -310,14 +303,17 @@ class RouteLayer:
f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'."
)
self.aggregation_method = self._set_aggregation_method(self.aggregation)
self.auto_sync = auto_sync
# set route score thresholds if not already set
for route in self.routes:
if route.score_threshold is None:
route.score_threshold = self.score_threshold
# if routes list has been passed, we initialize index now
if self.index.sync:
if self.auto_sync:
# initialize index now
dims = self.encoder.dimensions
self.index._init_index(force_create=True, dimensions=dims)
if len(self.routes) > 0:
self._add_and_sync_routes(routes=self.routes)
else:
......@@ -447,6 +443,100 @@ class RouteLayer:
return route_choices
def sync(self, sync_mode: str, force: bool = False) -> List[str]:
"""Runs a sync of the local routes with the remote index.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:param force: Whether to force the sync even if the local and remote
hashes already match. Defaults to False.
:type force: bool, optional
:return: A list of diffs describing the addressed differences between
the local and remote route layers.
:rtype: List[str]
"""
if not force and self.is_synced():
logger.warning("Local and remote route layers are already synchronized.")
# create utterance diff to return, but just using local instance
# for speed
local_utterances = self.to_config().to_utterances()
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=local_utterances,
)
return diff.to_utterance_str()
# otherwise we continue with the sync, first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=remote_utterances,
)
# generate sync strategy
sync_strategy = diff.to_sync_strategy()
# and execute
self._execute_sync_strategy(sync_strategy)
return diff.to_utterance_str()
def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
"""Executes the provided sync strategy, either deleting or upserting
routes from the local and remote instances as defined in the strategy.
:param strategy: The sync strategy to execute.
:type strategy: Dict[str, Dict[str, List[Utterance]]]
"""
if strategy["remote"]["delete"]:
data_to_delete = {} # type: ignore
for utt_obj in strategy["remote"]["delete"]:
data_to_delete.setdefault(
utt_obj.route, []
).append(utt_obj.utterance)
self.index._remove_and_sync(data_to_delete)
if strategy["remote"]["upsert"]:
utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
self.index.add(
embeddings=self.encoder(utterances_text),
routes=[utt.route for utt in strategy["remote"]["upsert"]],
utterances=utterances_text,
function_schemas=[utt.function_schemas for utt in strategy["remote"]["upsert"]],
metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
)
if strategy["local"]["delete"]:
self._local_delete(utterances=strategy["local"]["delete"])
if strategy["local"]["upsert"]:
self._local_upsert(utterances=strategy["local"]["upsert"])
# update hash
self._write_hash()
def _local_upsert(self, utterances: List[Utterance]):
"""Adds new routes to the RouteLayer.
:param utterances: The utterances to add to the local RouteLayer.
:type utterances: List[Utterance]
"""
new_routes = {}
for utt_obj in utterances:
if utt_obj.route not in new_routes.keys():
new_routes[utt_obj.route] = Route(
name=utt_obj.route,
utterances=[utt_obj.utterance],
function_schemas=utt_obj.function_schemas,
metadata=utt_obj.metadata
)
else:
new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
self.routes.extend(list(new_routes.values()))
def _local_delete(self, utterances: List[Utterance]):
"""Deletes routes from the local RouteLayer.
:param utterances: The utterances to delete from the local RouteLayer.
:type utterances: List[Utterance]
"""
route_names = set([utt.route for utt in utterances])
self.routes = [route for route in self.routes if route.name not in route_names]
def _retrieve_top_route(
self, vector: List[float], route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]:
......@@ -735,97 +825,27 @@ class RouteLayer:
"route2: utterance4", which do not exist locally.
"""
# first we get remote and local utterances
remote_utterances = [f"{x[0]}: {x[1]}" for x in self.index.get_utterances()]
local_routes, local_utterance_arr, _ = self._extract_routes_details(
self.routes, include_metadata=False
)
local_utterances = [
f"{x[0]}: {x[1]}" for x in zip(local_routes, local_utterance_arr)
]
# sort local and remote utterances
local_utterances.sort()
remote_utterances.sort()
# now get diff
differ = Differ()
diff = list(differ.compare(local_utterances, remote_utterances))
return diff
def _add_and_sync_routes(self, routes: List[Route]):
# get current local hash
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
# create embeddings for all routes and sync at startup with remote ones based on sync setting
local_route_names, local_utterances, local_function_schemas, local_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
remote_utterances = self.index.get_utterances()
local_utterances = self.to_config().to_utterances()
routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index(
local_route_names,
local_utterances,
local_function_schemas,
local_metadata,
dimensions=self.index.dimensions or len(self.encoder(["dummy"])[0]),
diff_obj = UtteranceDiff.from_utterances(
local_utterances=local_utterances, remote_utterances=remote_utterances
)
return diff_obj.to_utterance_str()
data_to_delete = {} # type: ignore
for route, utterance in routes_to_delete:
data_to_delete.setdefault(route, []).append(utterance)
self.index._remove_and_sync(data_to_delete)
# Prepare data for addition
if routes_to_add:
(
route_names_to_add,
all_utterances_to_add,
function_schemas_to_add,
metadata_to_add,
) = map(list, zip(*routes_to_add))
else:
(
route_names_to_add,
all_utterances_to_add,
function_schemas_to_add,
metadata_to_add,
) = ([], [], [], [])
embedded_utterances_to_add = (
self.encoder(all_utterances_to_add) if all_utterances_to_add else []
)
def _add_and_sync_routes(self, routes: List[Route]):
self.routes.extend(routes)
# first we get remote and local utterances
remote_utterances = self.index.get_utterances()
local_utterances = self.to_config().to_utterances()
self.index.add(
embeddings=embedded_utterances_to_add,
routes=route_names_to_add,
utterances=all_utterances_to_add,
function_schemas=function_schemas_to_add,
metadata_list=metadata_to_add,
diff_obj = UtteranceDiff.from_utterances(
local_utterances=local_utterances, remote_utterances=remote_utterances
)
# Update local route layer state
self.routes = []
for route, data in layer_routes_dict.items():
function_schemas = data.get("function_schemas", None)
if function_schemas is not None:
function_schemas = [function_schemas]
self.routes.append(
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=function_schemas,
metadata=data.get("metadata", {}),
)
)
# update hash IF index and local hash were aligned
if current_local_hash.value == current_remote_hash.value:
self._write_hash()
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
"not updated. Use `RouteLayer.get_utterance_diff()` to see "
"details."
)
sync_strategy = diff_obj.get_sync_strategy(sync_mode=self.auto_sync)
self._execute_sync_strategy(strategy=sync_strategy)
# update remote hash
self._write_hash()
def _extract_routes_details(
self, routes: List[Route], include_metadata: bool = False
......
from datetime import datetime
from difflib import Differ
from enum import Enum
from typing import List, Optional, Union, Any, Dict
from typing import List, Optional, Union, Any, Dict, Tuple
from pydantic.v1 import BaseModel, Field
......@@ -18,6 +19,9 @@ class EncoderType(Enum):
GOOGLE = "google"
BEDROCK = "bedrock"
def to_list():
return [encoder.value for encoder in EncoderType]
class EncoderInfo(BaseModel):
name: str
......@@ -86,6 +90,234 @@ class ConfigParameter(BaseModel):
}
class Utterance(BaseModel):
route: str
utterance: str
function_schemas: Optional[List[Dict]] = None
metadata: Optional[Dict] = None
diff_tag: str = " "
@classmethod
def from_tuple(cls, tuple_obj: Tuple):
"""Create an Utterance object from a tuple. The tuple must contain
route and utterance as the first two elements. Then optionally
function schemas and metadata as the third and fourth elements
respectively. If this order is not followed an invalid Utterance
object will be returned.
:param tuple_obj: A tuple containing route, utterance, function schemas and metadata.
:type tuple_obj: Tuple
:return: An Utterance object.
:rtype: Utterance
"""
route, utterance = tuple_obj[0], tuple_obj[1]
function_schemas = tuple_obj[2] if len(tuple_obj) > 2 else None
metadata = tuple_obj[3] if len(tuple_obj) > 3 else None
return cls(
route=route,
utterance=utterance,
function_schemas=function_schemas,
metadata=metadata
)
def to_tuple(self):
return (
self.route,
self.utterance,
self.function_schemas,
self.metadata,
)
def to_str(self, include_metadata: bool = False):
if include_metadata:
return f"{self.route}: {self.utterance} | {self.function_schemas} | {self.metadata}"
return f"{self.route}: {self.utterance}"
def to_diff_str(self):
return f"{self.diff_tag} {self.to_str()}"
class SyncMode(Enum):
"""Synchronization modes for local (route layer) and remote (index)
instances.
"""
ERROR = "error"
REMOTE = "remote"
LOCAL = "local"
MERGE_FORCE_REMOTE = "merge-force-remote"
MERGE_FORCE_LOCAL = "merge-force-local"
MERGE = "merge"
def to_list() -> List[str]:
return [mode.value for mode in SyncMode]
class UtteranceDiff(BaseModel):
diff: List[Utterance]
@classmethod
def from_utterances(
cls,
local_utterances: List[Utterance],
remote_utterances: List[Utterance]
):
local_utterances_map = {x.to_str(): x for x in local_utterances}
remote_utterances_map = {x.to_str(): x for x in remote_utterances}
# sort local and remote utterances
local_utterances_str = list(local_utterances_map.keys())
local_utterances_str.sort()
remote_utterances_str = list(remote_utterances_map.keys())
remote_utterances_str.sort()
# get diff
differ = Differ()
diff_obj = list(differ.compare(local_utterances_str, remote_utterances_str))
# create UtteranceDiff list
utterance_diffs = []
for line in diff_obj:
utterance_str = line[2:]
utterance_diff_tag = line[0]
utterance = remote_utterances_map[utterance_str] if utterance_diff_tag == "+" else local_utterances_map[utterance_str]
utterance.diff_tag = utterance_diff_tag
utterance_diffs.append(utterance)
return UtteranceDiff(diff=utterance_diffs)
def to_utterance_str(self) -> List[str]:
"""Outputs the utterance diff as a list of diff strings. Returns a list
of strings showing what is different in the remote when compared to the
local. For example:
[" route1: utterance1",
" route1: utterance2",
"- route2: utterance3",
"- route2: utterance4"]
Tells us that the remote is missing "route2: utterance3" and "route2:
utterance4", which do exist locally. If we see:
[" route1: utterance1",
" route1: utterance2",
"+ route2: utterance3",
"+ route2: utterance4"]
This diff tells us that the remote has "route2: utterance3" and
"route2: utterance4", which do not exist locally.
"""
return [x.to_diff_str() for x in self.diff]
def get_tag(self, diff_tag: str) -> List[Utterance]:
"""Get all utterances with a given diff tag.
:param diff_tag: The diff tag to filter by. Must be one of "+", "-", or
" ".
:type diff_tag: str
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if diff_tag not in ["+", "-", " "]:
raise ValueError("diff_tag must be one of '+', '-', or ' '")
return [x for x in self.diff if x.diff_tag == diff_tag]
def get_sync_strategy(self, sync_mode: str) -> dict:
"""Generates the optimal synchronization plan for local and remote
instances.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:return: A dictionary describing the synchronization strategy.
:rtype: dict
"""
if sync_mode not in SyncMode.to_list():
raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}")
local_only = self.get_tag("-")
remote_only = self.get_tag("+")
local_and_remote = self.get_tag(" ")
if sync_mode == "error":
if len(local_only) > 0 or len(remote_only) > 0:
raise ValueError(
"There are utterances that exist in the local or remote "
"instance that do not exist in the other instance. Please "
"sync the routes before running this command."
)
else:
return {
"remote": {
"upsert": [],
"delete": []
},
"local": {
"upsert": [],
"delete": []
}
}
elif sync_mode == "local":
return {
"remote": {
"upsert": local_only,
"delete": remote_only
},
"local": {
"upsert": [],
"delete": []
}
}
elif sync_mode == "remote":
return {
"remote": {
"upsert": [],
"delete": []
},
"local": {
"upsert": remote_only,
"delete": local_only
}
}
elif sync_mode == "merge-force-remote":
# get set of route names that exist in both local and remote
routes_in_both = set([utt.route for utt in local_and_remote])
# get remote utterances that belong to routes_in_both
remote_to_keep = [utt for utt in remote_only if utt.route in routes_in_both]
# get remote utterances that do NOT belong to routes_in_both
remote_to_delete = [utt for utt in remote_only if utt.route not in routes_in_both]
return {
"remote": {
"upsert": local_only,
"delete": remote_to_delete
},
"local": {
"upsert": remote_to_keep,
"delete": []
}
}
elif sync_mode == "merge-force-local":
# get set of route names that exist in both local and remote
routes_in_both = set([utt.route for utt in local_and_remote])
# get local utterances that belong to routes_in_both
local_to_keep = [utt for utt in local_only if utt.route in routes_in_both]
# get local utterances that do NOT belong to routes_in_both
local_to_delete = [utt for utt in local_only if utt.route not in routes_in_both]
return {
"remote": {
"upsert": local_to_keep,
"delete": []
},
"local": {
"upsert": remote_only,
"delete": local_to_delete
}
}
elif sync_mode == "merge":
return {
"remote": {
"upsert": local_only,
"delete": []
},
"local": {
"upsert": remote_only,
"delete": []
}
}
class Metric(Enum):
COSINE = "cosine"
DOTPRODUCT = "dotproduct"
......
......@@ -238,6 +238,7 @@ class TestRouteLayer:
assert route_layer_openai.score_threshold == 0.3
def test_delete_index(self, openai_encoder, routes, index_cls):
# TODO merge .delete_index() and .delete_all() and get working
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
if index_cls is PineconeIndex:
......
......@@ -237,3 +237,101 @@ class TestRouteLayer:
assert "+ Route 2: Bye" in diff
assert "+ Route 2: Goodbye" in diff
assert " Route 2: Hi" in diff
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_local(self, openai_encoder, routes, routes_2, routes_4, index_cls):
if index_cls is PineconeIndex:
# TEST LOCAL
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_2, index=pinecone_index,
auto_sync="local"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {}),
("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_remote(self, openai_encoder, routes, index_cls):
if index_cls is PineconeIndex:
# TEST REMOTE
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="remote"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {}),
("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_merge_force_remote(self, openai_encoder, routes, index_cls):
if index_cls is PineconeIndex:
# TEST MERGE FORCE REMOTE
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="merge-force-remote"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {}),
("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_merge_force_local(self, openai_encoder, routes, index_cls):
if index_cls is PineconeIndex:
# TEST MERGE FORCE LOCAL
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="merge-force-local"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {"type": "default"}),
("Route 1", "Hi", None, {"type": "default"}),
("Route 2", "Bye", None, {}),
("Route 2", "Au revoir", None, {}),
("Route 2", "Goodbye", None, {}),
], "The routes in the index should match the local routes"
@pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
)
def test_auto_sync_merge(self, openai_encoder, routes_4, index_cls):
if index_cls is PineconeIndex:
# TEST MERGE
pinecone_index = init_index(index_cls)
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_4, index=pinecone_index,
auto_sync="merge"
)
time.sleep(PINECONE_SLEEP) # allow for index to be populated
assert route_layer.index.get_utterances() == [
("Route 1", "Hello", None, {"type": "default"}),
("Route 1", "Hi", None, {"type": "default"}),
("Route 1", "Goodbye", None, {"type": "default"}),
("Route 2", "Bye", None, {}),
("Route 2", "Asparagus", None, {}),
("Route 2", "Au revoir", None, {}),
("Route 2", "Goodbye", None, {}),
], "The routes in the index should match the local routes"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment