From 51d8b5ae503187fe70307dee44dafa42dc8743dd Mon Sep 17 00:00:00 2001 From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com> Date: Tue, 20 Aug 2024 09:09:42 +0300 Subject: [PATCH] update the indexes to add function_schemas --- semantic_router/index/base.py | 2 +- semantic_router/index/local.py | 2 ++ semantic_router/index/pinecone.py | 2 +- semantic_router/index/postgres.py | 6 +++++- semantic_router/index/qdrant.py | 1 + semantic_router/layer.py | 20 ++++++++++++++++---- 6 files changed, 26 insertions(+), 7 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index d25d41dc..3d391083 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -26,7 +26,7 @@ class BaseIndex(BaseModel): embeddings: List[List[float]], routes: List[str], utterances: List[Any], - function_schemas: List[Dict[str, Any]], + function_schemas: List[Dict[str, Any]] = None, # type: ignore ): """ Add embeddings to the index. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 7150b267..802455cb 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -5,6 +5,7 @@ import numpy as np from semantic_router.index.base import BaseIndex from semantic_router.linear import similarity_matrix, top_scores from semantic_router.utils.logger import logger +from typing import Any class LocalIndex(BaseIndex): @@ -26,6 +27,7 @@ class LocalIndex(BaseIndex): embeddings: List[List[float]], routes: List[str], utterances: List[str], + function_schemas: List[Dict[str, Any]] = None, # type: ignore ): embeds = np.array(embeddings) # type: ignore routes_arr = np.array(routes) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 0b84ad4c..da11ec6e 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -310,7 +310,7 @@ class PineconeIndex(BaseIndex): embeddings: List[List[float]], routes: List[str], utterances: List[str], - function_schemas: List[Dict[str, Any]] = "", + function_schemas: List[Dict[str, Any]] = None, # type: ignore batch_size: int = 100, ): """Add vectors to Pinecone in batches.""" diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 4c971d4d..9fbac62f 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -254,7 +254,11 @@ class PostgresIndex(BaseIndex): raise ValueError("No comment found for the 'vector' column.") def add( - self, embeddings: List[List[float]], routes: List[str], utterances: List[Any] + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[Any], + function_schemas: List[Dict[str, Any]] = None, # type: ignore ) -> None: """ Adds vectors to the index. diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index c1a5e28b..0fc6aa52 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -175,6 +175,7 @@ class QdrantIndex(BaseIndex): embeddings: List[List[float]], routes: List[str], utterances: List[str], + function_schemas: List[Dict[str, Any]] = None, # type: ignore batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, ): self.dimensions = self.dimensions or len(embeddings[0]) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 2285a7d9..bae67544 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -437,7 +437,7 @@ class RouteLayer: function_schemas=( route.function_schemas * len(route.utterances) if route.function_schemas - else [""] * len(route.utterances) + else [""] * len(route.utterances) # type: ignore ), ) @@ -482,7 +482,9 @@ class RouteLayer: def _add_routes(self, routes: List[Route]): # create embeddings for all routes - route_names, all_utterances = self._extract_routes_details(routes) + route_names, all_utterances, function_schemas = self._extract_routes_details( + routes + ) embedded_utterances = self.encoder(all_utterances) # create route array # add everything to the index @@ -490,11 +492,14 @@ class RouteLayer: embeddings=embedded_utterances, routes=route_names, utterances=all_utterances, + function_schemas=function_schemas, ) def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting - local_route_names, local_utterances = self._extract_routes_details(routes) + local_route_names, local_utterances, local_function_schemas = ( + self._extract_routes_details(routes) + ) routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( local_route_names=local_route_names, local_utterances=local_utterances, @@ -522,6 +527,7 @@ class RouteLayer: embeddings=embedded_utterances_to_add, routes=route_names_to_add, utterances=all_utterances_to_add, + function_schemas=local_function_schemas, ) self._set_layer_routes(layer_routes) @@ -529,7 +535,13 @@ class RouteLayer: def _extract_routes_details(self, routes: List[Route]) -> Tuple: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] - return route_names, utterances + function_schemas = [ + function_schema if function_schema is not None else "" + for route in routes + if route.function_schemas is not None + for function_schema in route.function_schemas + ] + return route_names, utterances, function_schemas def _encode(self, text: str) -> Any: """Given some text, encode it.""" -- GitLab