From 51d8b5ae503187fe70307dee44dafa42dc8743dd Mon Sep 17 00:00:00 2001
From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com>
Date: Tue, 20 Aug 2024 09:09:42 +0300
Subject: [PATCH] update the indexes to add function_schemas

---
 semantic_router/index/base.py     |  2 +-
 semantic_router/index/local.py    |  2 ++
 semantic_router/index/pinecone.py |  2 +-
 semantic_router/index/postgres.py |  6 +++++-
 semantic_router/index/qdrant.py   |  1 +
 semantic_router/layer.py          | 20 ++++++++++++++++----
 6 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index d25d41dc..3d391083 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -26,7 +26,7 @@ class BaseIndex(BaseModel):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[Any],
-        function_schemas: List[Dict[str, Any]],
+        function_schemas: List[Dict[str, Any]] = None,  # type: ignore
     ):
         """
         Add embeddings to the index.
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 7150b267..802455cb 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -5,6 +5,7 @@ import numpy as np
 from semantic_router.index.base import BaseIndex
 from semantic_router.linear import similarity_matrix, top_scores
 from semantic_router.utils.logger import logger
+from typing import Any
 
 
 class LocalIndex(BaseIndex):
@@ -26,6 +27,7 @@ class LocalIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
+        function_schemas: List[Dict[str, Any]] = None,  # type: ignore
     ):
         embeds = np.array(embeddings)  # type: ignore
         routes_arr = np.array(routes)
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 0b84ad4c..da11ec6e 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -310,7 +310,7 @@ class PineconeIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
-        function_schemas: List[Dict[str, Any]] = "",
+        function_schemas: List[Dict[str, Any]] = None,  # type: ignore
         batch_size: int = 100,
     ):
         """Add vectors to Pinecone in batches."""
diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py
index 4c971d4d..9fbac62f 100644
--- a/semantic_router/index/postgres.py
+++ b/semantic_router/index/postgres.py
@@ -254,7 +254,11 @@ class PostgresIndex(BaseIndex):
                 raise ValueError("No comment found for the 'vector' column.")
 
     def add(
-        self, embeddings: List[List[float]], routes: List[str], utterances: List[Any]
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[Any],
+        function_schemas: List[Dict[str, Any]] = None,  # type: ignore
     ) -> None:
         """
         Adds vectors to the index.
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index c1a5e28b..0fc6aa52 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -175,6 +175,7 @@ class QdrantIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
+        function_schemas: List[Dict[str, Any]] = None,  # type: ignore
         batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
     ):
         self.dimensions = self.dimensions or len(embeddings[0])
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 2285a7d9..bae67544 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -437,7 +437,7 @@ class RouteLayer:
             function_schemas=(
                 route.function_schemas * len(route.utterances)
                 if route.function_schemas
-                else [""] * len(route.utterances)
+                else [""] * len(route.utterances)  # type: ignore
             ),
         )
 
@@ -482,7 +482,9 @@ class RouteLayer:
 
     def _add_routes(self, routes: List[Route]):
         # create embeddings for all routes
-        route_names, all_utterances = self._extract_routes_details(routes)
+        route_names, all_utterances, function_schemas = self._extract_routes_details(
+            routes
+        )
         embedded_utterances = self.encoder(all_utterances)
         # create route array
         # add everything to the index
@@ -490,11 +492,14 @@ class RouteLayer:
             embeddings=embedded_utterances,
             routes=route_names,
             utterances=all_utterances,
+            function_schemas=function_schemas,
         )
 
     def _add_and_sync_routes(self, routes: List[Route]):
         # create embeddings for all routes and sync at startup with remote ones based on sync setting
-        local_route_names, local_utterances = self._extract_routes_details(routes)
+        local_route_names, local_utterances, local_function_schemas = (
+            self._extract_routes_details(routes)
+        )
         routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index(
             local_route_names=local_route_names,
             local_utterances=local_utterances,
@@ -522,6 +527,7 @@ class RouteLayer:
             embeddings=embedded_utterances_to_add,
             routes=route_names_to_add,
             utterances=all_utterances_to_add,
+            function_schemas=local_function_schemas,
         )
 
         self._set_layer_routes(layer_routes)
@@ -529,7 +535,13 @@ class RouteLayer:
     def _extract_routes_details(self, routes: List[Route]) -> Tuple:
         route_names = [route.name for route in routes for _ in route.utterances]
         utterances = [utterance for route in routes for utterance in route.utterances]
-        return route_names, utterances
+        function_schemas = [
+            function_schema if function_schema is not None else ""
+            for route in routes
+            if route.function_schemas is not None
+            for function_schema in route.function_schemas
+        ]
+        return route_names, utterances, function_schemas
 
     def _encode(self, text: str) -> Any:
         """Given some text, encode it."""
-- 
GitLab