diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 012ef6168b4a1882c17b78fc2c3f4a5407a1369e..13f3c5d5caea23510aa971319b60e0de17859fcf 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -232,7 +232,9 @@ class UtteranceDiff(BaseModel): if sync_mode not in SyncMode.to_list(): raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}") local_only = self.get_tag("-") + local_only_mapper = {utt.route: (utt.function_schemas, utt.metadata) for utt in local_only} remote_only = self.get_tag("+") + remote_only_mapper = {utt.route: (utt.function_schemas, utt.metadata) for utt in remote_only} local_and_remote = self.get_tag(" ") if sync_mode == "error": if len(local_only) > 0 or len(remote_only) > 0: @@ -275,6 +277,7 @@ class UtteranceDiff(BaseModel): } } elif sync_mode == "merge-force-remote": # merge-to-local merge-join-local + # PRIORITIZE LOCAL # get set of route names that exist in local (we keep these if # they are in remote) local_route_names = set([utt.route for utt in local_only]) @@ -285,11 +288,36 @@ class UtteranceDiff(BaseModel): remote_to_keep = [utt for utt in remote_only if ( utt.route in local_route_names and utt.to_str() not in local_route_utt_strs )] + # overwrite remote routes with local metadata and function schemas + logger.info(f"local_only_mapper: {local_only_mapper}") + remote_to_update = [ + Utterance( + route=utt.route, + utterance=utt.utterance, + metadata=local_only_mapper[utt.route][1], + function_schemas=local_only_mapper[utt.route][0] + ) for utt in remote_only if ( + utt.route in local_only_mapper and ( + utt.metadata != local_only_mapper[utt.route][1] or + utt.function_schemas != local_only_mapper[utt.route][0] + ) + ) + ] + remote_to_keep = [ + Utterance( + route=utt.route, + utterance=utt.utterance, + metadata=local_only_mapper[utt.route][1], + function_schemas=local_only_mapper[utt.route][0] + ) for utt in remote_to_keep if utt.to_str() not in [ + x.to_str() for x in remote_to_update + ] + ] #Â get remote utterances that are NOT in local remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names] return { "remote": { - "upsert": local_only, + "upsert": local_only + remote_to_update, "delete": remote_to_delete }, "local": { @@ -308,6 +336,15 @@ class UtteranceDiff(BaseModel): local_to_keep = [utt for utt in local_only if ( utt.route in remote_route_names and utt.to_str() not in remote_route_utt_strs )] + # overwrite remote routes with local metadata and function schemas + local_to_keep = [ + Utterance( + route=utt.route, + utterance=utt.utterance, + metadata=remote_only_mapper[utt.route][1], + function_schemas=remote_only_mapper[utt.route][0] + ) for utt in local_to_keep + ] # get local utterances that are NOT in remote local_to_delete = [utt for utt in local_only if utt.route not in remote_route_names] return { @@ -321,13 +358,37 @@ class UtteranceDiff(BaseModel): } } elif sync_mode == "merge": + # overwrite remote routes with local metadata and function schemas + remote_only_updated = [ + Utterance( + route=utt.route, + utterance=utt.utterance, + metadata=local_only_mapper[utt.route][1], + function_schemas=local_only_mapper[utt.route][0] + ) if utt.route in local_only_mapper else utt + for utt in remote_only + ] + # propogate same to shared routes + shared_updated = [ + Utterance( + route=utt.route, + utterance=utt.utterance, + metadata=local_only_mapper[utt.route][1], + function_schemas=local_only_mapper[utt.route][0] + ) for utt in local_and_remote if ( + utt.route in local_only_mapper and ( + utt.metadata != local_only_mapper[utt.route][1] or + utt.function_schemas != local_only_mapper[utt.route][0] + ) + ) + ] return { "remote": { - "upsert": local_only, + "upsert": local_only + shared_updated + remote_only_updated, "delete": [] }, "local": { - "upsert": remote_only, + "upsert": remote_only_updated + shared_updated, "delete": [] } } diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 7a20d7a5f210414ca1c9c7e3db99622fc3b1bd42..cdc0745ff2319a4ef8670cafd8820566c75eb2b6 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -12,7 +12,7 @@ from semantic_router.route import Route from platform import python_version -PINECONE_SLEEP = 6 +PINECONE_SLEEP = 12 def mock_encoder_call(utterances): @@ -290,9 +290,6 @@ class TestRouteLayer: Utterance(route="Route 2", utterance="Hi"), ], "The routes in the index should match the local routes" - # clear index - route_layer.index.index.delete(namespace="", delete_all=True) - @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) @@ -317,23 +314,14 @@ class TestRouteLayer: #Â we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ - Utterance( - route='Route 1', utterance='Hello', - metadata={'type': 'default'} - ), - Utterance( - route='Route 1', utterance='Hi', - metadata={'type': 'default'} - ), + Utterance(route='Route 1', utterance='Hello'), + Utterance(route='Route 1', utterance='Hi'), Utterance(route='Route 2', utterance='Au revoir'), Utterance(route='Route 2', utterance='Bye'), Utterance(route='Route 2', utterance='Goodbye'), Utterance(route='Route 2', utterance='Hi') ], "The routes in the index should match the local routes" - # clear index - route_layer.index.index.delete(namespace="", delete_all=True) - @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) @@ -342,13 +330,13 @@ class TestRouteLayer: # TEST MERGE FORCE LOCAL pinecone_index = init_index(index_cls) route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, + encoder=openai_encoder, routes=routes, index=pinecone_index, auto_sync="local" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, - auto_sync="merge-force-remote" + encoder=openai_encoder, routes=routes_2, index=pinecone_index, + auto_sync="merge-force-local" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated # confirm local and remote are synced @@ -358,20 +346,15 @@ class TestRouteLayer: #Â we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ - Utterance(route='Route 1', utterance='Hello'), - Utterance( - route='Route 1', utterance='Hi', - metadata={'type': 'default'} - ), + Utterance(route='Route 1', utterance='Hello', metadata={'type': 'default'}), + Utterance(route='Route 1', utterance='Hi', metadata={'type': 'default'}), Utterance(route='Route 2', utterance='Au revoir'), Utterance(route='Route 2', utterance='Bye'), Utterance(route='Route 2', utterance='Goodbye'), - Utterance(route='Route 2', utterance='Hi') + Utterance(route='Route 2', utterance='Hi'), + Utterance(route='Route 3', utterance='Boo') ], "The routes in the index should match the local routes" - # clear index - route_layer.index.index.delete(namespace="", delete_all=True) - @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) @@ -380,12 +363,12 @@ class TestRouteLayer: # TEST MERGE pinecone_index = init_index(index_cls) route_layer = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, + encoder=openai_encoder, routes=routes_2, index=pinecone_index, auto_sync="local" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, + encoder=openai_encoder, routes=routes, index=pinecone_index, auto_sync="merge" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated @@ -396,14 +379,19 @@ class TestRouteLayer: #Â we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ - ("Route 1", "Hello", None, {"type": "default"}), - ("Route 1", "Hi", None, {"type": "default"}), - ("Route 1", "Goodbye", None, {"type": "default"}), - ("Route 2", "Bye", None, {}), - ("Route 2", "Asparagus", None, {}), - ("Route 2", "Au revoir", None, {}), - ("Route 2", "Goodbye", None, {}), - ("Route 3", "Boo", None, {}), + Utterance( + route='Route 1', utterance='Hello', + metadata={'type': 'default'} + ), + Utterance( + route='Route 1', utterance='Hi', + metadata={'type': 'default'} + ), + Utterance(route='Route 2', utterance='Au revoir'), + Utterance(route='Route 2', utterance='Bye'), + Utterance(route='Route 2', utterance='Goodbye'), + Utterance(route='Route 2', utterance='Hi'), + Utterance(route='Route 3', utterance='Boo') ], "The routes in the index should match the local routes" # clear index