diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 1f4ccaefb37d6b1f6d3cd7f37ff74bb153f10159..be3f6592edbc12e9afc944ed10c5887772cff402 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -278,6 +278,8 @@ def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]: for key, value in record.items() if key not in ["sr_route", "sr_utterance", "sr_function_schema"] } + if additional_metadata is None: + additional_metadata = {} # TODO: Not a fan of tuple packing here route_info.append( (sr_route, sr_utterance, sr_function_schema, additional_metadata) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 74b8b50c2c45ffb297e95d52ab87b3ed1726b7e1..78de158e7b919dd4ae0521f35f1b0419eb40860e 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -311,15 +311,21 @@ class RouteLayer: route.score_threshold = self.score_threshold # if routes list has been passed, we initialize index now if self.auto_sync: - # initialize index now - dims = self.encoder.dimensions - self.index._init_index(force_create=True, dimensions=dims) + # initialize index now, check if we need dimensions + if self.index.dimensions is None: + dims = len(self.encoder(["test"])[0]) + self.index.dimensions = dims + # now init index + self.index.index = self.index._init_index(force_create=True) if len(self.routes) > 0: - self._add_and_sync_routes(routes=self.routes) - else: - self._add_and_sync_routes(routes=[]) - elif len(self.routes) > 0: - self._add_routes(routes=self.routes) + local_utterances = self.to_config().to_utterances() + remote_utterances = self.index.get_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances + ) + sync_strategy = diff.get_sync_strategy(self.auto_sync) + self._execute_sync_strategy(sync_strategy) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_route = next( @@ -803,7 +809,7 @@ class RouteLayer: else: return False - def get_utterance_diff(self) -> List[str]: + def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: """Get the difference between the local and remote utterances. Returns a list of strings showing what is different in the remote when compared to the local. For example: @@ -831,7 +837,7 @@ class RouteLayer: diff_obj = UtteranceDiff.from_utterances( local_utterances=local_utterances, remote_utterances=remote_utterances ) - return diff_obj.to_utterance_str() + return diff_obj.to_utterance_str(include_metadata=include_metadata) def _add_and_sync_routes(self, routes: List[Route]): self.routes.extend(routes) diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 634cd25bacffb9c2e927c94a60cb877a12080f8d..fbf1da5099d3a7a695dc1b0ccdc63af58a8167b0 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -3,6 +3,7 @@ from difflib import Differ from enum import Enum from typing import List, Optional, Union, Any, Dict, Tuple from pydantic.v1 import BaseModel, Field +from semantic_router.utils.logger import logger class EncoderType(Enum): @@ -94,7 +95,7 @@ class Utterance(BaseModel): route: str utterance: str function_schemas: Optional[List[Dict]] = None - metadata: Optional[Dict] = None + metadata: dict = {} diff_tag: str = " " @classmethod @@ -112,7 +113,7 @@ class Utterance(BaseModel): """ route, utterance = tuple_obj[0], tuple_obj[1] function_schemas = tuple_obj[2] if len(tuple_obj) > 2 else None - metadata = tuple_obj[3] if len(tuple_obj) > 3 else None + metadata = tuple_obj[3] if len(tuple_obj) > 3 else {} return cls( route=route, utterance=utterance, @@ -133,8 +134,8 @@ class Utterance(BaseModel): return f"{self.route}: {self.utterance} | {self.function_schemas} | {self.metadata}" return f"{self.route}: {self.utterance}" - def to_diff_str(self): - return f"{self.diff_tag} {self.to_str()}" + def to_diff_str(self, include_metadata: bool = False): + return f"{self.diff_tag} {self.to_str(include_metadata=include_metadata)}" class SyncMode(Enum): @@ -160,8 +161,8 @@ class UtteranceDiff(BaseModel): local_utterances: List[Utterance], remote_utterances: List[Utterance] ): - local_utterances_map = {x.to_str(): x for x in local_utterances} - remote_utterances_map = {x.to_str(): x for x in remote_utterances} + local_utterances_map = {x.to_str(include_metadata=True): x for x in local_utterances} + remote_utterances_map = {x.to_str(include_metadata=True): x for x in remote_utterances} # sort local and remote utterances local_utterances_str = list(local_utterances_map.keys()) local_utterances_str.sort() @@ -175,12 +176,15 @@ class UtteranceDiff(BaseModel): for line in diff_obj: utterance_str = line[2:] utterance_diff_tag = line[0] + if utterance_diff_tag == "?": + # this is a new line from diff string, we can ignore + continue utterance = remote_utterances_map[utterance_str] if utterance_diff_tag == "+" else local_utterances_map[utterance_str] utterance.diff_tag = utterance_diff_tag utterance_diffs.append(utterance) return UtteranceDiff(diff=utterance_diffs) - def to_utterance_str(self) -> List[str]: + def to_utterance_str(self, include_metadata: bool = False) -> List[str]: """Outputs the utterance diff as a list of diff strings. Returns a list of strings showing what is different in the remote when compared to the local. For example: @@ -201,7 +205,7 @@ class UtteranceDiff(BaseModel): This diff tells us that the remote has "route2: utterance3" and "route2: utterance4", which do not exist locally. """ - return [x.to_diff_str() for x in self.diff] + return [x.to_diff_str(include_metadata=include_metadata) for x in self.diff] def get_tag(self, diff_tag: str) -> List[Utterance]: """Get all utterances with a given diff tag. diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index c4877151639699f7c5d3c937618911c703db71db..3cb425bbd18549a155853c5144b5f97b43c62594 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -6,6 +6,7 @@ import time from typing import Optional from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.index.pinecone import PineconeIndex +from semantic_router.schema import Utterance from semantic_router.layer import RouteLayer from semantic_router.route import Route from platform import python_version @@ -241,57 +242,77 @@ class TestRouteLayer: @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_local(self, openai_encoder, routes, routes_2, routes_4, index_cls): + def test_auto_sync_local(self, openai_encoder, routes, routes_2, index_cls): if index_cls is PineconeIndex: # TEST LOCAL pinecone_index = init_index(index_cls) + _ = RouteLayer( + encoder=openai_encoder, routes=routes, index=pinecone_index, + ) + time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( encoder=openai_encoder, routes=routes_2, index=pinecone_index, auto_sync="local" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated assert route_layer.index.get_utterances() == [ - ("Route 1", "Hello", None, {}), - ("Route 2", "Hi", None, {}), + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 2", utterance="Hi"), ], "The routes in the index should match the local routes" @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_remote(self, openai_encoder, routes, index_cls): + def test_auto_sync_remote(self, openai_encoder, routes, routes_2, index_cls): if index_cls is PineconeIndex: # TEST REMOTE pinecone_index = init_index(index_cls) + _ = RouteLayer( + encoder=openai_encoder, routes=routes_2, index=pinecone_index, + auto_sync="local" + ) + time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( encoder=openai_encoder, routes=routes, index=pinecone_index, auto_sync="remote" ) - time.sleep(PINECONE_SLEEP) # allow for index to be populated assert route_layer.index.get_utterances() == [ - ("Route 1", "Hello", None, {}), - ("Route 2", "Hi", None, {}), + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 2", utterance="Hi"), ], "The routes in the index should match the local routes" + # clear index + route_layer.index.index.delete(namespace="", delete_all=True) + @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_merge_force_remote(self, openai_encoder, routes, index_cls): + def test_auto_sync_merge_force_remote(self, openai_encoder, routes, routes_2, index_cls): if index_cls is PineconeIndex: # TEST MERGE FORCE REMOTE pinecone_index = init_index(index_cls) route_layer = RouteLayer( encoder=openai_encoder, routes=routes, index=pinecone_index, + ) + time.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes_2, index=pinecone_index, auto_sync="merge-force-remote" ) - time.sleep(PINECONE_SLEEP) # allow for index to be populated assert route_layer.index.get_utterances() == [ - ("Route 1", "Hello", None, {}), - ("Route 2", "Hi", None, {}), + Utterance( + route="Route 1", utterance="Hello", + metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Hi"), ], "The routes in the index should match the local routes" + # clear index + route_layer.index.index.delete(namespace="", delete_all=True) + @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) @@ -306,13 +327,22 @@ class TestRouteLayer: time.sleep(PINECONE_SLEEP) # allow for index to be populated assert route_layer.index.get_utterances() == [ - ("Route 1", "Hello", None, {"type": "default"}), - ("Route 1", "Hi", None, {"type": "default"}), - ("Route 2", "Bye", None, {}), - ("Route 2", "Au revoir", None, {}), - ("Route 2", "Goodbye", None, {}), + Utterance( + route="Route 1", utterance="Hello", + metadata={"type": "default"} + ), + Utterance( + route="Route 1", utterance="Hi", + metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Goodbye"), ], "The routes in the index should match the local routes" + # clear index + route_layer.index.index.delete(namespace="", delete_all=True) + @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) @@ -335,3 +365,6 @@ class TestRouteLayer: ("Route 2", "Au revoir", None, {}), ("Route 2", "Goodbye", None, {}), ], "The routes in the index should match the local routes" + + # clear index + route_layer.index.index.delete(namespace="", delete_all=True)