From 87d6ae4b160f66efbdaaabd0234129dc50c6a669 Mon Sep 17 00:00:00 2001
From: jamescalam <james.briggs@hotmail.com>
Date: Tue, 26 Nov 2024 22:39:28 +0100
Subject: [PATCH] feat: new sparse embedding support and abstractions

---
 docs/examples/pinecone-hybrid.ipynb | 196 ++++++++++++++++------------
 semantic_router/encoders/aurelio.py |  17 +--
 semantic_router/index/pinecone.py   |  27 ++--
 semantic_router/routers/base.py     |  92 ++++++++-----
 semantic_router/routers/hybrid.py   |  47 ++-----
 semantic_router/routers/semantic.py | 100 +-------------
 semantic_router/schema.py           |  38 +++++-
 7 files changed, 245 insertions(+), 272 deletions(-)

diff --git a/docs/examples/pinecone-hybrid.ipynb b/docs/examples/pinecone-hybrid.ipynb
index f18eef6e..134b6e0d 100644
--- a/docs/examples/pinecone-hybrid.ipynb
+++ b/docs/examples/pinecone-hybrid.ipynb
@@ -174,9 +174,9 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n",
-      "2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n",
-      "2024-11-24 19:41:05 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n"
+      "2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_namespace_packages.py:12 - discover_subpackages() - Discovering subpackages in _NamespacePath(['/Users/jamesbriggs/Library/Caches/pypoetry/virtualenvs/semantic-router-C1zr4a78-py3.12/lib/python3.12/site-packages/pinecone_plugins'])\n",
+      "2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - discover_plugins.py:9 - discover_plugins() - Looking for plugins in pinecone_plugins.inference\n",
+      "2024-11-26 22:34:54 - pinecone_plugin_interface.logging - INFO - installation.py:10 - install_plugins() - Installing plugin inference into Pinecone\n"
      ]
     }
    ],
@@ -205,29 +205,7 @@
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "2024-11-24 19:41:15 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
-      "2024-11-24 19:41:17 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
-      "politics: isn't politics the best thing ever\n",
-      "politics: why don't you tell me about your political opinions\n",
-      "politics: don't you just love the president\n",
-      "politics: don't you just hate the president\n",
-      "politics: they're going to destroy this country!\n",
-      "politics: they will save the country!\n",
-      "2024-11-24 19:41:17 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
-      "2024-11-24 19:41:18 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
-      "chitchat: how's the weather today?\n",
-      "chitchat: how are things going?\n",
-      "chitchat: lovely weather today\n",
-      "chitchat: the weather is horrendous\n",
-      "chitchat: let's go to the chippy\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "from semantic_router.routers import HybridRouter\n",
     "\n",
@@ -248,23 +226,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "2024-11-24 19:42:06 - semantic_router.utils.logger - WARNING - pinecone.py:424 - _read_hash() - Configuration for hash parameter not found in index.\n"
-     ]
-    },
     {
      "data": {
       "text/plain": [
-       "False"
+       "True"
       ]
      },
-     "execution_count": 9,
+     "execution_count": 8,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -282,26 +253,26 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "['- chitchat: how are things going?',\n",
-       " \"- chitchat: how's the weather today?\",\n",
-       " \"- chitchat: let's go to the chippy\",\n",
-       " '- chitchat: lovely weather today',\n",
-       " '- chitchat: the weather is horrendous',\n",
-       " \"- politics: don't you just hate the president\",\n",
-       " \"- politics: don't you just love the president\",\n",
-       " \"- politics: isn't politics the best thing ever\",\n",
-       " '- politics: they will save the country!',\n",
-       " \"- politics: they're going to destroy this country!\",\n",
-       " \"- politics: why don't you tell me about your political opinions\"]"
+       "['  chitchat: how are things going?',\n",
+       " \"  chitchat: how's the weather today?\",\n",
+       " \"  chitchat: let's go to the chippy\",\n",
+       " '  chitchat: lovely weather today',\n",
+       " '  chitchat: the weather is horrendous',\n",
+       " \"  politics: don't you just hate the president\",\n",
+       " \"  politics: don't you just love the president\",\n",
+       " \"  politics: isn't politics the best thing ever\",\n",
+       " '  politics: they will save the country!',\n",
+       " \"  politics: they're going to destroy this country!\",\n",
+       " \"  politics: why don't you tell me about your political opinions\"]"
       ]
      },
-     "execution_count": 10,
+     "execution_count": 9,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -319,16 +290,26 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "[]"
+       "[Utterance(route='chitchat', utterance='how are things going?', function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='chitchat', utterance=\"how's the weather today?\", function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='chitchat', utterance='the weather is horrendous', function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='chitchat', utterance='lovely weather today', function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='chitchat', utterance=\"let's go to the chippy\", function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='politics', utterance=\"don't you just hate the president\", function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='politics', utterance=\"don't you just love the president\", function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='politics', utterance=\"they're going to destroy this country!\", function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='politics', utterance='they will save the country!', function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='politics', utterance=\"isn't politics the best thing ever\", function_schemas=None, metadata={}, diff_tag=' '),\n",
+       " Utterance(route='politics', utterance=\"why don't you tell me about your political opinions\", function_schemas=None, metadata={}, diff_tag=' ')]"
       ]
      },
-     "execution_count": 12,
+     "execution_count": 10,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -346,31 +327,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 11,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "2024-11-24 19:48:29 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
-      "2024-11-24 19:48:31 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
-      "politics: isn't politics the best thing ever\n",
-      "politics: why don't you tell me about your political opinions\n",
-      "politics: don't you just love the president\n",
-      "politics: don't you just hate the president\n",
-      "politics: they're going to destroy this country!\n",
-      "politics: they will save the country!\n",
-      "2024-11-24 19:48:31 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
-      "2024-11-24 19:48:32 - semantic_router.utils.logger - WARNING - pinecone.py:247 - add() - TEMP | add:\n",
-      "chitchat: how's the weather today?\n",
-      "chitchat: how are things going?\n",
-      "chitchat: lovely weather today\n",
-      "chitchat: the weather is horrendous\n",
-      "chitchat: let's go to the chippy\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "router = HybridRouter(\n",
     "    encoder=encoder,\n",
@@ -390,16 +349,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "False"
+       "True"
       ]
      },
-     "execution_count": 16,
+     "execution_count": 12,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -408,6 +367,36 @@
     "router.is_synced()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['  chitchat: how are things going?',\n",
+       " \"  chitchat: how's the weather today?\",\n",
+       " \"  chitchat: let's go to the chippy\",\n",
+       " '  chitchat: lovely weather today',\n",
+       " '  chitchat: the weather is horrendous',\n",
+       " \"  politics: don't you just hate the president\",\n",
+       " \"  politics: don't you just love the president\",\n",
+       " \"  politics: isn't politics the best thing ever\",\n",
+       " '  politics: they will save the country!',\n",
+       " \"  politics: they're going to destroy this country!\",\n",
+       " \"  politics: why don't you tell me about your political opinions\"]"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "router.get_utterance_diff()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -417,9 +406,54 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2024-11-26 22:35:56 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "RouteChoice(name=None, function_call=None, similarity_score=None)"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "router(\"it's raining cats and dogs today\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2024-11-26 22:35:20 - httpx - INFO - _client.py:1013 - _send_single_request() - HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "RouteChoice(name=None, function_call=None, similarity_score=None)"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "router(\"I'm interested in learning about llama 2\")"
    ]
diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py
index c257b514..bc150e50 100644
--- a/semantic_router/encoders/aurelio.py
+++ b/semantic_router/encoders/aurelio.py
@@ -5,6 +5,7 @@ from pydantic.v1 import Field
 from aurelio_sdk import AurelioClient, AsyncAurelioClient, EmbeddingResponse
 
 from semantic_router.encoders.base import BaseEncoder
+from semantic_router.schema import SparseEmbedding
 
 
 class AurelioSparseEncoder(BaseEncoder):
@@ -28,19 +29,15 @@ class AurelioSparseEncoder(BaseEncoder):
         self.client = AurelioClient(api_key=api_key)
         self.async_client = AsyncAurelioClient(api_key=api_key)
 
-    def __call__(self, docs: list[str]) -> list[dict[int, float]]:
+    def __call__(self, docs: list[str]) -> list[SparseEmbedding]:
         res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name)
-        embeds = [r.embedding.model_dump() for r in res.data]
-        # convert sparse vector to {index: value} format
-        sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds]
-        return sparse_dicts
+        embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
+        return embeds
     
-    async def acall(self, docs: list[str]) -> list[dict[int, float]]:
+    async def acall(self, docs: list[str]) -> list[SparseEmbedding]:
         res: EmbeddingResponse = await self.async_client.embedding(input=docs, model=self.name)
-        embeds = [r.embedding.model_dump() for r in res.data]
-        # convert sparse vector to {index: value} format
-        sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds]
-        return sparse_dicts
+        embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
+        return embeds
 
     def fit(self, docs: List[str]):
         raise NotImplementedError("AurelioSparseEncoder does not support fit.")
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 2d432d33..5eb0aecb 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -11,7 +11,7 @@ import numpy as np
 from pydantic.v1 import BaseModel, Field
 
 from semantic_router.index.base import BaseIndex
-from semantic_router.schema import ConfigParameter
+from semantic_router.schema import ConfigParameter, SparseEmbedding
 from semantic_router.utils.logger import logger
 
 
@@ -243,8 +243,6 @@ class PineconeIndex(BaseIndex):
         sparse_embeddings: Optional[List[dict[int, float]]] = None,
     ):
         """Add vectors to Pinecone in batches."""
-        temp = "\n".join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)])
-        logger.warning("TEMP | add:\n" + temp)
         if self.index is None:
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = self._init_index(force_create=True)
@@ -272,10 +270,6 @@ class PineconeIndex(BaseIndex):
             self._batch_upsert(batch)
 
     def _remove_and_sync(self, routes_to_delete: dict):
-        temp = "\n".join(
-            [f"{route}: {utterances}" for route, utterances in routes_to_delete.items()]
-        )
-        logger.warning("TEMP | _remove_and_sync:\n" + temp)
         for route, utterances in routes_to_delete.items():
             remote_routes = self._get_routes_with_ids(route_name=route)
             ids_to_delete = [
@@ -364,6 +358,7 @@ class PineconeIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
         **kwargs: Any,
     ) -> Tuple[np.ndarray, List[str]]:
         """Search the index for the query vector and return the top_k results.
@@ -374,10 +369,10 @@ class PineconeIndex(BaseIndex):
         :type top_k: int, optional
         :param route_filter: A list of route names to filter the search results, defaults to None.
         :type route_filter: Optional[List[str]], optional
+        :param sparse_vector: An optional sparse vector to include in the query.
+        :type sparse_vector: Optional[SparseEmbedding]
         :param kwargs: Additional keyword arguments for the query, including sparse_vector.
         :type kwargs: Any
-        :keyword sparse_vector: An optional sparse vector to include in the query.
-        :type sparse_vector: Optional[dict]
         :return: A tuple containing an array of scores and a list of route names.
         :rtype: Tuple[np.ndarray, List[str]]
         :raises ValueError: If the index is not populated.
@@ -389,9 +384,13 @@ class PineconeIndex(BaseIndex):
             filter_query = {"sr_route": {"$in": route_filter}}
         else:
             filter_query = None
+        if sparse_vector is not None:
+            if isinstance(sparse_vector, dict):
+                sparse_vector = SparseEmbedding.from_dict(sparse_vector)
+            sparse_vector = sparse_vector.to_pinecone()
         results = self.index.query(
             vector=[query_vector_list],
-            sparse_vector=kwargs.get("sparse_vector", None),
+            sparse_vector=sparse_vector,
             top_k=top_k,
             filter=filter_query,
             include_metadata=True,
@@ -653,6 +652,8 @@ class PineconeIndex(BaseIndex):
             )
 
     def __len__(self):
-        return self.index.describe_index_stats()["namespaces"][self.namespace][
-            "vector_count"
-        ]
+        namespace_stats = self.index.describe_index_stats()["namespaces"].get(self.namespace)
+        if namespace_stats:
+            return namespace_stats["vector_count"]
+        else:
+            return 0
diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py
index f9170051..f7ab5226 100644
--- a/semantic_router/routers/base.py
+++ b/semantic_router/routers/base.py
@@ -250,8 +250,39 @@ class RouterConfig:
         return utterances
 
     def add(self, route: Route):
+        """Add a route to the local SemanticRouter and index.
+
+        :param route: The route to add.
+        :type route: Route
+        """
+        current_local_hash = self._get_hash()
+        current_remote_hash = self.index._read_hash()
+        if current_remote_hash.value == "":
+            # if remote hash is empty, the index is to be initialized
+            current_remote_hash = current_local_hash
+        embedded_utterances = self.encoder(route.utterances)
+        self.index.add(
+            embeddings=embedded_utterances,
+            routes=[route.name] * len(route.utterances),
+            utterances=route.utterances,
+            function_schemas=(
+                route.function_schemas * len(route.utterances)
+                if route.function_schemas
+                else [{}] * len(route.utterances)
+            ),
+            metadata_list=[route.metadata if route.metadata else {}]
+            * len(route.utterances),
+        )
+
         self.routes.append(route)
-        logger.info(f"Added route `{route.name}`")
+        if current_local_hash.value == current_remote_hash.value:
+            self._write_hash()  # update current hash in index
+        else:
+            logger.warning(
+                "Local and remote route layers were not aligned. Remote hash "
+                "not updated. Use `SemanticRouter.get_utterance_diff()` to see "
+                "details."
+            )
 
     def get(self, name: str) -> Optional[Route]:
         for route in self.routes:
@@ -289,10 +320,6 @@ class BaseRouter(BaseModel):
     class Config:
         arbitrary_types_allowed = True
 
-    @validator("index", pre=True, always=True)
-    def set_index(cls, v):
-        return v if v is not None else LocalIndex()
-
     def __init__(
         self,
         encoder: Optional[BaseEncoder] = None,
@@ -321,14 +348,11 @@ class BaseRouter(BaseModel):
         else:
             self.encoder = encoder
         self.llm = llm
-        self.routes = routes if routes else []
-        if self.encoder.score_threshold is not None:
-            self.score_threshold = self.encoder.score_threshold
-            if self.score_threshold is None:
-                logger.warning(
-                    "No score threshold value found in encoder. Using the default "
-                    "'None' value can lead to unexpected results."
-                )
+        self.routes = routes.copy() if routes else []
+        # initialize index
+        self._set_index(index=index)
+        # set score threshold using default method
+        self._set_score_threshold()
         self.top_k = top_k
         if self.top_k < 1:
             raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.")
@@ -344,15 +368,29 @@ class BaseRouter(BaseModel):
         for route in self.routes:
             if route.score_threshold is None:
                 route.score_threshold = self.score_threshold
-        # if routes list has been passed, we initialize index now
+        # run initialize index now if auto sync is active
+        if self.auto_sync:
+            self._init_index_state()
+
+    def _set_index(self, index: Optional[BaseIndex]):
+        if index is None:
+            logger.warning("No index provided. Using default LocalIndex.")
+            self.index = LocalIndex()
+        else:
+            self.index = index
+
+    def _init_index_state(self):
+        """Initializes an index (where required) and runs auto_sync if active.
+        """
+        # initialize index now, check if we need dimensions
+        if self.index.dimensions is None:
+            dims = len(self.encoder(["test"])[0])
+            self.index.dimensions = dims
+        # now init index
+        if isinstance(self.index, PineconeIndex):
+            self.index.index = self.index._init_index(force_create=True)
+        # run auto sync if active
         if self.auto_sync:
-            # initialize index now, check if we need dimensions
-            if self.index.dimensions is None:
-                dims = len(self.encoder(["test"])[0])
-                self.index.dimensions = dims
-            # now init index
-            if isinstance(self.index, PineconeIndex):
-                self.index.index = self.index._init_index(force_create=True)
             local_utterances = self.to_config().to_utterances()
             remote_utterances = self.index.get_utterances()
             diff = UtteranceDiff.from_utterances(
@@ -585,8 +623,6 @@ class BaseRouter(BaseModel):
                     new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
                 new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas
                 new_routes[utt_obj.route].metadata = utt_obj.metadata
-        temp = "\n".join([f"{name}: {r.utterances}" for name, r in new_routes.items()])
-        logger.warning("TEMP | _local_upsert:\n" + temp)
         self.routes = list(new_routes.values())
 
     def _local_delete(self, utterances: List[Utterance]):
@@ -599,8 +635,6 @@ class BaseRouter(BaseModel):
         route_dict: dict[str, List[str]] = {}
         for utt in utterances:
             route_dict.setdefault(utt.route, []).append(utt.utterance)
-        temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()])
-        logger.warning("TEMP | _local_delete:\n" + temp)
         # iterate over current routes and delete specific utterance if found
         new_routes = []
         for route in self.routes:
@@ -622,17 +656,9 @@ class BaseRouter(BaseModel):
                             metadata=route.metadata,
                         )
                     )
-                logger.warning(
-                    f"TEMP | _local_delete OLD | {route.name}: {route.utterances}"
-                )
-                logger.warning(
-                    f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}"
-                )
             else:
                 # the route is not in the route_dict, so we keep it as is
                 new_routes.append(route)
-        temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()])
-        logger.warning("TEMP | _local_delete:\n" + temp)
 
         self.routes = new_routes
 
diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py
index 0b9271a7..a0a96671 100644
--- a/semantic_router/routers/hybrid.py
+++ b/semantic_router/routers/hybrid.py
@@ -11,7 +11,7 @@ from semantic_router.encoders import (
 )
 from semantic_router.route import Route
 from semantic_router.index.hybrid_local import HybridLocalIndex
-from semantic_router.schema import RouteChoice
+from semantic_router.schema import RouteChoice, SparseEmbedding
 from semantic_router.utils.logger import logger
 from semantic_router.routers.base import BaseRouter
 from semantic_router.llms import BaseLLM
@@ -36,10 +36,13 @@ class HybridRouter(BaseRouter):
         auto_sync: Optional[str] = None,
         alpha: float = 0.3,
     ):
+        if index is None:
+            logger.warning("No index provided. Using default HybridLocalIndex.")
+            index = HybridLocalIndex()
         super().__init__(
             encoder=encoder,
             llm=llm,
-            #routes=routes.copy(),
+            routes=routes,
             index=index,
             top_k=top_k,
             aggregation=aggregation,
@@ -49,28 +52,14 @@ class HybridRouter(BaseRouter):
         self._set_sparse_encoder(sparse_encoder=sparse_encoder)
         # set alpha
         self.alpha = alpha
-        # create copy of routes
-        routes_copy = routes.copy()
         # fit sparse encoder if needed
         if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
             self.sparse_encoder, "fit"
         ):
-            self.sparse_encoder.fit(routes_copy)
-        # initialize index if not provided
-        self._set_index(index=index)
-        # add routes if we have them
-        if routes_copy:
-            for route in routes_copy:
-                self.add(route)
-        # set score threshold using default method
-        self._set_score_threshold()  # TODO: we can't really use this with hybrid...
-
-    def _set_index(self, index: Optional[HybridLocalIndex]):
-        if index is None:
-            logger.warning("No index provided. Using default HybridLocalIndex.")
-            self.index = HybridLocalIndex()
-        else:
-            self.index = index
+            self.sparse_encoder.fit(self.routes)
+        # run initialize index now if auto sync is active
+        if self.auto_sync:
+            self._init_index_state()
     
     def _set_sparse_encoder(self, sparse_encoder: Optional[BaseEncoder]):
         if sparse_encoder is None:
@@ -121,7 +110,7 @@ class HybridRouter(BaseRouter):
         vector: Optional[List[float]] = None,
         simulate_static: bool = False,
         route_filter: Optional[List[str]] = None,
-        sparse_vector: Optional[dict[int, float]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> RouteChoice:
         # if no vector provided, encode text to get vector
         if vector is None:
@@ -148,22 +137,6 @@ class HybridRouter(BaseRouter):
         else:
             return RouteChoice()
 
-    def add(self, route: Route):
-        self.routes += [route]
-
-        route_names = [route.name] * len(route.utterances)
-
-        # create embeddings for all routes
-        dense_embeds, sparse_embeds = self._encode(route.utterances)
-        self.index.add(
-            embeddings=dense_embeds,
-            sparse_embeddings=sparse_embeds,
-            routes=route_names,  # TODO: aligning names of routes v route_names
-            utterances=route.utterances,
-        )
-        # TODO: in some places we say vector, sparse_vector and in others
-        # TODO: we say embeddings, sparse_embeddings
-
     def _convex_scaling(self, dense: np.ndarray, sparse: list[dict[int, float]]):
         # scale sparse and dense vecs
         scaled_dense = np.array(dense) * self.alpha
diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py
index 45493e67..f0df2717 100644
--- a/semantic_router/routers/semantic.py
+++ b/semantic_router/routers/semantic.py
@@ -53,12 +53,6 @@ def is_valid(layer_config: str) -> bool:
 
 
 class SemanticRouter(BaseRouter):
-    index: BaseIndex = Field(default_factory=LocalIndex)
-
-    @validator("index", pre=True, always=True)
-    def set_index(cls, v):
-        return v if v is not None else LocalIndex()
-
     def __init__(
         self,
         encoder: Optional[BaseEncoder] = None,
@@ -72,56 +66,15 @@ class SemanticRouter(BaseRouter):
         super().__init__(
             encoder=encoder,
             llm=llm,
-            routes=routes.copy() if routes else [],
+            routes=routes if routes else [],
             index=index,
             top_k=top_k,
             aggregation=aggregation,
             auto_sync=auto_sync,
         )
-        if encoder is None:
-            logger.warning(
-                "No encoder provided. Using default OpenAIEncoder. Ensure "
-                "that you have set OPENAI_API_KEY in your environment."
-            )
-            self.encoder = OpenAIEncoder()
-        else:
-            self.encoder = encoder
-        self.llm = llm
-        self.routes = routes if routes else []
-        # set score threshold using default method
-        self._set_score_threshold()
-        self.top_k = top_k
-        if self.top_k < 1:
-            raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.")
-        self.aggregation = aggregation
-        if self.aggregation not in ["sum", "mean", "max"]:
-            raise ValueError(
-                f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'."
-            )
-        self.aggregation_method = self._set_aggregation_method(self.aggregation)
-        self.auto_sync = auto_sync
-
-        # set route score thresholds if not already set
-        for route in self.routes:
-            if route.score_threshold is None:
-                route.score_threshold = self.score_threshold
-        # if routes list has been passed, we initialize index now
+        # run initialize index now if auto sync is active
         if self.auto_sync:
-            # initialize index now, check if we need dimensions
-            if self.index.dimensions is None:
-                dims = len(self.encoder(["test"])[0])
-                self.index.dimensions = dims
-            # now init index
-            if isinstance(self.index, PineconeIndex):
-                self.index.index = self.index._init_index(force_create=True)
-            local_utterances = self.to_config().to_utterances()
-            remote_utterances = self.index.get_utterances()
-            diff = UtteranceDiff.from_utterances(
-                local_utterances=local_utterances,
-                remote_utterances=remote_utterances,
-            )
-            sync_strategy = diff.get_sync_strategy(self.auto_sync)
-            self._execute_sync_strategy(sync_strategy)
+            self._init_index_state()
 
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
         matching_route = next(
@@ -331,8 +284,6 @@ class SemanticRouter(BaseRouter):
                     new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
                 new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas
                 new_routes[utt_obj.route].metadata = utt_obj.metadata
-        temp = "\n".join([f"{name}: {r.utterances}" for name, r in new_routes.items()])
-        logger.warning("TEMP | _local_upsert:\n" + temp)
         self.routes = list(new_routes.values())
 
     def _local_delete(self, utterances: List[Utterance]):
@@ -345,8 +296,6 @@ class SemanticRouter(BaseRouter):
         route_dict: dict[str, List[str]] = {}
         for utt in utterances:
             route_dict.setdefault(utt.route, []).append(utt.utterance)
-        temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()])
-        logger.warning("TEMP | _local_delete:\n" + temp)
         # iterate over current routes and delete specific utterance if found
         new_routes = []
         for route in self.routes:
@@ -368,17 +317,9 @@ class SemanticRouter(BaseRouter):
                             metadata=route.metadata,
                         )
                     )
-                logger.warning(
-                    f"TEMP | _local_delete OLD | {route.name}: {route.utterances}"
-                )
-                logger.warning(
-                    f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}"
-                )
             else:
                 # the route is not in the route_dict, so we keep it as is
                 new_routes.append(route)
-        temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()])
-        logger.warning("TEMP | _local_delete:\n" + temp)
 
         self.routes = new_routes
 
@@ -449,41 +390,6 @@ class SemanticRouter(BaseRouter):
         encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
         return cls(encoder=encoder, routes=config.routes, index=index)
 
-    def add(self, route: Route):
-        """Add a route to the local SemanticRouter and index.
-
-        :param route: The route to add.
-        :type route: Route
-        """
-        current_local_hash = self._get_hash()
-        current_remote_hash = self.index._read_hash()
-        if current_remote_hash.value == "":
-            # if remote hash is empty, the index is to be initialized
-            current_remote_hash = current_local_hash
-        embedded_utterances = self.encoder(route.utterances)
-        self.index.add(
-            embeddings=embedded_utterances,
-            routes=[route.name] * len(route.utterances),
-            utterances=route.utterances,
-            function_schemas=(
-                route.function_schemas * len(route.utterances)
-                if route.function_schemas
-                else [{}] * len(route.utterances)
-            ),
-            metadata_list=[route.metadata if route.metadata else {}]
-            * len(route.utterances),
-        )
-
-        self.routes.append(route)
-        if current_local_hash.value == current_remote_hash.value:
-            self._write_hash()  # update current hash in index
-        else:
-            logger.warning(
-                "Local and remote route layers were not aligned. Remote hash "
-                "not updated. Use `SemanticRouter.get_utterance_diff()` to see "
-                "details."
-            )
-
     def list_route_names(self) -> List[str]:
         return [route.name for route in self.routes]
 
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index eea86b2e..2d00572f 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -4,7 +4,7 @@ from enum import Enum
 from typing import List, Optional, Union, Any, Dict, Tuple
 from pydantic.v1 import BaseModel, Field
 from semantic_router.utils.logger import logger
-
+from aurelio_sdk.schema import BM25Embedding
 
 class EncoderType(Enum):
     AURELIO = "aurelio"
@@ -404,3 +404,39 @@ class Metric(Enum):
     DOTPRODUCT = "dotproduct"
     EUCLIDEAN = "euclidean"
     MANHATTAN = "manhattan"
+
+
+class SparseValue(BaseModel):
+    index: int
+    value: float
+
+
+class SparseEmbedding(BaseModel):
+    embedding: List[SparseValue]
+
+    def to_dict(self):
+        return {x.index: x.value for x in self.embedding}
+    
+    def to_pinecone(self):
+        return {
+            "indices": [x.index for x in self.embedding],
+            "values": [x.value for x in self.embedding],
+        }
+    
+    @classmethod
+    def from_dict(cls, sparse_dict: dict):
+        return cls(embedding=[SparseValue(index=i, value=v) for i, v in sparse_dict.items()])
+    
+    @classmethod
+    def from_aurelio(cls, embedding: BM25Embedding):
+        return cls(embedding=[
+            SparseValue(
+                index=i,
+                value=v
+            ) for i, v in zip(embedding.indices, embedding.values)
+        ])
+    
+    # dictionary interface
+    def items(self):
+        return [(x.index, x.value) for x in self.embedding]
+
-- 
GitLab