From 981b039080253a68c11af8f0680c79c4f887e9f0 Mon Sep 17 00:00:00 2001
From: jamescalam <james.briggs@hotmail.com>
Date: Fri, 29 Nov 2024 15:32:30 +0100
Subject: [PATCH] fix: hybrid

---
 semantic_router/encoders/aurelio.py |  3 +-
 semantic_router/encoders/bm25.py    | 16 ++++++-
 semantic_router/encoders/tfidf.py   | 10 +++++
 tests/unit/encoders/test_bm25.py    | 66 ++++++++++++++++++++---------
 tests/unit/encoders/test_tfidf.py   |  9 +++-
 5 files changed, 79 insertions(+), 25 deletions(-)

diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py
index 779fe6b1..8b2501ba 100644
--- a/semantic_router/encoders/aurelio.py
+++ b/semantic_router/encoders/aurelio.py
@@ -1,5 +1,5 @@
 import os
-from typing import Any, Dict, List, Optional
+from typing import Any, List, Optional
 from pydantic.v1 import Field
 
 from aurelio_sdk import AurelioClient, AsyncAurelioClient, EmbeddingResponse
@@ -10,7 +10,6 @@ from semantic_router.schema import SparseEmbedding
 
 class AurelioSparseEncoder(SparseEncoder):
     model: Optional[Any] = None
-    idx_mapping: Optional[Dict[int, int]] = None
     client: AurelioClient = Field(default_factory=AurelioClient, exclude=True)
     async_client: AsyncAurelioClient = Field(
         default_factory=AsyncAurelioClient, exclude=True
diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py
index b5365cee..f42bf9c2 100644
--- a/semantic_router/encoders/bm25.py
+++ b/semantic_router/encoders/bm25.py
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
 from semantic_router.encoders.tfidf import TfidfEncoder
 from semantic_router.utils.logger import logger
 from semantic_router.schema import SparseEmbedding
+from semantic_router.route import Route
 
 
 class BM25Encoder(TfidfEncoder):
@@ -34,6 +35,7 @@ class BM25Encoder(TfidfEncoder):
             self._set_idx_mapping()
 
     def _set_idx_mapping(self):
+        # TODO JB: this is training the model somehow - not sure how...
         params = self.model.get_params()
         doc_freq = params["doc_freq"]
         if isinstance(doc_freq, dict):
@@ -42,8 +44,20 @@ class BM25Encoder(TfidfEncoder):
         else:
             raise TypeError("Expected a dictionary for 'doc_freq'")
 
+    def fit(self, routes: List[Route]):
+        """Trains the encoder weights on the provided routes.
+
+        :param routes: List of routes to train the encoder on.
+        :type routes: List[Route]
+        """
+        self._fit_validate(routes=routes)
+        if self.model is None:
+            raise ValueError("Model is not initialized.")
+        utterances = [utterance for route in routes for utterance in route.utterances]
+        self.model.fit(corpus=utterances)
+
     def __call__(self, docs: List[str]) -> list[SparseEmbedding]:
-        if self.model is None or self.idx_mapping is None:
+        if self.model is None:
             raise ValueError("Model or index mapping is not initialized.")
         if len(docs) == 1:
             sparse_dicts = self.model.encode_queries(docs)
diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py
index b865c170..fe285e5c 100644
--- a/semantic_router/encoders/tfidf.py
+++ b/semantic_router/encoders/tfidf.py
@@ -33,6 +33,12 @@ class TfidfEncoder(SparseEncoder):
         return self._array_to_sparse_embeddings(tfidf)
 
     def fit(self, routes: List[Route]):
+        """Trains the encoder weights on the provided routes.
+
+        :param routes: List of routes to train the encoder on.
+        :type routes: List[Route]
+        """
+        self._fit_validate(routes=routes)
         docs = []
         for route in routes:
             for doc in route.utterances:
@@ -42,6 +48,10 @@ class TfidfEncoder(SparseEncoder):
             raise ValueError(f"Too little data to fit {self.__class__.__name__}.")
         self.idf = self._compute_idf(docs)
 
+    def _fit_validate(self, routes: List[Route]):
+        if not isinstance(routes, list) or not isinstance(routes[0], Route):
+            raise TypeError("`routes` parameter must be a list of Route objects.")
+
     def _build_word_index(self, docs: List[str]) -> Dict:
         print(docs)
         words = set()
diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py
index 73e52d55..a22b0cc2 100644
--- a/tests/unit/encoders/test_bm25.py
+++ b/tests/unit/encoders/test_bm25.py
@@ -1,31 +1,66 @@
 import pytest
+import numpy as np
 
 from semantic_router.encoders import BM25Encoder
+from semantic_router.route import Route
+import nltk
+
+nltk.download("punkt_tab")
+
+UTTERANCES = [
+    "Hello we need this text to be a little longer for our sparse encoders",
+    "In this case they need to learn from recurring tokens, ie words.",
+    "We give ourselves several examples from our encoders to learn from.",
+    "But given this is only an example we don't need too many",
+    "Just enough to test that our sparse encoders work as expected",
+]
 
 
 @pytest.fixture
 def bm25_encoder():
     sparse_encoder = BM25Encoder(use_default_params=False)
     sparse_encoder.fit(
-        ["The quick brown fox", "jumps over the lazy dog", "Hello, world!"]
+        [
+            Route(
+                name="test_route",
+                utterances=[
+                    "The quick brown fox",
+                    "jumps over the lazy dog",
+                    "Hello, world!",
+                ],
+            )
+        ]
     )
     return sparse_encoder
 
 
+@pytest.fixture
+def routes():
+    return [
+        Route(name="Route 1", utterances=[UTTERANCES[0], UTTERANCES[1]]),
+        Route(name="Route 2", utterances=[UTTERANCES[2], UTTERANCES[3], UTTERANCES[4]]),
+    ]
+
+
 class TestBM25Encoder:
     def test_initialization(self, bm25_encoder):
-        assert len(bm25_encoder.idx_mapping) != 0
+        assert bm25_encoder.model is not None
+
+    def test_fit(self, bm25_encoder, routes):
+        bm25_encoder.fit(routes)
+        assert bm25_encoder.model is not None
 
-    def test_fit(self, bm25_encoder):
-        bm25_encoder.fit(["some docs", "and more docs", "and even more docs"])
-        assert len(bm25_encoder.idx_mapping) != 0
+    def test_fit_with_strings(self, bm25_encoder):
+        route_strings = ["test a", "test b", "test c"]
+        with pytest.raises(TypeError):
+            bm25_encoder.fit(route_strings)
 
     def test_call_method(self, bm25_encoder):
         result = bm25_encoder(["test"])
         assert isinstance(result, list), "Result should be a list"
         assert all(
-            isinstance(sublist, list) for sublist in result
-        ), "Each item in result should be a list"
+            isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result
+        ), "Each item in result should be an array"
 
     def test_call_method_no_docs_bm25_encoder(self, bm25_encoder):
         with pytest.raises(ValueError):
@@ -35,24 +70,15 @@ class TestBM25Encoder:
         result = bm25_encoder(["doc with fake word gta5jabcxyz"])
         assert isinstance(result, list), "Result should be a list"
         assert all(
-            isinstance(sublist, list) for sublist in result
-        ), "Each item in result should be a list"
-
-    def test_init_with_non_dict_doc_freq(self, mocker):
-        mock_encoder = mocker.MagicMock()
-        mock_encoder.get_params.return_value = {"doc_freq": "not a dict"}
-        mocker.patch(
-            "pinecone_text.sparse.BM25Encoder.default", return_value=mock_encoder
-        )
-        with pytest.raises(TypeError):
-            BM25Encoder()
+            isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result
+        ), "Each item in result should be an array"
 
     def test_call_method_with_uninitialized_model_or_mapping(self, bm25_encoder):
         bm25_encoder.model = None
         with pytest.raises(ValueError):
             bm25_encoder(["test"])
 
-    def test_fit_with_uninitialized_model(self, bm25_encoder):
+    def test_fit_with_uninitialized_model(self, bm25_encoder, routes):
         bm25_encoder.model = None
         with pytest.raises(ValueError):
-            bm25_encoder.fit(["test"])
+            bm25_encoder.fit(routes)
diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py
index 5052a035..94beec9f 100644
--- a/tests/unit/encoders/test_tfidf.py
+++ b/tests/unit/encoders/test_tfidf.py
@@ -38,7 +38,7 @@ class TestTfidfEncoder:
         assert isinstance(result, list), "Result should be a list"
         assert all(
             isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result
-        ), "Each item in result should be a list"
+        ), "Each item in result should be an array"
 
     def test_call_method_no_docs_tfidf_encoder(self, tfidf_encoder):
         with pytest.raises(ValueError):
@@ -56,7 +56,12 @@ class TestTfidfEncoder:
         assert isinstance(result, list), "Result should be a list"
         assert all(
             isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result
-        ), "Each item in result should be a list"
+        ), "Each item in result should be an array"
+
+    def test_fit_with_strings(self, tfidf_encoder):
+        routes = ["test a", "test b", "test c"]
+        with pytest.raises(TypeError):
+            tfidf_encoder.fit(routes)
 
     def test_call_method_with_uninitialized_model(self, tfidf_encoder):
         with pytest.raises(ValueError):
-- 
GitLab