Skip to content
Snippets Groups Projects
Commit c8de4667 authored by James Briggs's avatar James Briggs
Browse files

fix: metadata propogation

parent 248edf8b
No related branches found
No related tags found
No related merge requests found
...@@ -232,7 +232,9 @@ class UtteranceDiff(BaseModel): ...@@ -232,7 +232,9 @@ class UtteranceDiff(BaseModel):
if sync_mode not in SyncMode.to_list(): if sync_mode not in SyncMode.to_list():
raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}") raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}")
local_only = self.get_tag("-") local_only = self.get_tag("-")
local_only_mapper = {utt.route: (utt.function_schemas, utt.metadata) for utt in local_only}
remote_only = self.get_tag("+") remote_only = self.get_tag("+")
remote_only_mapper = {utt.route: (utt.function_schemas, utt.metadata) for utt in remote_only}
local_and_remote = self.get_tag(" ") local_and_remote = self.get_tag(" ")
if sync_mode == "error": if sync_mode == "error":
if len(local_only) > 0 or len(remote_only) > 0: if len(local_only) > 0 or len(remote_only) > 0:
...@@ -275,6 +277,7 @@ class UtteranceDiff(BaseModel): ...@@ -275,6 +277,7 @@ class UtteranceDiff(BaseModel):
} }
} }
elif sync_mode == "merge-force-remote": # merge-to-local merge-join-local elif sync_mode == "merge-force-remote": # merge-to-local merge-join-local
# PRIORITIZE LOCAL
# get set of route names that exist in local (we keep these if # get set of route names that exist in local (we keep these if
# they are in remote) # they are in remote)
local_route_names = set([utt.route for utt in local_only]) local_route_names = set([utt.route for utt in local_only])
...@@ -285,11 +288,36 @@ class UtteranceDiff(BaseModel): ...@@ -285,11 +288,36 @@ class UtteranceDiff(BaseModel):
remote_to_keep = [utt for utt in remote_only if ( remote_to_keep = [utt for utt in remote_only if (
utt.route in local_route_names and utt.to_str() not in local_route_utt_strs utt.route in local_route_names and utt.to_str() not in local_route_utt_strs
)] )]
# overwrite remote routes with local metadata and function schemas
logger.info(f"local_only_mapper: {local_only_mapper}")
remote_to_update = [
Utterance(
route=utt.route,
utterance=utt.utterance,
metadata=local_only_mapper[utt.route][1],
function_schemas=local_only_mapper[utt.route][0]
) for utt in remote_only if (
utt.route in local_only_mapper and (
utt.metadata != local_only_mapper[utt.route][1] or
utt.function_schemas != local_only_mapper[utt.route][0]
)
)
]
remote_to_keep = [
Utterance(
route=utt.route,
utterance=utt.utterance,
metadata=local_only_mapper[utt.route][1],
function_schemas=local_only_mapper[utt.route][0]
) for utt in remote_to_keep if utt.to_str() not in [
x.to_str() for x in remote_to_update
]
]
# get remote utterances that are NOT in local # get remote utterances that are NOT in local
remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names] remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names]
return { return {
"remote": { "remote": {
"upsert": local_only, "upsert": local_only + remote_to_update,
"delete": remote_to_delete "delete": remote_to_delete
}, },
"local": { "local": {
...@@ -308,6 +336,15 @@ class UtteranceDiff(BaseModel): ...@@ -308,6 +336,15 @@ class UtteranceDiff(BaseModel):
local_to_keep = [utt for utt in local_only if ( local_to_keep = [utt for utt in local_only if (
utt.route in remote_route_names and utt.to_str() not in remote_route_utt_strs utt.route in remote_route_names and utt.to_str() not in remote_route_utt_strs
)] )]
# overwrite remote routes with local metadata and function schemas
local_to_keep = [
Utterance(
route=utt.route,
utterance=utt.utterance,
metadata=remote_only_mapper[utt.route][1],
function_schemas=remote_only_mapper[utt.route][0]
) for utt in local_to_keep
]
# get local utterances that are NOT in remote # get local utterances that are NOT in remote
local_to_delete = [utt for utt in local_only if utt.route not in remote_route_names] local_to_delete = [utt for utt in local_only if utt.route not in remote_route_names]
return { return {
...@@ -321,13 +358,37 @@ class UtteranceDiff(BaseModel): ...@@ -321,13 +358,37 @@ class UtteranceDiff(BaseModel):
} }
} }
elif sync_mode == "merge": elif sync_mode == "merge":
# overwrite remote routes with local metadata and function schemas
remote_only_updated = [
Utterance(
route=utt.route,
utterance=utt.utterance,
metadata=local_only_mapper[utt.route][1],
function_schemas=local_only_mapper[utt.route][0]
) if utt.route in local_only_mapper else utt
for utt in remote_only
]
# propogate same to shared routes
shared_updated = [
Utterance(
route=utt.route,
utterance=utt.utterance,
metadata=local_only_mapper[utt.route][1],
function_schemas=local_only_mapper[utt.route][0]
) for utt in local_and_remote if (
utt.route in local_only_mapper and (
utt.metadata != local_only_mapper[utt.route][1] or
utt.function_schemas != local_only_mapper[utt.route][0]
)
)
]
return { return {
"remote": { "remote": {
"upsert": local_only, "upsert": local_only + shared_updated + remote_only_updated,
"delete": [] "delete": []
}, },
"local": { "local": {
"upsert": remote_only, "upsert": remote_only_updated + shared_updated,
"delete": [] "delete": []
} }
} }
......
...@@ -12,7 +12,7 @@ from semantic_router.route import Route ...@@ -12,7 +12,7 @@ from semantic_router.route import Route
from platform import python_version from platform import python_version
PINECONE_SLEEP = 6 PINECONE_SLEEP = 12
def mock_encoder_call(utterances): def mock_encoder_call(utterances):
...@@ -290,9 +290,6 @@ class TestRouteLayer: ...@@ -290,9 +290,6 @@ class TestRouteLayer:
Utterance(route="Route 2", utterance="Hi"), Utterance(route="Route 2", utterance="Hi"),
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# clear index
route_layer.index.index.delete(namespace="", delete_all=True)
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
) )
...@@ -317,23 +314,14 @@ class TestRouteLayer: ...@@ -317,23 +314,14 @@ class TestRouteLayer:
# we sort to ensure order is the same # we sort to ensure order is the same
local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
assert local_utterances == [ assert local_utterances == [
Utterance( Utterance(route='Route 1', utterance='Hello'),
route='Route 1', utterance='Hello', Utterance(route='Route 1', utterance='Hi'),
metadata={'type': 'default'}
),
Utterance(
route='Route 1', utterance='Hi',
metadata={'type': 'default'}
),
Utterance(route='Route 2', utterance='Au revoir'), Utterance(route='Route 2', utterance='Au revoir'),
Utterance(route='Route 2', utterance='Bye'), Utterance(route='Route 2', utterance='Bye'),
Utterance(route='Route 2', utterance='Goodbye'), Utterance(route='Route 2', utterance='Goodbye'),
Utterance(route='Route 2', utterance='Hi') Utterance(route='Route 2', utterance='Hi')
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# clear index
route_layer.index.index.delete(namespace="", delete_all=True)
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
) )
...@@ -342,13 +330,13 @@ class TestRouteLayer: ...@@ -342,13 +330,13 @@ class TestRouteLayer:
# TEST MERGE FORCE LOCAL # TEST MERGE FORCE LOCAL
pinecone_index = init_index(index_cls) pinecone_index = init_index(index_cls)
route_layer = RouteLayer( route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_2, index=pinecone_index, encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="local" auto_sync="local"
) )
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
route_layer = RouteLayer( route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index, encoder=openai_encoder, routes=routes_2, index=pinecone_index,
auto_sync="merge-force-remote" auto_sync="merge-force-local"
) )
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
# confirm local and remote are synced # confirm local and remote are synced
...@@ -358,20 +346,15 @@ class TestRouteLayer: ...@@ -358,20 +346,15 @@ class TestRouteLayer:
# we sort to ensure order is the same # we sort to ensure order is the same
local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
assert local_utterances == [ assert local_utterances == [
Utterance(route='Route 1', utterance='Hello'), Utterance(route='Route 1', utterance='Hello', metadata={'type': 'default'}),
Utterance( Utterance(route='Route 1', utterance='Hi', metadata={'type': 'default'}),
route='Route 1', utterance='Hi',
metadata={'type': 'default'}
),
Utterance(route='Route 2', utterance='Au revoir'), Utterance(route='Route 2', utterance='Au revoir'),
Utterance(route='Route 2', utterance='Bye'), Utterance(route='Route 2', utterance='Bye'),
Utterance(route='Route 2', utterance='Goodbye'), Utterance(route='Route 2', utterance='Goodbye'),
Utterance(route='Route 2', utterance='Hi') Utterance(route='Route 2', utterance='Hi'),
Utterance(route='Route 3', utterance='Boo')
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# clear index
route_layer.index.index.delete(namespace="", delete_all=True)
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
) )
...@@ -380,12 +363,12 @@ class TestRouteLayer: ...@@ -380,12 +363,12 @@ class TestRouteLayer:
# TEST MERGE # TEST MERGE
pinecone_index = init_index(index_cls) pinecone_index = init_index(index_cls)
route_layer = RouteLayer( route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index, encoder=openai_encoder, routes=routes_2, index=pinecone_index,
auto_sync="local" auto_sync="local"
) )
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
route_layer = RouteLayer( route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_2, index=pinecone_index, encoder=openai_encoder, routes=routes, index=pinecone_index,
auto_sync="merge" auto_sync="merge"
) )
time.sleep(PINECONE_SLEEP) # allow for index to be populated time.sleep(PINECONE_SLEEP) # allow for index to be populated
...@@ -396,14 +379,19 @@ class TestRouteLayer: ...@@ -396,14 +379,19 @@ class TestRouteLayer:
# we sort to ensure order is the same # we sort to ensure order is the same
local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
assert local_utterances == [ assert local_utterances == [
("Route 1", "Hello", None, {"type": "default"}), Utterance(
("Route 1", "Hi", None, {"type": "default"}), route='Route 1', utterance='Hello',
("Route 1", "Goodbye", None, {"type": "default"}), metadata={'type': 'default'}
("Route 2", "Bye", None, {}), ),
("Route 2", "Asparagus", None, {}), Utterance(
("Route 2", "Au revoir", None, {}), route='Route 1', utterance='Hi',
("Route 2", "Goodbye", None, {}), metadata={'type': 'default'}
("Route 3", "Boo", None, {}), ),
Utterance(route='Route 2', utterance='Au revoir'),
Utterance(route='Route 2', utterance='Bye'),
Utterance(route='Route 2', utterance='Goodbye'),
Utterance(route='Route 2', utterance='Hi'),
Utterance(route='Route 3', utterance='Boo')
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# clear index # clear index
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment