From bdfe5ea123a15e14dd1130a6764a38846aa09b05 Mon Sep 17 00:00:00 2001 From: Bogdan Buduroiu <bogdan@buduroiu.com> Date: Mon, 17 Feb 2025 22:24:08 +0800 Subject: [PATCH] chore: add `input_type` to encoders --- semantic_router/encoders/base.py | 41 +++++++++++++++--- semantic_router/encoders/bm25.py | 26 ++++++++++- semantic_router/encoders/encode_input_type.py | 3 ++ semantic_router/encoders/tfidf.py | 31 ++++++++----- semantic_router/routers/base.py | 20 ++++++--- semantic_router/routers/hybrid.py | 43 ++++++++++++------- 6 files changed, 126 insertions(+), 38 deletions(-) create mode 100644 semantic_router/encoders/encode_input_type.py diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index a9b460ad..3c14f945 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -3,6 +3,7 @@ from typing import Any, Coroutine, List, Optional import numpy as np from pydantic import BaseModel, Field, field_validator +from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.route import Route from semantic_router.schema import SparseEmbedding @@ -26,25 +27,33 @@ class DenseEncoder(BaseModel): """ return float(v) if v is not None else None - def __call__(self, docs: List[Any]) -> List[List[float]]: + def __call__( + self, docs: List[Any], input_type: EncodeInputType + ) -> List[List[float]]: """Encode a list of documents. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: Literal["queries", "documents"] :return: The encoded documents. :rtype: List[List[float]] """ raise NotImplementedError("Subclasses must implement this method") - def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]: + def acall( + self, docs: List[Any], input_type: EncodeInputType + ) -> Coroutine[Any, Any, List[List[float]]]: """Encode a list of documents asynchronously. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The encoded documents. :rtype: List[List[float]] """ @@ -60,13 +69,35 @@ class SparseEncoder(BaseModel): class Config: arbitrary_types_allowed = True - def __call__(self, docs: List[str]) -> List[SparseEmbedding]: + def __call__( + self, + docs: List[str], + input_type: Optional[EncodeInputType] = "queries", + ) -> List[SparseEmbedding]: """Sparsely encode a list of documents. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType + :return: The encoded documents. + :rtype: List[SparseEmbedding] + """ + raise NotImplementedError("Subclasses must implement this method") + + def acall( + self, docs: List[Any], input_type: EncodeInputType + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + """Encode a list of documents asynchronously. Documents can be any type, but the + encoder must be built to handle that data type. Typically, these types are + strings or arrays representing images. + + :param docs: The documents to encode. + :type docs: List[Any] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: Literal["queries", "documents"] :return: The encoded documents. :rtype: List[SparseEmbedding] """ @@ -80,11 +111,11 @@ class SparseEncoder(BaseModel): """Convert document texts to sparse embeddings optimized for storage""" raise NotImplementedError("Subclasses must implement this method") - async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_queries(self, docs: List[str]) -> List[SparseEmbedding]: """Async version of encode_queries""" raise NotImplementedError("Subclasses must implement this method") - async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_documents(self, docs: List[str]) -> List[SparseEmbedding]: """Async version of encode_documents""" raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 5377e2e7..00c6039d 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -1,5 +1,6 @@ +import asyncio from functools import partial -from typing import List +from typing import Any, Coroutine, List, Literal import numpy as np @@ -269,5 +270,26 @@ class BM25Encoder(SparseEncoder, FittableMixin): return self.encode_queries(docs) - def __call__(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]: + # While this is a CPU-bound operation, and doesn't benefit from asyncio + # we provide this method to abide by the `SparseEncoder` superclass return self.encode_queries(docs) + + async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]: + # While this is a CPU-bound operation, and doesn't benefit from asyncio + # we provide this method to abide by the `SparseEncoder` superclass + return self.encode_documents(docs) + + def __call__( + self, docs: List[str], input_type: Literal["queries", "documents"] + ) -> list[SparseEmbedding]: + match input_type: + case "queries": + return self.encode_queries(docs) + case "documents": + return self.encode_documents(docs) + + def acall( + self, docs: List[Any], input_type: Literal["queries", "documents"] + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + return asyncio.to_thread(lambda: self.__call__(docs, input_type)) diff --git a/semantic_router/encoders/encode_input_type.py b/semantic_router/encoders/encode_input_type.py new file mode 100644 index 00000000..1cfe9c87 --- /dev/null +++ b/semantic_router/encoders/encode_input_type.py @@ -0,0 +1,3 @@ +from typing import Literal + +EncodeInputType = Literal["queries", "documents"] diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 0d6642b7..3d3cbe57 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -1,11 +1,13 @@ +import asyncio import string from collections import Counter -from typing import Dict, List +from typing import Any, Coroutine, Dict, List import numpy as np from semantic_router.encoders import SparseEncoder from semantic_router.encoders.base import FittableMixin +from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.route import Route from semantic_router.schema import SparseEmbedding @@ -22,7 +24,9 @@ class TfidfEncoder(SparseEncoder, FittableMixin): self.word_index = {} self.idf = np.array([]) - def __call__(self, docs: List[str]) -> list[SparseEmbedding]: + def __call__( + self, docs: List[str], input_type: EncodeInputType + ) -> list[SparseEmbedding]: if len(self.word_index) == 0 or self.idf.size == 0: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: @@ -33,21 +37,28 @@ class TfidfEncoder(SparseEncoder, FittableMixin): tfidf = tf * self.idf return self._array_to_sparse_embeddings(tfidf) - def encode_queries(self, docs: List[str]) -> list[SparseEmbedding]: + async def acall( + self, docs: List[str], input_type: EncodeInputType + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + return asyncio.to_thread(lambda: self.__call__(docs, input_type)) + + def encode_queries(self, docs: List[str]) -> List[SparseEmbedding]: """Encode documents using TF-IDF""" - return self.__call__(docs) # TF-IDF uses same method for docs and queries + # TF-IDF uses same method for docs and queries + return self.__call__(docs, input_type="queries") - def encode_documents(self, docs: List[str]) -> list[SparseEmbedding]: + def encode_documents(self, docs: List[str]) -> List[SparseEmbedding]: """Encode documents using TF-IDF""" - return self.__call__(docs) # TF-IDF uses same method for docs and queries + # TF-IDF uses same method for docs and queries + return self.__call__(docs, input_type="documents") - async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_queries(self, docs: List[str]) -> List[SparseEmbedding]: """Async version of encode_queries""" - return self.__call__(docs) + return self.__call__(docs, input_type="queries") - async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_documents(self, docs: List[str]) -> List[SparseEmbedding]: """Async version of encode_documents""" - return self.__call__(docs) + return self.__call__(docs, input_type="documents") def fit(self, routes: List[Route]): """Trains the encoder weights on the provided routes. diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 3d5c18af..61229b1e 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -3,7 +3,7 @@ import importlib import json import os import random -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import yaml # type: ignore @@ -822,7 +822,8 @@ class BaseRouter(BaseModel): routes=[utt.route for utt in strategy["remote"]["upsert"]], utterances=utterances_text, function_schemas=[ - utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore + utt.function_schemas + for utt in strategy["remote"]["upsert"] # type: ignore ], metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]], ) @@ -855,7 +856,8 @@ class BaseRouter(BaseModel): routes=[utt.route for utt in strategy["remote"]["upsert"]], utterances=utterances_text, function_schemas=[ - utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore + utt.function_schemas + for utt in strategy["remote"]["upsert"] # type: ignore ], metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]], ) @@ -1333,26 +1335,34 @@ class BaseRouter(BaseModel): return route_names, utterances, function_schemas, metadata return route_names, utterances, function_schemas - def _encode(self, text: list[str]) -> Any: + def _encode( + self, text: list[str], input_type: Literal["queries", "documents"] + ) -> Any: """Generates embeddings for a given text. Must be implemented by a subclass. :param text: The text to encode. :type text: list[str] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: Literal["queries", "documents"] :return: The embeddings of the text. :rtype: Any """ # TODO: should encode "content" rather than text raise NotImplementedError("This method should be implemented by subclasses.") - async def _async_encode(self, text: list[str]) -> Any: + async def _async_encode( + self, text: list[str], input_type: Literal["queries", "documents"] + ) -> Any: """Asynchronously generates embeddings for a given text. Must be implemented by a subclass. :param text: The text to encode. :type text: list[str] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: Literal["queries", "documents"] :return: The embeddings of the text. :rtype: Any """ diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index 9c6680c6..ec210f25 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -11,6 +11,7 @@ from semantic_router.encoders import ( SparseEncoder, ) from semantic_router.encoders.base import FittableMixin +from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.index import BaseIndex, HybridLocalIndex from semantic_router.llms import BaseLLM from semantic_router.route import Route @@ -112,7 +113,7 @@ class HybridRouter(BaseRouter): ) = self._extract_routes_details(routes, include_metadata=True) # TODO: to merge, self._encode should probably output a special # TODO Embedding type that can be either dense or hybrid - dense_emb, sparse_emb = self._encode(all_utterances) + dense_emb, sparse_emb = self._encode(all_utterances, input_type="documents") self.index.add( embeddings=dense_emb.tolist(), routes=route_names, @@ -148,7 +149,9 @@ class HybridRouter(BaseRouter): self.index._remove_and_sync(data_to_delete) if strategy["remote"]["upsert"]: utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]] - dense_emb, sparse_emb = self._encode(utterances_text) + dense_emb, sparse_emb = self._encode( + utterances_text, input_type="documents" + ) self.index.add( embeddings=dense_emb.tolist(), routes=[utt.route for utt in strategy["remote"]["upsert"]], @@ -202,38 +205,44 @@ class HybridRouter(BaseRouter): return sparse_encoder def _encode( - self, - text: list[str], + self, text: list[str], input_type: EncodeInputType ) -> tuple[np.ndarray, list[SparseEmbedding]]: """Given some text, generates dense and sparse embeddings, then scales them using the chosen alpha value. :param text: List of texts to encode :type text: List[str] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: Tuple of dense and sparse embeddings """ if self.sparse_encoder is None: raise ValueError("self.sparse_encoder is not set.") # Create dense vector - xq_d = np.array(self.encoder(text)) + xq_d = np.array(self.encoder(text, input_type=input_type)) # Create sparse vector, handling BM25's query/document distinction - xq_s = self.sparse_encoder.encode_documents(text) + match input_type: + case "queries": + xq_s = self.sparse_encoder.encode_queries(text) + case "documents": + xq_s = self.sparse_encoder.encode_documents(text) # Convex scaling xq_d, xq_s = self._convex_scaling(dense=xq_d, sparse=xq_s) return xq_d, xq_s async def _async_encode( - self, - text: List[str], + self, text: List[str], input_type: EncodeInputType ) -> tuple[np.ndarray, list[SparseEmbedding]]: """Given some text, generates dense and sparse embeddings, then scales them using the chosen alpha value. :param text: The text to encode. :type text: List[str] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: A tuple of the dense and sparse embeddings. :rtype: tuple[np.ndarray, list[SparseEmbedding]] """ @@ -242,8 +251,8 @@ class HybridRouter(BaseRouter): # TODO: should encode "content" rather than text # TODO: add alpha as a parameter # async encode both dense and sparse - dense_coro = self.encoder.acall(text) - sparse_coro = self.sparse_encoder.aencode_documents(text) + dense_coro = self.encoder.acall(text, input_type=input_type) + sparse_coro = self.sparse_encoder.acall(text, input_type=input_type) dense_vec, xq_s = await asyncio.gather(dense_coro, sparse_coro) # create dense query vector xq_d = np.array(dense_vec) @@ -281,8 +290,8 @@ class HybridRouter(BaseRouter): if vector is None: if text is None: raise ValueError("Either text or vector must be provided") - xq_d = np.array(self.encoder([text])) - xq_s = self.sparse_encoder.encode_queries([text]) + xq_d = np.array(self.encoder([text], input_type="queries")) + xq_s = self.sparse_encoder([text], input_type="queries") vector, potential_sparse_vector = self._convex_scaling( dense=xq_d, sparse=xq_s ) @@ -377,7 +386,7 @@ class HybridRouter(BaseRouter): routes.append(utterance.route) utterances.append(utterance.utterance) metadata.append(utterance.metadata) - embeddings = self.encoder(utterances) + embeddings = self.encoder(utterances, input_type="documents") sparse_embeddings = self.sparse_encoder.encode_documents(utterances) self.index = HybridLocalIndex() self.index.add( @@ -392,7 +401,9 @@ class HybridRouter(BaseRouter): Xq_d: List[List[float]] = [] Xq_s: List[SparseEmbedding] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): - emb_d = np.array(self.encoder(X[i : i + batch_size])) + emb_d = np.array( + self.encoder(X[i : i + batch_size], input_type="documents") + ) # TODO JB: for some reason the sparse encoder is receiving a tuple # like `("Hello",)` emb_s = self.sparse_encoder.encode_documents(X[i : i + batch_size]) @@ -441,8 +452,8 @@ class HybridRouter(BaseRouter): Xq_d: List[List[float]] = [] Xq_s: List[SparseEmbedding] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): - emb_d = np.array(self.encoder(X[i : i + batch_size])) - emb_s = self.sparse_encoder.encode_queries(X[i : i + batch_size]) + emb_d = np.array(self.encoder(X[i : i + batch_size], input_type="queries")) + emb_s = self.sparse_encoder(X[i : i + batch_size], input_type="queries") Xq_d.extend(emb_d) Xq_s.extend(emb_s) -- GitLab