Skip to content
Snippets Groups Projects
Commit 3e5a9bc2 authored by Vits's avatar Vits
Browse files

Formatting and linting

parent 8fd88ad0
No related branches found
No related tags found
No related merge requests found
...@@ -268,12 +268,12 @@ class PineconeIndex(BaseIndex): ...@@ -268,12 +268,12 @@ class PineconeIndex(BaseIndex):
for route in all_routes: for route in all_routes:
local_utterances = local_dict.get(route, {}).get("utterances", set()) local_utterances = local_dict.get(route, {}).get("utterances", set())
remote_utterances = remote_dict.get(route, {}).get("utterances", set()) remote_utterances = remote_dict.get(route, {}).get("utterances", set())
local_function_schemas = local_dict.get(route, {}).get( local_function_schemas = (
"function_schemas", {} local_dict.get(route, {}).get("function_schemas", {}) or {}
) or {} )
remote_function_schemas = remote_dict.get(route, {}).get( remote_function_schemas = (
"function_schemas", {} remote_dict.get(route, {}).get("function_schemas", {}) or {}
) or {} )
local_metadata = local_dict.get(route, {}).get("metadata", {}) local_metadata = local_dict.get(route, {}).get("metadata", {})
remote_metadata = remote_dict.get(route, {}).get("metadata", {}) remote_metadata = remote_dict.get(route, {}).get("metadata", {})
...@@ -295,7 +295,9 @@ class PineconeIndex(BaseIndex): ...@@ -295,7 +295,9 @@ class PineconeIndex(BaseIndex):
if local_utterances: if local_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(local_utterances), "utterances": list(local_utterances),
"function_schemas": local_function_schemas if local_function_schemas else None, "function_schemas": (
local_function_schemas if local_function_schemas else None
),
"metadata": local_metadata, "metadata": local_metadata,
} }
...@@ -303,7 +305,9 @@ class PineconeIndex(BaseIndex): ...@@ -303,7 +305,9 @@ class PineconeIndex(BaseIndex):
if remote_utterances: if remote_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(remote_utterances), "utterances": list(remote_utterances),
"function_schemas": remote_function_schemas if remote_function_schemas else None, "function_schemas": (
remote_function_schemas if remote_function_schemas else None
),
"metadata": remote_metadata, "metadata": remote_metadata,
} }
...@@ -319,7 +323,9 @@ class PineconeIndex(BaseIndex): ...@@ -319,7 +323,9 @@ class PineconeIndex(BaseIndex):
if local_utterances: if local_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(local_utterances), "utterances": list(local_utterances),
"function_schemas": local_function_schemas if local_function_schemas else None, "function_schemas": (
local_function_schemas if local_function_schemas else None
),
"metadata": local_metadata, "metadata": local_metadata,
} }
...@@ -329,14 +335,22 @@ class PineconeIndex(BaseIndex): ...@@ -329,14 +335,22 @@ class PineconeIndex(BaseIndex):
if local_utterances: if local_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(local_utterances), "utterances": list(local_utterances),
"function_schemas": local_function_schemas if local_function_schemas else None, "function_schemas": (
local_function_schemas
if local_function_schemas
else None
),
"metadata": local_metadata, "metadata": local_metadata,
} }
else: else:
if remote_utterances: if remote_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(remote_utterances), "utterances": list(remote_utterances),
"function_schemas": remote_function_schemas if remote_function_schemas else None, "function_schemas": (
remote_function_schemas
if remote_function_schemas
else None
),
"metadata": remote_metadata, "metadata": remote_metadata,
} }
...@@ -353,14 +367,22 @@ class PineconeIndex(BaseIndex): ...@@ -353,14 +367,22 @@ class PineconeIndex(BaseIndex):
if local_utterances: if local_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(local_utterances), "utterances": list(local_utterances),
"function_schemas": local_function_schemas if local_function_schemas else None, "function_schemas": (
local_function_schemas
if local_function_schemas
else None
),
"metadata": local_metadata, "metadata": local_metadata,
} }
else: else:
if remote_utterances: if remote_utterances:
layer_routes[route] = { layer_routes[route] = {
"utterances": list(remote_utterances), "utterances": list(remote_utterances),
"function_schemas": remote_function_schemas if remote_function_schemas else None, "function_schemas": (
remote_function_schemas
if remote_function_schemas
else None
),
"metadata": remote_metadata, "metadata": remote_metadata,
} }
...@@ -375,7 +397,9 @@ class PineconeIndex(BaseIndex): ...@@ -375,7 +397,9 @@ class PineconeIndex(BaseIndex):
} }
layer_routes[route] = { layer_routes[route] = {
"utterances": list(remote_utterances.union(local_utterances)), "utterances": list(remote_utterances.union(local_utterances)),
"function_schemas": merged_function_schemas if merged_function_schemas else None, "function_schemas": (
merged_function_schemas if merged_function_schemas else None
),
"metadata": merged_metadata, "metadata": merged_metadata,
} }
...@@ -389,17 +413,36 @@ class PineconeIndex(BaseIndex): ...@@ -389,17 +413,36 @@ class PineconeIndex(BaseIndex):
]: ]:
for utterance in local_utterances: for utterance in local_utterances:
routes_to_add.append( routes_to_add.append(
(route, utterance, local_function_schemas if local_function_schemas else None, local_metadata) (
route,
utterance,
local_function_schemas if local_function_schemas else None,
local_metadata,
)
) )
if (metadata_changed or function_schema_changed) and self.sync == "merge": if (metadata_changed or function_schema_changed) and self.sync == "merge":
for utterance in local_utterances: for utterance in local_utterances:
routes_to_add.append( routes_to_add.append(
(route, utterance, merged_function_schemas if merged_function_schemas else None, merged_metadata) (
route,
utterance,
(
merged_function_schemas
if merged_function_schemas
else None
),
merged_metadata,
)
) )
elif utterances_to_include: elif utterances_to_include:
for utterance in utterances_to_include: for utterance in utterances_to_include:
routes_to_add.append( routes_to_add.append(
(route, utterance, local_function_schemas if local_function_schemas else None, local_metadata) (
route,
utterance,
local_function_schemas if local_function_schemas else None,
local_metadata,
)
) )
return routes_to_add, routes_to_delete, layer_routes return routes_to_add, routes_to_delete, layer_routes
......
...@@ -546,7 +546,11 @@ class RouteLayer: ...@@ -546,7 +546,11 @@ class RouteLayer:
Route( Route(
name=route, name=route,
utterances=data.get("utterances", []), utterances=data.get("utterances", []),
function_schemas=[data.get("function_schemas", None)] if data.get("function_schemas") else None, function_schemas=(
[data.get("function_schemas", None)]
if data.get("function_schemas")
else None
),
metadata=data.get("metadata", {}), metadata=data.get("metadata", {}),
) )
for route, data in layer_routes_dict.items() for route, data in layer_routes_dict.items()
......
...@@ -114,6 +114,7 @@ def routes_3(): ...@@ -114,6 +114,7 @@ def routes_3():
Route(name="Route 2", utterances=["Asparagus"]), Route(name="Route 2", utterances=["Asparagus"]),
] ]
@pytest.fixture @pytest.fixture
def routes_4(): def routes_4():
return [ return [
...@@ -168,7 +169,7 @@ class TestRouteLayer: ...@@ -168,7 +169,7 @@ class TestRouteLayer:
) )
if index_cls is PineconeIndex: if index_cls is PineconeIndex:
time.sleep(15) # allow for index to be populated time.sleep(15) # allow for index to be populated
assert openai_encoder.score_threshold == 0.3 assert openai_encoder.score_threshold == 0.3
assert route_layer.score_threshold == 0.3 assert route_layer.score_threshold == 0.3
assert route_layer.top_k == 10 assert route_layer.top_k == 10
...@@ -363,59 +364,69 @@ class TestRouteLayer: ...@@ -363,59 +364,69 @@ class TestRouteLayer:
if index_cls is PineconeIndex: if index_cls is PineconeIndex:
# TEST LOCAL # TEST LOCAL
pinecone_index = PineconeIndex(sync="local") pinecone_index = PineconeIndex(sync="local")
route_layer = RouteLayer(encoder=openai_encoder, routes=routes_2, index=pinecone_index) route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_2, index=pinecone_index
)
time.sleep(15) # allow for index to be populated time.sleep(15) # allow for index to be populated
assert route_layer.index.get_routes() == [ assert route_layer.index.get_routes() == [
('Route 1', 'Hello', None, {}), ("Route 1", "Hello", None, {}),
('Route 2', 'Hi', None, {}), ("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# TEST REMOTE # TEST REMOTE
pinecone_index = PineconeIndex(sync="remote") pinecone_index = PineconeIndex(sync="remote")
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=pinecone_index) route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index
)
time.sleep(15) # allow for index to be populated time.sleep(15) # allow for index to be populated
assert route_layer.index.get_routes() == [ assert route_layer.index.get_routes() == [
('Route 1', 'Hello', None, {}), ("Route 1", "Hello", None, {}),
('Route 2', 'Hi', None, {}), ("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# TEST MERGE FORCE REMOTE # TEST MERGE FORCE REMOTE
pinecone_index = PineconeIndex(sync="merge-force-remote") pinecone_index = PineconeIndex(sync="merge-force-remote")
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=pinecone_index) route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index
)
time.sleep(15) # allow for index to be populated time.sleep(15) # allow for index to be populated
assert route_layer.index.get_routes() == [ assert route_layer.index.get_routes() == [
('Route 1', 'Hello', None, {}), ("Route 1", "Hello", None, {}),
('Route 2', 'Hi', None, {}), ("Route 2", "Hi", None, {}),
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# TEST MERGE FORCE LOCAL # TEST MERGE FORCE LOCAL
pinecone_index = PineconeIndex(sync="merge-force-local") pinecone_index = PineconeIndex(sync="merge-force-local")
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=pinecone_index) route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pinecone_index
)
time.sleep(15) # allow for index to be populated time.sleep(15) # allow for index to be populated
assert route_layer.index.get_routes() == [ assert route_layer.index.get_routes() == [
('Route 1', 'Hello', None, {'type': 'default'}), ("Route 1", "Hello", None, {"type": "default"}),
('Route 1', 'Hi', None, {'type': 'default'}), ("Route 1", "Hi", None, {"type": "default"}),
('Route 2', 'Bye', None, {}), ("Route 2", "Bye", None, {}),
('Route 2', 'Au revoir', None, {}), ("Route 2", "Au revoir", None, {}),
('Route 2', 'Goodbye', None, {}) ("Route 2", "Goodbye", None, {}),
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
# TEST MERGE # TEST MERGE
pinecone_index = PineconeIndex(sync="merge") pinecone_index = PineconeIndex(sync="merge")
route_layer = RouteLayer(encoder=openai_encoder, routes=routes_4, index=pinecone_index) route_layer = RouteLayer(
encoder=openai_encoder, routes=routes_4, index=pinecone_index
)
time.sleep(15) # allow for index to be populated time.sleep(15) # allow for index to be populated
assert route_layer.index.get_routes() == [ assert route_layer.index.get_routes() == [
('Route 1', 'Hello', None, {'type': 'default'}), ("Route 1", "Hello", None, {"type": "default"}),
('Route 1', 'Hi', None, {'type': 'default'}), ("Route 1", "Hi", None, {"type": "default"}),
('Route 1', 'Goodbye', None, {'type': 'default'}), ("Route 1", "Goodbye", None, {"type": "default"}),
('Route 2', 'Bye', None, {}), ("Route 2", "Bye", None, {}),
('Route 2', 'Asparagus', None, {}), ("Route 2", "Asparagus", None, {}),
('Route 2', 'Au revoir', None, {}), ("Route 2", "Au revoir", None, {}),
('Route 2', 'Goodbye', None, {}), ("Route 2", "Goodbye", None, {}),
], "The routes in the index should match the local routes" ], "The routes in the index should match the local routes"
def test_query_with_no_index(self, openai_encoder, index_cls): def test_query_with_no_index(self, openai_encoder, index_cls):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment