From 981b039080253a68c11af8f0680c79c4f887e9f0 Mon Sep 17 00:00:00 2001 From: jamescalam <james.briggs@hotmail.com> Date: Fri, 29 Nov 2024 15:32:30 +0100 Subject: [PATCH] fix: hybrid --- semantic_router/encoders/aurelio.py | 3 +- semantic_router/encoders/bm25.py | 16 ++++++- semantic_router/encoders/tfidf.py | 10 +++++ tests/unit/encoders/test_bm25.py | 66 ++++++++++++++++++++--------- tests/unit/encoders/test_tfidf.py | 9 +++- 5 files changed, 79 insertions(+), 25 deletions(-) diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index 779fe6b1..8b2501ba 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from pydantic.v1 import Field from aurelio_sdk import AurelioClient, AsyncAurelioClient, EmbeddingResponse @@ -10,7 +10,6 @@ from semantic_router.schema import SparseEmbedding class AurelioSparseEncoder(SparseEncoder): model: Optional[Any] = None - idx_mapping: Optional[Dict[int, int]] = None client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) async_client: AsyncAurelioClient = Field( default_factory=AsyncAurelioClient, exclude=True diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index b5365cee..f42bf9c2 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from semantic_router.encoders.tfidf import TfidfEncoder from semantic_router.utils.logger import logger from semantic_router.schema import SparseEmbedding +from semantic_router.route import Route class BM25Encoder(TfidfEncoder): @@ -34,6 +35,7 @@ class BM25Encoder(TfidfEncoder): self._set_idx_mapping() def _set_idx_mapping(self): + # TODO JB: this is training the model somehow - not sure how... params = self.model.get_params() doc_freq = params["doc_freq"] if isinstance(doc_freq, dict): @@ -42,8 +44,20 @@ class BM25Encoder(TfidfEncoder): else: raise TypeError("Expected a dictionary for 'doc_freq'") + def fit(self, routes: List[Route]): + """Trains the encoder weights on the provided routes. + + :param routes: List of routes to train the encoder on. + :type routes: List[Route] + """ + self._fit_validate(routes=routes) + if self.model is None: + raise ValueError("Model is not initialized.") + utterances = [utterance for route in routes for utterance in route.utterances] + self.model.fit(corpus=utterances) + def __call__(self, docs: List[str]) -> list[SparseEmbedding]: - if self.model is None or self.idx_mapping is None: + if self.model is None: raise ValueError("Model or index mapping is not initialized.") if len(docs) == 1: sparse_dicts = self.model.encode_queries(docs) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index b865c170..fe285e5c 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -33,6 +33,12 @@ class TfidfEncoder(SparseEncoder): return self._array_to_sparse_embeddings(tfidf) def fit(self, routes: List[Route]): + """Trains the encoder weights on the provided routes. + + :param routes: List of routes to train the encoder on. + :type routes: List[Route] + """ + self._fit_validate(routes=routes) docs = [] for route in routes: for doc in route.utterances: @@ -42,6 +48,10 @@ class TfidfEncoder(SparseEncoder): raise ValueError(f"Too little data to fit {self.__class__.__name__}.") self.idf = self._compute_idf(docs) + def _fit_validate(self, routes: List[Route]): + if not isinstance(routes, list) or not isinstance(routes[0], Route): + raise TypeError("`routes` parameter must be a list of Route objects.") + def _build_word_index(self, docs: List[str]) -> Dict: print(docs) words = set() diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py index 73e52d55..a22b0cc2 100644 --- a/tests/unit/encoders/test_bm25.py +++ b/tests/unit/encoders/test_bm25.py @@ -1,31 +1,66 @@ import pytest +import numpy as np from semantic_router.encoders import BM25Encoder +from semantic_router.route import Route +import nltk + +nltk.download("punkt_tab") + +UTTERANCES = [ + "Hello we need this text to be a little longer for our sparse encoders", + "In this case they need to learn from recurring tokens, ie words.", + "We give ourselves several examples from our encoders to learn from.", + "But given this is only an example we don't need too many", + "Just enough to test that our sparse encoders work as expected", +] @pytest.fixture def bm25_encoder(): sparse_encoder = BM25Encoder(use_default_params=False) sparse_encoder.fit( - ["The quick brown fox", "jumps over the lazy dog", "Hello, world!"] + [ + Route( + name="test_route", + utterances=[ + "The quick brown fox", + "jumps over the lazy dog", + "Hello, world!", + ], + ) + ] ) return sparse_encoder +@pytest.fixture +def routes(): + return [ + Route(name="Route 1", utterances=[UTTERANCES[0], UTTERANCES[1]]), + Route(name="Route 2", utterances=[UTTERANCES[2], UTTERANCES[3], UTTERANCES[4]]), + ] + + class TestBM25Encoder: def test_initialization(self, bm25_encoder): - assert len(bm25_encoder.idx_mapping) != 0 + assert bm25_encoder.model is not None + + def test_fit(self, bm25_encoder, routes): + bm25_encoder.fit(routes) + assert bm25_encoder.model is not None - def test_fit(self, bm25_encoder): - bm25_encoder.fit(["some docs", "and more docs", "and even more docs"]) - assert len(bm25_encoder.idx_mapping) != 0 + def test_fit_with_strings(self, bm25_encoder): + route_strings = ["test a", "test b", "test c"] + with pytest.raises(TypeError): + bm25_encoder.fit(route_strings) def test_call_method(self, bm25_encoder): result = bm25_encoder(["test"]) assert isinstance(result, list), "Result should be a list" assert all( - isinstance(sublist, list) for sublist in result - ), "Each item in result should be a list" + isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result + ), "Each item in result should be an array" def test_call_method_no_docs_bm25_encoder(self, bm25_encoder): with pytest.raises(ValueError): @@ -35,24 +70,15 @@ class TestBM25Encoder: result = bm25_encoder(["doc with fake word gta5jabcxyz"]) assert isinstance(result, list), "Result should be a list" assert all( - isinstance(sublist, list) for sublist in result - ), "Each item in result should be a list" - - def test_init_with_non_dict_doc_freq(self, mocker): - mock_encoder = mocker.MagicMock() - mock_encoder.get_params.return_value = {"doc_freq": "not a dict"} - mocker.patch( - "pinecone_text.sparse.BM25Encoder.default", return_value=mock_encoder - ) - with pytest.raises(TypeError): - BM25Encoder() + isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result + ), "Each item in result should be an array" def test_call_method_with_uninitialized_model_or_mapping(self, bm25_encoder): bm25_encoder.model = None with pytest.raises(ValueError): bm25_encoder(["test"]) - def test_fit_with_uninitialized_model(self, bm25_encoder): + def test_fit_with_uninitialized_model(self, bm25_encoder, routes): bm25_encoder.model = None with pytest.raises(ValueError): - bm25_encoder.fit(["test"]) + bm25_encoder.fit(routes) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 5052a035..94beec9f 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -38,7 +38,7 @@ class TestTfidfEncoder: assert isinstance(result, list), "Result should be a list" assert all( isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result - ), "Each item in result should be a list" + ), "Each item in result should be an array" def test_call_method_no_docs_tfidf_encoder(self, tfidf_encoder): with pytest.raises(ValueError): @@ -56,7 +56,12 @@ class TestTfidfEncoder: assert isinstance(result, list), "Result should be a list" assert all( isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result - ), "Each item in result should be a list" + ), "Each item in result should be an array" + + def test_fit_with_strings(self, tfidf_encoder): + routes = ["test a", "test b", "test c"] + with pytest.raises(TypeError): + tfidf_encoder.fit(routes) def test_call_method_with_uninitialized_model(self, tfidf_encoder): with pytest.raises(ValueError): -- GitLab