From 3e5a9bc278519180e88f016c18c2d95b2398746a Mon Sep 17 00:00:00 2001
From: Vits <vittorio.mayellaro.dev@gmail.com>
Date: Thu, 5 Sep 2024 12:12:30 +0200
Subject: [PATCH] Formatting and linting

---
 semantic_router/index/pinecone.py | 77 ++++++++++++++++++++++++-------
 semantic_router/layer.py          |  6 ++-
 tests/unit/test_layer.py          | 59 +++++++++++++----------
 3 files changed, 100 insertions(+), 42 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 3de221ba..9e5d0f2e 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -268,12 +268,12 @@ class PineconeIndex(BaseIndex):
         for route in all_routes:
             local_utterances = local_dict.get(route, {}).get("utterances", set())
             remote_utterances = remote_dict.get(route, {}).get("utterances", set())
-            local_function_schemas = local_dict.get(route, {}).get(
-                "function_schemas", {}
-            ) or {}
-            remote_function_schemas = remote_dict.get(route, {}).get(
-                "function_schemas", {}
-            ) or {}
+            local_function_schemas = (
+                local_dict.get(route, {}).get("function_schemas", {}) or {}
+            )
+            remote_function_schemas = (
+                remote_dict.get(route, {}).get("function_schemas", {}) or {}
+            )
             local_metadata = local_dict.get(route, {}).get("metadata", {})
             remote_metadata = remote_dict.get(route, {}).get("metadata", {})
 
@@ -295,7 +295,9 @@ class PineconeIndex(BaseIndex):
                 if local_utterances:
                     layer_routes[route] = {
                         "utterances": list(local_utterances),
-                        "function_schemas": local_function_schemas if local_function_schemas else None,
+                        "function_schemas": (
+                            local_function_schemas if local_function_schemas else None
+                        ),
                         "metadata": local_metadata,
                     }
 
@@ -303,7 +305,9 @@ class PineconeIndex(BaseIndex):
                 if remote_utterances:
                     layer_routes[route] = {
                         "utterances": list(remote_utterances),
-                        "function_schemas": remote_function_schemas if remote_function_schemas else None,
+                        "function_schemas": (
+                            remote_function_schemas if remote_function_schemas else None
+                        ),
                         "metadata": remote_metadata,
                     }
 
@@ -319,7 +323,9 @@ class PineconeIndex(BaseIndex):
                 if local_utterances:
                     layer_routes[route] = {
                         "utterances": list(local_utterances),
-                        "function_schemas": local_function_schemas if local_function_schemas else None,
+                        "function_schemas": (
+                            local_function_schemas if local_function_schemas else None
+                        ),
                         "metadata": local_metadata,
                     }
 
@@ -329,14 +335,22 @@ class PineconeIndex(BaseIndex):
                     if local_utterances:
                         layer_routes[route] = {
                             "utterances": list(local_utterances),
-                            "function_schemas": local_function_schemas if local_function_schemas else None,
+                            "function_schemas": (
+                                local_function_schemas
+                                if local_function_schemas
+                                else None
+                            ),
                             "metadata": local_metadata,
                         }
                 else:
                     if remote_utterances:
                         layer_routes[route] = {
                             "utterances": list(remote_utterances),
-                            "function_schemas": remote_function_schemas if remote_function_schemas else None,
+                            "function_schemas": (
+                                remote_function_schemas
+                                if remote_function_schemas
+                                else None
+                            ),
                             "metadata": remote_metadata,
                         }
 
@@ -353,14 +367,22 @@ class PineconeIndex(BaseIndex):
                     if local_utterances:
                         layer_routes[route] = {
                             "utterances": list(local_utterances),
-                            "function_schemas": local_function_schemas if local_function_schemas else None,
+                            "function_schemas": (
+                                local_function_schemas
+                                if local_function_schemas
+                                else None
+                            ),
                             "metadata": local_metadata,
                         }
                 else:
                     if remote_utterances:
                         layer_routes[route] = {
                             "utterances": list(remote_utterances),
-                            "function_schemas": remote_function_schemas if remote_function_schemas else None,
+                            "function_schemas": (
+                                remote_function_schemas
+                                if remote_function_schemas
+                                else None
+                            ),
                             "metadata": remote_metadata,
                         }
 
@@ -375,7 +397,9 @@ class PineconeIndex(BaseIndex):
                     }
                     layer_routes[route] = {
                         "utterances": list(remote_utterances.union(local_utterances)),
-                        "function_schemas": merged_function_schemas if merged_function_schemas else None,
+                        "function_schemas": (
+                            merged_function_schemas if merged_function_schemas else None
+                        ),
                         "metadata": merged_metadata,
                     }
 
@@ -389,17 +413,36 @@ class PineconeIndex(BaseIndex):
             ]:
                 for utterance in local_utterances:
                     routes_to_add.append(
-                        (route, utterance, local_function_schemas if local_function_schemas else None, local_metadata)
+                        (
+                            route,
+                            utterance,
+                            local_function_schemas if local_function_schemas else None,
+                            local_metadata,
+                        )
                     )
             if (metadata_changed or function_schema_changed) and self.sync == "merge":
                 for utterance in local_utterances:
                     routes_to_add.append(
-                        (route, utterance, merged_function_schemas if merged_function_schemas else None, merged_metadata)
+                        (
+                            route,
+                            utterance,
+                            (
+                                merged_function_schemas
+                                if merged_function_schemas
+                                else None
+                            ),
+                            merged_metadata,
+                        )
                     )
             elif utterances_to_include:
                 for utterance in utterances_to_include:
                     routes_to_add.append(
-                        (route, utterance, local_function_schemas if local_function_schemas else None, local_metadata)
+                        (
+                            route,
+                            utterance,
+                            local_function_schemas if local_function_schemas else None,
+                            local_metadata,
+                        )
                     )
 
         return routes_to_add, routes_to_delete, layer_routes
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index f7d1395a..45a80bde 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -546,7 +546,11 @@ class RouteLayer:
             Route(
                 name=route,
                 utterances=data.get("utterances", []),
-                function_schemas=[data.get("function_schemas", None)] if data.get("function_schemas") else None,
+                function_schemas=(
+                    [data.get("function_schemas", None)]
+                    if data.get("function_schemas")
+                    else None
+                ),
                 metadata=data.get("metadata", {}),
             )
             for route, data in layer_routes_dict.items()
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index eaec568e..ff87e6f4 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -114,6 +114,7 @@ def routes_3():
         Route(name="Route 2", utterances=["Asparagus"]),
     ]
 
+
 @pytest.fixture
 def routes_4():
     return [
@@ -168,7 +169,7 @@ class TestRouteLayer:
         )
         if index_cls is PineconeIndex:
             time.sleep(15)  # allow for index to be populated
-            
+
         assert openai_encoder.score_threshold == 0.3
         assert route_layer.score_threshold == 0.3
         assert route_layer.top_k == 10
@@ -363,59 +364,69 @@ class TestRouteLayer:
         if index_cls is PineconeIndex:
             # TEST LOCAL
             pinecone_index = PineconeIndex(sync="local")
-            route_layer = RouteLayer(encoder=openai_encoder, routes=routes_2, index=pinecone_index)
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes_2, index=pinecone_index
+            )
             time.sleep(15)  # allow for index to be populated
             assert route_layer.index.get_routes() == [
-                ('Route 1', 'Hello', None, {}), 
-                ('Route 2', 'Hi', None, {}), 
+                ("Route 1", "Hello", None, {}),
+                ("Route 2", "Hi", None, {}),
             ], "The routes in the index should match the local routes"
 
             # TEST REMOTE
             pinecone_index = PineconeIndex(sync="remote")
-            route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=pinecone_index)
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes, index=pinecone_index
+            )
 
             time.sleep(15)  # allow for index to be populated
             assert route_layer.index.get_routes() == [
-                ('Route 1', 'Hello', None, {}), 
-                ('Route 2', 'Hi', None, {}), 
+                ("Route 1", "Hello", None, {}),
+                ("Route 2", "Hi", None, {}),
             ], "The routes in the index should match the local routes"
 
             # TEST MERGE FORCE REMOTE
             pinecone_index = PineconeIndex(sync="merge-force-remote")
-            route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=pinecone_index)
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes, index=pinecone_index
+            )
 
             time.sleep(15)  # allow for index to be populated
             assert route_layer.index.get_routes() == [
-                ('Route 1', 'Hello', None, {}),
-                ('Route 2', 'Hi', None, {}),
+                ("Route 1", "Hello", None, {}),
+                ("Route 2", "Hi", None, {}),
             ], "The routes in the index should match the local routes"
 
             # TEST MERGE FORCE LOCAL
             pinecone_index = PineconeIndex(sync="merge-force-local")
-            route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=pinecone_index)
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes, index=pinecone_index
+            )
 
             time.sleep(15)  # allow for index to be populated
             assert route_layer.index.get_routes() == [
-                ('Route 1', 'Hello', None, {'type': 'default'}),
-                ('Route 1', 'Hi', None, {'type': 'default'}),
-                ('Route 2', 'Bye', None, {}),
-                ('Route 2', 'Au revoir', None, {}),
-                ('Route 2', 'Goodbye', None, {})
+                ("Route 1", "Hello", None, {"type": "default"}),
+                ("Route 1", "Hi", None, {"type": "default"}),
+                ("Route 2", "Bye", None, {}),
+                ("Route 2", "Au revoir", None, {}),
+                ("Route 2", "Goodbye", None, {}),
             ], "The routes in the index should match the local routes"
 
             # TEST MERGE
             pinecone_index = PineconeIndex(sync="merge")
-            route_layer = RouteLayer(encoder=openai_encoder, routes=routes_4, index=pinecone_index)
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes_4, index=pinecone_index
+            )
 
             time.sleep(15)  # allow for index to be populated
             assert route_layer.index.get_routes() == [
-                ('Route 1', 'Hello', None, {'type': 'default'}),
-                ('Route 1', 'Hi', None, {'type': 'default'}),
-                ('Route 1', 'Goodbye', None, {'type': 'default'}),
-                ('Route 2', 'Bye', None, {}),
-                ('Route 2', 'Asparagus', None, {}),
-                ('Route 2', 'Au revoir', None, {}),
-                ('Route 2', 'Goodbye', None, {}),
+                ("Route 1", "Hello", None, {"type": "default"}),
+                ("Route 1", "Hi", None, {"type": "default"}),
+                ("Route 1", "Goodbye", None, {"type": "default"}),
+                ("Route 2", "Bye", None, {}),
+                ("Route 2", "Asparagus", None, {}),
+                ("Route 2", "Au revoir", None, {}),
+                ("Route 2", "Goodbye", None, {}),
             ], "The routes in the index should match the local routes"
 
     def test_query_with_no_index(self, openai_encoder, index_cls):
-- 
GitLab