From f7f05081867de6e9cc629df33fedb575f431c806 Mon Sep 17 00:00:00 2001
From: jamescalam <james.briggs@hotmail.com>
Date: Fri, 29 Nov 2024 11:24:46 +0100
Subject: [PATCH] fix: hybrid fixes

---
 semantic_router/index/hybrid_local.py | 16 ++++---
 semantic_router/routers/base.py       |  3 --
 semantic_router/routers/hybrid.py     | 64 +++++++++++++++++++++++----
 semantic_router/routers/semantic.py   |  5 ++-
 tests/unit/test_hybrid_layer.py       | 46 ++++++++++++-------
 5 files changed, 99 insertions(+), 35 deletions(-)

diff --git a/semantic_router/index/hybrid_local.py b/semantic_router/index/hybrid_local.py
index 2a5a43d5..f927914e 100644
--- a/semantic_router/index/hybrid_local.py
+++ b/semantic_router/index/hybrid_local.py
@@ -24,7 +24,7 @@ class HybridLocalIndex(LocalIndex):
         utterances: List[str],
         function_schemas: Optional[List[Dict[str, Any]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
-        sparse_embeddings: Optional[List[dict[int, float]]] = None,
+        sparse_embeddings: Optional[List[SparseEmbedding]] = None,
     ):
         if sparse_embeddings is None:
             raise ValueError("Sparse embeddings are required for HybridLocalIndex.")
@@ -32,21 +32,27 @@ class HybridLocalIndex(LocalIndex):
             logger.warning("Function schemas are not supported for HybridLocalIndex.")
         if metadata_list:
             logger.warning("Metadata is not supported for HybridLocalIndex.")
-        embeds = np.array(embeddings)
+        embeds = np.array(
+            embeddings
+        )  # TODO: we previously had as a array, so switching back and forth seems inefficient
         routes_arr = np.array(routes)
         if isinstance(utterances[0], str):
             utterances_arr = np.array(utterances)
         else:
-            utterances_arr = np.array(utterances, dtype=object)
+            utterances_arr = np.array(
+                utterances, dtype=object
+            )  # TODO: could we speed up if this were already array?
         if self.index is None or self.sparse_index is None:
             self.index = embeds
-            self.sparse_index = sparse_embeddings
+            self.sparse_index = [
+                x.to_dict() for x in sparse_embeddings
+            ]  # TODO: switch back to using SparseEmbedding later
             self.routes = routes_arr
             self.utterances = utterances_arr
         else:
             # TODO: we should probably switch to an `upsert` method and standardize elsewhere
             self.index = np.concatenate([self.index, embeds])
-            self.sparse_index.extend(sparse_embeddings)
+            self.sparse_index.extend([x.to_dict() for x in sparse_embeddings])
             self.routes = np.concatenate([self.routes, routes_arr])
             self.utterances = np.concatenate([self.utterances, utterances_arr])
 
diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py
index 615d4699..392e91d8 100644
--- a/semantic_router/routers/base.py
+++ b/semantic_router/routers/base.py
@@ -335,9 +335,6 @@ class BaseRouter(BaseModel):
         for route in self.routes:
             if route.score_threshold is None:
                 route.score_threshold = self.score_threshold
-        # run initialize index now if auto sync is active
-        if self.auto_sync:
-            self._init_index_state()
 
     def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
         if index is None:
diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py
index f74e3e6b..e603add3 100644
--- a/semantic_router/routers/hybrid.py
+++ b/semantic_router/routers/hybrid.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Dict, List, Optional
 import asyncio
 from pydantic.v1 import Field
 
@@ -12,7 +12,7 @@ from semantic_router.encoders import (
 )
 from semantic_router.route import Route
 from semantic_router.index import BaseIndex, HybridLocalIndex
-from semantic_router.schema import RouteChoice, SparseEmbedding
+from semantic_router.schema import RouteChoice, SparseEmbedding, Utterance
 from semantic_router.utils.logger import logger
 from semantic_router.routers.base import BaseRouter
 from semantic_router.llms import BaseLLM
@@ -37,10 +37,13 @@ class HybridRouter(BaseRouter):
         auto_sync: Optional[str] = None,
         alpha: float = 0.3,
     ):
+        print("...2.1")
         if index is None:
             logger.warning("No index provided. Using default HybridLocalIndex.")
             index = HybridLocalIndex()
+        print("...2.2")
         encoder = self._get_encoder(encoder=encoder)
+        print("...2.3")
         super().__init__(
             encoder=encoder,
             llm=llm,
@@ -50,15 +53,22 @@ class HybridRouter(BaseRouter):
             aggregation=aggregation,
             auto_sync=auto_sync,
         )
+        print("...0")
         # initialize sparse encoder
-        self._set_sparse_encoder(sparse_encoder=sparse_encoder)
+        self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder)
+        print("...5")
         # set alpha
         self.alpha = alpha
+        print("...6")
         # fit sparse encoder if needed
-        if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
-            self.sparse_encoder, "fit"
-        ) and self.routes:
+        if (
+            isinstance(self.sparse_encoder, TfidfEncoder)
+            and hasattr(self.sparse_encoder, "fit")
+            and self.routes
+        ):
+            print("...3")
             self.sparse_encoder.fit(self.routes)
+            print("...4")
         # run initialize index now if auto sync is active
         if self.auto_sync:
             self._init_index_state()
@@ -104,6 +114,39 @@ class HybridRouter(BaseRouter):
                 "to see details."
             )
 
+    def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
+        """Executes the provided sync strategy, either deleting or upserting
+        routes from the local and remote instances as defined in the strategy.
+
+        :param strategy: The sync strategy to execute.
+        :type strategy: Dict[str, Dict[str, List[Utterance]]]
+        """
+        if strategy["remote"]["delete"]:
+            data_to_delete = {}  # type: ignore
+            for utt_obj in strategy["remote"]["delete"]:
+                data_to_delete.setdefault(utt_obj.route, []).append(utt_obj.utterance)
+            # TODO: switch to remove without sync??
+            self.index._remove_and_sync(data_to_delete)
+        if strategy["remote"]["upsert"]:
+            utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
+            dense_emb, sparse_emb = self._encode(utterances_text)
+            self.index.add(
+                embeddings=dense_emb.tolist(),
+                routes=[utt.route for utt in strategy["remote"]["upsert"]],
+                utterances=utterances_text,
+                function_schemas=[
+                    utt.function_schemas for utt in strategy["remote"]["upsert"]  # type: ignore
+                ],
+                metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
+                sparse_embeddings=sparse_emb,  # type: ignore
+            )
+        if strategy["local"]["delete"]:
+            self._local_delete(utterances=strategy["local"]["delete"])
+        if strategy["local"]["upsert"]:
+            self._local_upsert(utterances=strategy["local"]["upsert"])
+        # update hash
+        self._write_hash()
+
     def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
         if index is None:
             logger.warning("No index provided. Using default HybridLocalIndex.")
@@ -112,12 +155,15 @@ class HybridRouter(BaseRouter):
             index = index
         return index
 
-    def _set_sparse_encoder(self, sparse_encoder: Optional[SparseEncoder]):
+    def _get_sparse_encoder(
+        self, sparse_encoder: Optional[SparseEncoder]
+    ) -> SparseEncoder:
         if sparse_encoder is None:
             logger.warning("No sparse_encoder provided. Using default BM25Encoder.")
-            self.sparse_encoder = BM25Encoder()
+            sparse_encoder = BM25Encoder()
         else:
-            self.sparse_encoder = sparse_encoder
+            sparse_encoder = sparse_encoder
+        return sparse_encoder
 
     def _encode(self, text: list[str]) -> tuple[np.ndarray, list[SparseEmbedding]]:
         """Given some text, generates dense and sparse embeddings, then scales them
diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py
index 94c3e179..33af2a32 100644
--- a/semantic_router/routers/semantic.py
+++ b/semantic_router/routers/semantic.py
@@ -32,6 +32,9 @@ class SemanticRouter(BaseRouter):
             aggregation=aggregation,
             auto_sync=auto_sync,
         )
+        # run initialize index now if auto sync is active
+        if self.auto_sync:
+            self._init_index_state()
 
     def _encode(self, text: list[str]) -> Any:
         """Given some text, encode it."""
@@ -81,4 +84,4 @@ class SemanticRouter(BaseRouter):
                 "Local and remote route layers were not aligned. Remote hash "
                 f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
                 "to see details."
-            )
\ No newline at end of file
+            )
diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py
index aadad86a..a7d29b46 100644
--- a/tests/unit/test_hybrid_layer.py
+++ b/tests/unit/test_hybrid_layer.py
@@ -54,30 +54,37 @@ def azure_encoder(mocker):
         model="test_model",
     )
 
+
 @pytest.fixture
 def bm25_encoder():
-    #mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call)
+    # mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call)
     return BM25Encoder(name="test-bm25-encoder")
 
 
 @pytest.fixture
 def tfidf_encoder():
-    #mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call)
+    # mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call)
     return TfidfEncoder(name="test-tfidf-encoder")
 
 
 @pytest.fixture
 def routes():
     return [
-        Route(name="Route 1", utterances=[
-            "Hello we need this text to be a little longer for our sparse encoders",
-            "In this case they need to learn from recurring tokens, ie words."
-        ]),
-        Route(name="Route 2", utterances=[
-            "We give ourselves several examples from our encoders to learn from.",
-            "But given this is only an example we don't need too many",
-            "Just enough to test that our sparse encoders work as expected"
-        ]),
+        Route(
+            name="Route 1",
+            utterances=[
+                "Hello we need this text to be a little longer for our sparse encoders",
+                "In this case they need to learn from recurring tokens, ie words.",
+            ],
+        ),
+        Route(
+            name="Route 2",
+            utterances=[
+                "We give ourselves several examples from our encoders to learn from.",
+                "But given this is only an example we don't need too many",
+                "Just enough to test that our sparse encoders work as expected",
+            ],
+        ),
     ]
 
 
@@ -88,7 +95,7 @@ sparse_encoder.fit(
             name="Route 1",
             utterances=[
                 "The quick brown fox jumps over the lazy dog",
-                "some other useful text containing words like fox and dog"
+                "some other useful text containing words like fox and dog",
             ],
         ),
         Route(name="Route 2", utterances=["Hello, world!"]),
@@ -143,9 +150,14 @@ class TestHybridRouter:
         assert len(route_layer.routes) == 2, "route_layer.routes is not 2"
 
     def test_query_and_classification(self, openai_encoder, routes):
+        print("...1")
         route_layer = HybridRouter(
-            encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes
+            encoder=openai_encoder,
+            sparse_encoder=sparse_encoder,
+            routes=routes,
+            auto_sync="local",
         )
+        print("...2")
         query_result = route_layer("Hello")
         assert query_result in ["Route 1", "Route 2"]
 
@@ -153,9 +165,7 @@ class TestHybridRouter:
         route_layer = HybridRouter(
             encoder=openai_encoder, sparse_encoder=sparse_encoder
         )
-        assert isinstance(
-            route_layer.sparse_encoder, BM25Encoder
-        ) or isinstance(
+        assert isinstance(route_layer.sparse_encoder, BM25Encoder) or isinstance(
             route_layer.sparse_encoder, TfidfEncoder
         ), (
             f"route_layer.sparse_encoder is {route_layer.sparse_encoder.__class__.__name__} "
@@ -213,7 +223,9 @@ class TestHybridRouter:
             utterance for route in routes for utterance in route.utterances
         ]
         assert hybrid_route_layer.index.sparse_index is not None, "sparse_index is None"
-        assert len(hybrid_route_layer.index.sparse_index) == len(all_utterances), "sparse_index length mismatch"
+        assert len(hybrid_route_layer.index.sparse_index) == len(
+            all_utterances
+        ), "sparse_index length mismatch"
 
     def test_setting_aggregation_methods(self, openai_encoder, routes):
         for agg in ["sum", "mean", "max"]:
-- 
GitLab