Skip to content
Snippets Groups Projects
Commit f7f05081 authored by jamescalam's avatar jamescalam
Browse files

fix: hybrid fixes

parent 0364bab8
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ class HybridLocalIndex(LocalIndex):
utterances: List[str],
function_schemas: Optional[List[Dict[str, Any]]] = None,
metadata_list: List[Dict[str, Any]] = [],
sparse_embeddings: Optional[List[dict[int, float]]] = None,
sparse_embeddings: Optional[List[SparseEmbedding]] = None,
):
if sparse_embeddings is None:
raise ValueError("Sparse embeddings are required for HybridLocalIndex.")
......@@ -32,21 +32,27 @@ class HybridLocalIndex(LocalIndex):
logger.warning("Function schemas are not supported for HybridLocalIndex.")
if metadata_list:
logger.warning("Metadata is not supported for HybridLocalIndex.")
embeds = np.array(embeddings)
embeds = np.array(
embeddings
) # TODO: we previously had as a array, so switching back and forth seems inefficient
routes_arr = np.array(routes)
if isinstance(utterances[0], str):
utterances_arr = np.array(utterances)
else:
utterances_arr = np.array(utterances, dtype=object)
utterances_arr = np.array(
utterances, dtype=object
) # TODO: could we speed up if this were already array?
if self.index is None or self.sparse_index is None:
self.index = embeds
self.sparse_index = sparse_embeddings
self.sparse_index = [
x.to_dict() for x in sparse_embeddings
] # TODO: switch back to using SparseEmbedding later
self.routes = routes_arr
self.utterances = utterances_arr
else:
# TODO: we should probably switch to an `upsert` method and standardize elsewhere
self.index = np.concatenate([self.index, embeds])
self.sparse_index.extend(sparse_embeddings)
self.sparse_index.extend([x.to_dict() for x in sparse_embeddings])
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])
......
......@@ -335,9 +335,6 @@ class BaseRouter(BaseModel):
for route in self.routes:
if route.score_threshold is None:
route.score_threshold = self.score_threshold
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()
def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
if index is None:
......
from typing import List, Optional
from typing import Dict, List, Optional
import asyncio
from pydantic.v1 import Field
......@@ -12,7 +12,7 @@ from semantic_router.encoders import (
)
from semantic_router.route import Route
from semantic_router.index import BaseIndex, HybridLocalIndex
from semantic_router.schema import RouteChoice, SparseEmbedding
from semantic_router.schema import RouteChoice, SparseEmbedding, Utterance
from semantic_router.utils.logger import logger
from semantic_router.routers.base import BaseRouter
from semantic_router.llms import BaseLLM
......@@ -37,10 +37,13 @@ class HybridRouter(BaseRouter):
auto_sync: Optional[str] = None,
alpha: float = 0.3,
):
print("...2.1")
if index is None:
logger.warning("No index provided. Using default HybridLocalIndex.")
index = HybridLocalIndex()
print("...2.2")
encoder = self._get_encoder(encoder=encoder)
print("...2.3")
super().__init__(
encoder=encoder,
llm=llm,
......@@ -50,15 +53,22 @@ class HybridRouter(BaseRouter):
aggregation=aggregation,
auto_sync=auto_sync,
)
print("...0")
# initialize sparse encoder
self._set_sparse_encoder(sparse_encoder=sparse_encoder)
self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder)
print("...5")
# set alpha
self.alpha = alpha
print("...6")
# fit sparse encoder if needed
if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
self.sparse_encoder, "fit"
) and self.routes:
if (
isinstance(self.sparse_encoder, TfidfEncoder)
and hasattr(self.sparse_encoder, "fit")
and self.routes
):
print("...3")
self.sparse_encoder.fit(self.routes)
print("...4")
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()
......@@ -104,6 +114,39 @@ class HybridRouter(BaseRouter):
"to see details."
)
def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
"""Executes the provided sync strategy, either deleting or upserting
routes from the local and remote instances as defined in the strategy.
:param strategy: The sync strategy to execute.
:type strategy: Dict[str, Dict[str, List[Utterance]]]
"""
if strategy["remote"]["delete"]:
data_to_delete = {} # type: ignore
for utt_obj in strategy["remote"]["delete"]:
data_to_delete.setdefault(utt_obj.route, []).append(utt_obj.utterance)
# TODO: switch to remove without sync??
self.index._remove_and_sync(data_to_delete)
if strategy["remote"]["upsert"]:
utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
dense_emb, sparse_emb = self._encode(utterances_text)
self.index.add(
embeddings=dense_emb.tolist(),
routes=[utt.route for utt in strategy["remote"]["upsert"]],
utterances=utterances_text,
function_schemas=[
utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore
],
metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
sparse_embeddings=sparse_emb, # type: ignore
)
if strategy["local"]["delete"]:
self._local_delete(utterances=strategy["local"]["delete"])
if strategy["local"]["upsert"]:
self._local_upsert(utterances=strategy["local"]["upsert"])
# update hash
self._write_hash()
def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
if index is None:
logger.warning("No index provided. Using default HybridLocalIndex.")
......@@ -112,12 +155,15 @@ class HybridRouter(BaseRouter):
index = index
return index
def _set_sparse_encoder(self, sparse_encoder: Optional[SparseEncoder]):
def _get_sparse_encoder(
self, sparse_encoder: Optional[SparseEncoder]
) -> SparseEncoder:
if sparse_encoder is None:
logger.warning("No sparse_encoder provided. Using default BM25Encoder.")
self.sparse_encoder = BM25Encoder()
sparse_encoder = BM25Encoder()
else:
self.sparse_encoder = sparse_encoder
sparse_encoder = sparse_encoder
return sparse_encoder
def _encode(self, text: list[str]) -> tuple[np.ndarray, list[SparseEmbedding]]:
"""Given some text, generates dense and sparse embeddings, then scales them
......
......@@ -32,6 +32,9 @@ class SemanticRouter(BaseRouter):
aggregation=aggregation,
auto_sync=auto_sync,
)
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()
def _encode(self, text: list[str]) -> Any:
"""Given some text, encode it."""
......@@ -81,4 +84,4 @@ class SemanticRouter(BaseRouter):
"Local and remote route layers were not aligned. Remote hash "
f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
"to see details."
)
\ No newline at end of file
)
......@@ -54,30 +54,37 @@ def azure_encoder(mocker):
model="test_model",
)
@pytest.fixture
def bm25_encoder():
#mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call)
# mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call)
return BM25Encoder(name="test-bm25-encoder")
@pytest.fixture
def tfidf_encoder():
#mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call)
# mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call)
return TfidfEncoder(name="test-tfidf-encoder")
@pytest.fixture
def routes():
return [
Route(name="Route 1", utterances=[
"Hello we need this text to be a little longer for our sparse encoders",
"In this case they need to learn from recurring tokens, ie words."
]),
Route(name="Route 2", utterances=[
"We give ourselves several examples from our encoders to learn from.",
"But given this is only an example we don't need too many",
"Just enough to test that our sparse encoders work as expected"
]),
Route(
name="Route 1",
utterances=[
"Hello we need this text to be a little longer for our sparse encoders",
"In this case they need to learn from recurring tokens, ie words.",
],
),
Route(
name="Route 2",
utterances=[
"We give ourselves several examples from our encoders to learn from.",
"But given this is only an example we don't need too many",
"Just enough to test that our sparse encoders work as expected",
],
),
]
......@@ -88,7 +95,7 @@ sparse_encoder.fit(
name="Route 1",
utterances=[
"The quick brown fox jumps over the lazy dog",
"some other useful text containing words like fox and dog"
"some other useful text containing words like fox and dog",
],
),
Route(name="Route 2", utterances=["Hello, world!"]),
......@@ -143,9 +150,14 @@ class TestHybridRouter:
assert len(route_layer.routes) == 2, "route_layer.routes is not 2"
def test_query_and_classification(self, openai_encoder, routes):
print("...1")
route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes
encoder=openai_encoder,
sparse_encoder=sparse_encoder,
routes=routes,
auto_sync="local",
)
print("...2")
query_result = route_layer("Hello")
assert query_result in ["Route 1", "Route 2"]
......@@ -153,9 +165,7 @@ class TestHybridRouter:
route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder
)
assert isinstance(
route_layer.sparse_encoder, BM25Encoder
) or isinstance(
assert isinstance(route_layer.sparse_encoder, BM25Encoder) or isinstance(
route_layer.sparse_encoder, TfidfEncoder
), (
f"route_layer.sparse_encoder is {route_layer.sparse_encoder.__class__.__name__} "
......@@ -213,7 +223,9 @@ class TestHybridRouter:
utterance for route in routes for utterance in route.utterances
]
assert hybrid_route_layer.index.sparse_index is not None, "sparse_index is None"
assert len(hybrid_route_layer.index.sparse_index) == len(all_utterances), "sparse_index length mismatch"
assert len(hybrid_route_layer.index.sparse_index) == len(
all_utterances
), "sparse_index length mismatch"
def test_setting_aggregation_methods(self, openai_encoder, routes):
for agg in ["sum", "mean", "max"]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment