diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index 2859289b016b88440466e2ee679fede4b183c8b6..6df3da151bb0e535a76621e7b1eece0e00f21283 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -1,14 +1,15 @@ import os -from typing import Any, List, Optional +from typing import Any, Coroutine, List, Optional from aurelio_sdk import AsyncAurelioClient, AurelioClient, EmbeddingResponse +from aurelio_sdk.client_async import asyncio from pydantic import Field -from semantic_router.encoders.base import SparseEncoder +from semantic_router.encoders.base import AsymmetricSparseMixin, SparseEncoder from semantic_router.schema import SparseEmbedding -class AurelioSparseEncoder(SparseEncoder): +class AurelioSparseEncoder(SparseEncoder, AsymmetricSparseMixin): """Sparse encoder using Aurelio Platform's embedding API. Requires an API key from https://platform.aurelio.ai """ @@ -43,7 +44,7 @@ class AurelioSparseEncoder(SparseEncoder): self.async_client = AsyncAurelioClient(api_key=api_key) def __call__(self, docs: list[str]) -> list[SparseEmbedding]: - """Encode a list of documents using the Aurelio Platform embedding API. Documents + """Encode a list of queries using the Aurelio Platform embedding API. Documents must be strings, sparse encoders do not support other types. """ return self.encode_queries(docs) @@ -62,27 +63,39 @@ class AurelioSparseEncoder(SparseEncoder): embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] return embeds - async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_queries( + self, docs: List[str] + ) -> Coroutine[Any, Any, list[SparseEmbedding]]: res: EmbeddingResponse = await self.async_client.embedding( input=docs, model=self.name, input_type="queries" ) - embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] + embeds = asyncio.to_thread( + lambda: [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] + ) return embeds - async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_documents( + self, docs: List[str] + ) -> Coroutine[Any, Any, list[SparseEmbedding]]: res: EmbeddingResponse = await self.async_client.embedding( input=docs, model=self.name, input_type="documents" ) - embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] + embeds = asyncio.to_thread( + lambda: [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] + ) return embeds - async def acall(self, docs: list[str]) -> list[SparseEmbedding]: + async def acall( + self, docs: list[str] + ) -> Coroutine[Any, Any, list[SparseEmbedding]]: """Asynchronously encode a list of documents using the Aurelio Platform embedding API. Documents must be strings, sparse encoders do not support other types. :param docs: The documents to encode. :type docs: list[str] + :param input_type: + :type semantic_router.encoders.encode_input_type.EncodeInputType :return: The encoded documents. :rtype: list[SparseEmbedding] """ diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 3c14f94557a6bbc489e4797586c50a4647488d7e..4dec0ba204c5497d65562d1dadb365022facbae5 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -3,7 +3,6 @@ from typing import Any, Coroutine, List, Optional import numpy as np from pydantic import BaseModel, Field, field_validator -from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.route import Route from semantic_router.schema import SparseEmbedding @@ -27,33 +26,25 @@ class DenseEncoder(BaseModel): """ return float(v) if v is not None else None - def __call__( - self, docs: List[Any], input_type: EncodeInputType - ) -> List[List[float]]: + def __call__(self, docs: List[Any]) -> List[List[float]]: """Encode a list of documents. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] - :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: Literal["queries", "documents"] :return: The encoded documents. :rtype: List[List[float]] """ raise NotImplementedError("Subclasses must implement this method") - def acall( - self, docs: List[Any], input_type: EncodeInputType - ) -> Coroutine[Any, Any, List[List[float]]]: + def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]: """Encode a list of documents asynchronously. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] - :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The encoded documents. :rtype: List[List[float]] """ @@ -69,56 +60,30 @@ class SparseEncoder(BaseModel): class Config: arbitrary_types_allowed = True - def __call__( - self, - docs: List[str], - input_type: Optional[EncodeInputType] = "queries", - ) -> List[SparseEmbedding]: + def __call__(self, docs: List[str]) -> List[SparseEmbedding]: """Sparsely encode a list of documents. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] - :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The encoded documents. :rtype: List[SparseEmbedding] """ raise NotImplementedError("Subclasses must implement this method") - def acall( - self, docs: List[Any], input_type: EncodeInputType - ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[SparseEmbedding]]: """Encode a list of documents asynchronously. Documents can be any type, but the encoder must be built to handle that data type. Typically, these types are strings or arrays representing images. :param docs: The documents to encode. :type docs: List[Any] - :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: Literal["queries", "documents"] :return: The encoded documents. :rtype: List[SparseEmbedding] """ raise NotImplementedError("Subclasses must implement this method") - def encode_queries(self, docs: List[str]) -> List[SparseEmbedding]: - """Convert query texts to sparse embeddings optimized for querying""" - raise NotImplementedError("Subclasses must implement this method") - - def encode_documents(self, docs: List[str]) -> List[SparseEmbedding]: - """Convert document texts to sparse embeddings optimized for storage""" - raise NotImplementedError("Subclasses must implement this method") - - async def aencode_queries(self, docs: List[str]) -> List[SparseEmbedding]: - """Async version of encode_queries""" - raise NotImplementedError("Subclasses must implement this method") - - async def aencode_documents(self, docs: List[str]) -> List[SparseEmbedding]: - """Async version of encode_documents""" - raise NotImplementedError("Subclasses must implement this method") - def _array_to_sparse_embeddings( self, sparse_arrays: np.ndarray ) -> List[SparseEmbedding]: @@ -147,3 +112,47 @@ class SparseEncoder(BaseModel): class FittableMixin: def fit(self, routes: list[Route]): pass + + +class AsymmetricDenseMixin: + def encode_queries(self, docs: List[str]) -> List[List[float]]: + """Convert query texts to dense embeddings optimized for querying""" + raise NotImplementedError("Subclasses must implement this method") + + def encode_documents(self, docs: List[str]) -> List[List[float]]: + """Convert document texts to dense embeddings optimized for storage""" + raise NotImplementedError("Subclasses must implement this method") + + async def aencode_queries( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[List[float]]]: + """Async version of encode_queries""" + raise NotImplementedError("Subclasses must implement this method") + + async def aencode_documents( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[List[float]]]: + """Async version of encode_documents""" + raise NotImplementedError("Subclasses must implement this method") + + +class AsymmetricSparseMixin: + def encode_queries(self, docs: List[str]) -> List[SparseEmbedding]: + """Convert query texts to dense embeddings optimized for querying""" + raise NotImplementedError("Subclasses must implement this method") + + def encode_documents(self, docs: List[str]) -> List[SparseEmbedding]: + """Convert document texts to dense embeddings optimized for storage""" + raise NotImplementedError("Subclasses must implement this method") + + async def aencode_queries( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + """Async version of encode_queries""" + raise NotImplementedError("Subclasses must implement this method") + + async def aencode_documents( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + """Async version of encode_documents""" + raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 00c6039d3eefb0e076753392e287d9ffd28ed217..464ce50c23d7dd6156cfda354435091ba4350d34 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -1,17 +1,21 @@ import asyncio from functools import partial -from typing import Any, Coroutine, List, Literal +from typing import Any, Coroutine, List import numpy as np -from semantic_router.encoders.base import FittableMixin, SparseEncoder +from semantic_router.encoders.base import ( + AsymmetricSparseMixin, + FittableMixin, + SparseEncoder, +) from semantic_router.route import Route from semantic_router.schema import SparseEmbedding from semantic_router.tokenizers import BaseTokenizer, PretrainedTokenizer from semantic_router.utils.logger import logger -class BM25Encoder(SparseEncoder, FittableMixin): +class BM25Encoder(SparseEncoder, FittableMixin, AsymmetricSparseMixin): """BM25Encoder, running a vectorized version of ATIRE BM25 algorithm Concept: @@ -270,26 +274,22 @@ class BM25Encoder(SparseEncoder, FittableMixin): return self.encode_queries(docs) - async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_queries( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: # While this is a CPU-bound operation, and doesn't benefit from asyncio # we provide this method to abide by the `SparseEncoder` superclass - return self.encode_queries(docs) + return asyncio.to_thread(lambda: self.encode_queries(docs)) - async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]: + async def aencode_documents( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[SparseEmbedding]]: # While this is a CPU-bound operation, and doesn't benefit from asyncio # we provide this method to abide by the `SparseEncoder` superclass - return self.encode_documents(docs) + return asyncio.to_thread(lambda: self.encode_documents(docs)) - def __call__( - self, docs: List[str], input_type: Literal["queries", "documents"] - ) -> list[SparseEmbedding]: - match input_type: - case "queries": - return self.encode_queries(docs) - case "documents": - return self.encode_documents(docs) - - def acall( - self, docs: List[Any], input_type: Literal["queries", "documents"] - ) -> Coroutine[Any, Any, List[SparseEmbedding]]: - return asyncio.to_thread(lambda: self.__call__(docs, input_type)) + def __call__(self, docs: List[str]) -> list[SparseEmbedding]: + return self.encode_queries(docs) + + def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[SparseEmbedding]]: + return asyncio.to_thread(lambda: self.__call__(docs)) diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index dc5b0cdbdc695c8e92c43f23a3ca580ccfffedc8..d81f142b4a42c982557f26aab71624bb355fbd63 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -1,20 +1,20 @@ import os -from typing import Any, List, Optional +from typing import Any, Coroutine, List, Optional -from cohere import EmbedInputType from pydantic import PrivateAttr from semantic_router.encoders import DenseEncoder -from semantic_router.encoders.encode_input_type import EncodeInputType +from semantic_router.encoders.base import AsymmetricDenseMixin from semantic_router.utils.defaults import EncoderDefault -class CohereEncoder(DenseEncoder): +class CohereEncoder(DenseEncoder, AsymmetricDenseMixin): """Dense encoder that uses Cohere API to embed documents. Supports text only. Requires a Cohere API key from https://dashboard.cohere.com/api-keys. """ _client: Any = PrivateAttr() + _async_client: Any = PrivateAttr() _embed_type: Any = PrivateAttr() type: str = "cohere" @@ -42,7 +42,7 @@ class CohereEncoder(DenseEncoder): name=name, score_threshold=score_threshold, ) - self._client = self._initialize_client(cohere_api_key) + self._client, self._async_client = self._initialize_client(cohere_api_key) def _initialize_client(self, cohere_api_key: Optional[str] = None): """Initializes the Cohere client. @@ -69,38 +69,86 @@ class CohereEncoder(DenseEncoder): raise ValueError("Cohere API key cannot be 'None'.") try: client = cohere.Client(cohere_api_key) + async_client = cohere.AsyncClient(cohere_api_key) except Exception as e: raise ValueError( f"Cohere API client failed to initialize. Error: {e}" ) from e - return client + return client, async_client - def __call__( - self, docs: List[str], input_type: EncodeInputType - ) -> List[List[float]]: + def __call__(self, docs: List[str]) -> List[List[float]]: """Embed a list of documents. Supports text only. :param docs: The documents to embed. :type docs: List[str] - :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The vector embeddings of the documents. :rtype: List[List[float]] """ + return self.encode_queries(docs) + + def encode_queries(self, docs: List[str]) -> List[List[float]]: if self._client is None: raise ValueError("Cohere client is not initialized.") - cohere_input_type: EmbedInputType = None - match input_type: - case "queries": - cohere_input_type = "search_query" - case "documents": - cohere_input_type = "search_document" try: embeds = self._client.embed( - texts=docs, input_type=cohere_input_type, model=self.name + texts=docs, input_type="search_query", model=self.name + ) + if isinstance(embeds, self._embed_type): + raise NotImplementedError( + "Handling of EmbedByTypeResponseEmbeddings is not implemented." + ) + else: + return embeds.embeddings + except Exception as e: + raise ValueError(f"Cohere API call failed. Error: {e}") from e + + def encode_documents(self, docs: List[str]) -> List[List[float]]: + if self._client is None: + raise ValueError("Cohere client is not initialized.") + + try: + embeds = self._client.embed( + texts=docs, input_type="search_document", model=self.name + ) + if isinstance(embeds, self._embed_type): + raise NotImplementedError( + "Handling of EmbedByTypeResponseEmbeddings is not implemented." + ) + else: + return embeds.embeddings + except Exception as e: + raise ValueError(f"Cohere API call failed. Error: {e}") from e + + async def aencode_queries( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[List[float]]]: + if self._async_client is None: + raise ValueError("Cohere client is not initialized.") + + try: + embeds = await self._async_client.embed( + texts=docs, input_type="search_query", model=self.name + ) + if isinstance(embeds, self._embed_type): + raise NotImplementedError( + "Handling of EmbedByTypeResponseEmbeddings is not implemented." + ) + else: + return embeds.embeddings + except Exception as e: + raise ValueError(f"Cohere API call failed. Error: {e}") from e + + async def aencode_documents( + self, docs: List[str] + ) -> Coroutine[Any, Any, List[List[float]]]: + if self._async_client is None: + raise ValueError("Cohere client is not initialized.") + + try: + embeds = await self._async_client.embed( + texts=docs, input_type="search_document", model=self.name ) - # Check for unsupported type. if isinstance(embeds, self._embed_type): raise NotImplementedError( "Handling of EmbedByTypeResponseEmbeddings is not implemented." diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 3d3cbe575e2c94a40b16e2588e0f2809952118bb..59a9cffd261778f77a7ac03054322a9eae0d0dc4 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -7,7 +7,6 @@ import numpy as np from semantic_router.encoders import SparseEncoder from semantic_router.encoders.base import FittableMixin -from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.route import Route from semantic_router.schema import SparseEmbedding @@ -24,9 +23,7 @@ class TfidfEncoder(SparseEncoder, FittableMixin): self.word_index = {} self.idf = np.array([]) - def __call__( - self, docs: List[str], input_type: EncodeInputType - ) -> list[SparseEmbedding]: + def __call__(self, docs: List[str]) -> list[SparseEmbedding]: if len(self.word_index) == 0 or self.idf.size == 0: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: @@ -38,27 +35,9 @@ class TfidfEncoder(SparseEncoder, FittableMixin): return self._array_to_sparse_embeddings(tfidf) async def acall( - self, docs: List[str], input_type: EncodeInputType + self, docs: List[str] ) -> Coroutine[Any, Any, List[SparseEmbedding]]: - return asyncio.to_thread(lambda: self.__call__(docs, input_type)) - - def encode_queries(self, docs: List[str]) -> List[SparseEmbedding]: - """Encode documents using TF-IDF""" - # TF-IDF uses same method for docs and queries - return self.__call__(docs, input_type="queries") - - def encode_documents(self, docs: List[str]) -> List[SparseEmbedding]: - """Encode documents using TF-IDF""" - # TF-IDF uses same method for docs and queries - return self.__call__(docs, input_type="documents") - - async def aencode_queries(self, docs: List[str]) -> List[SparseEmbedding]: - """Async version of encode_queries""" - return self.__call__(docs, input_type="queries") - - async def aencode_documents(self, docs: List[str]) -> List[SparseEmbedding]: - """Async version of encode_documents""" - return self.__call__(docs, input_type="documents") + return asyncio.to_thread(lambda: self.__call__(docs)) def fit(self, routes: List[Route]): """Trains the encoder weights on the provided routes. diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 61229b1e1bb90865ecb58ed8719d69d32e21dc48..c63d61876b21b7838e362e9245f05e00aad32925 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -3,7 +3,7 @@ import importlib import json import os import random -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import yaml # type: ignore @@ -17,6 +17,7 @@ from semantic_router.encoders import ( OpenAIEncoder, SparseEncoder, ) +from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.index.base import BaseIndex from semantic_router.index.local import LocalIndex from semantic_router.index.pinecone import PineconeIndex @@ -1336,7 +1337,9 @@ class BaseRouter(BaseModel): return route_names, utterances, function_schemas def _encode( - self, text: list[str], input_type: Literal["queries", "documents"] + self, + text: list[str], + input_type: EncodeInputType, ) -> Any: """Generates embeddings for a given text. @@ -1345,7 +1348,7 @@ class BaseRouter(BaseModel): :param text: The text to encode. :type text: list[str] :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: Literal["queries", "documents"] + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The embeddings of the text. :rtype: Any """ @@ -1353,7 +1356,9 @@ class BaseRouter(BaseModel): raise NotImplementedError("This method should be implemented by subclasses.") async def _async_encode( - self, text: list[str], input_type: Literal["queries", "documents"] + self, + text: list[str], + input_type: EncodeInputType, ) -> Any: """Asynchronously generates embeddings for a given text. @@ -1362,7 +1367,7 @@ class BaseRouter(BaseModel): :param text: The text to encode. :type text: list[str] :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval - :type input_type: Literal["queries", "documents"] + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The embeddings of the text. :rtype: Any """ diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index ec210f25d3aeef9a88b7a1ce62d5b37d05a32fcd..f89804aae6783702b4856917ef47369d2b4e84aa 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -10,7 +10,11 @@ from semantic_router.encoders import ( DenseEncoder, SparseEncoder, ) -from semantic_router.encoders.base import FittableMixin +from semantic_router.encoders.base import ( + AsymmetricDenseMixin, + AsymmetricSparseMixin, + FittableMixin, +) from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.index import BaseIndex, HybridLocalIndex from semantic_router.llms import BaseLLM @@ -219,15 +223,23 @@ class HybridRouter(BaseRouter): if self.sparse_encoder is None: raise ValueError("self.sparse_encoder is not set.") - # Create dense vector - xq_d = np.array(self.encoder(text, input_type=input_type)) - - # Create sparse vector, handling BM25's query/document distinction - match input_type: - case "queries": - xq_s = self.sparse_encoder.encode_queries(text) - case "documents": - xq_s = self.sparse_encoder.encode_documents(text) + if isinstance(self.encoder, AsymmetricDenseMixin): + match input_type: + case "queries": + xq_d = self.encoder.encode_queries(text) + case "documents": + xq_d = self.encoder.encode_documents(text) + else: + xq_d = self.encoder(text) + + if isinstance(self.sparse_encoder, AsymmetricSparseMixin): + match input_type: + case "queries": + xq_s = self.sparse_encoder.encode_queries(text) + case "documents": + xq_s = self.sparse_encoder.encode_documents(text) + else: + xq_s = self.sparse_encoder(text) # Convex scaling xq_d, xq_s = self._convex_scaling(dense=xq_d, sparse=xq_s) @@ -251,8 +263,24 @@ class HybridRouter(BaseRouter): # TODO: should encode "content" rather than text # TODO: add alpha as a parameter # async encode both dense and sparse - dense_coro = self.encoder.acall(text, input_type=input_type) - sparse_coro = self.sparse_encoder.acall(text, input_type=input_type) + + if isinstance(self.encoder, AsymmetricDenseMixin): + match input_type: + case "queries": + dense_coro = self.encoder.aencode_queries(text) + case "documents": + dense_coro = self.encoder.aencode_documents(text) + else: + dense_coro = self.encoder.acall(text) + + if isinstance(self.sparse_encoder, AsymmetricSparseMixin): + match input_type: + case "queries": + sparse_coro = self.sparse_encoder.aencode_queries(text) + case "documents": + sparse_coro = self.sparse_encoder.aencode_documents(text) + else: + sparse_coro = self.sparse_encoder.acall(text) dense_vec, xq_s = await asyncio.gather(dense_coro, sparse_coro) # create dense query vector xq_d = np.array(dense_vec) @@ -290,8 +318,8 @@ class HybridRouter(BaseRouter): if vector is None: if text is None: raise ValueError("Either text or vector must be provided") - xq_d = np.array(self.encoder([text], input_type="queries")) - xq_s = self.sparse_encoder([text], input_type="queries") + xq_d = np.array(self.encoder([text])) + xq_s = self.sparse_encoder([text]) vector, potential_sparse_vector = self._convex_scaling( dense=xq_d, sparse=xq_s ) @@ -386,8 +414,16 @@ class HybridRouter(BaseRouter): routes.append(utterance.route) utterances.append(utterance.utterance) metadata.append(utterance.metadata) - embeddings = self.encoder(utterances, input_type="documents") - sparse_embeddings = self.sparse_encoder.encode_documents(utterances) + embeddings = ( + self.encoder(utterances) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.encode_documents(utterances) + ) + sparse_embeddings = ( + self.sparse_encoder(utterances) + if not isinstance(self.sparse_encoder, AsymmetricSparseMixin) + else self.sparse_encoder.encode_documents(utterances) + ) self.index = HybridLocalIndex() self.index.add( embeddings=embeddings, @@ -402,11 +438,17 @@ class HybridRouter(BaseRouter): Xq_s: List[SparseEmbedding] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): emb_d = np.array( - self.encoder(X[i : i + batch_size], input_type="documents") + self.encoder(X[i : i + batch_size]) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.encode_queries(X[i : i + batch_size]) ) # TODO JB: for some reason the sparse encoder is receiving a tuple # like `("Hello",)` - emb_s = self.sparse_encoder.encode_documents(X[i : i + batch_size]) + emb_s = ( + self.sparse_encoder(X[i : i + batch_size]) + if not isinstance(self.sparse_encoder, AsymmetricSparseMixin) + else self.sparse_encoder.encode_queries(X[i : i + batch_size]) + ) Xq_d.extend(emb_d) Xq_s.extend(emb_s) # initial eval (we will iterate from here) @@ -452,8 +494,16 @@ class HybridRouter(BaseRouter): Xq_d: List[List[float]] = [] Xq_s: List[SparseEmbedding] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): - emb_d = np.array(self.encoder(X[i : i + batch_size], input_type="queries")) - emb_s = self.sparse_encoder(X[i : i + batch_size], input_type="queries") + emb_d = np.array( + self.encoder(X[i : i + batch_size]) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.encode_queries(X[i : i + batch_size]) + ) + emb_s = ( + self.sparse_encoder(X[i : i + batch_size]) + if not isinstance(self.sparse_encoder, AsymmetricSparseMixin) + else self.sparse_encoder.encode_queries(X[i : i + batch_size]) + ) Xq_d.extend(emb_d) Xq_s.extend(emb_s) diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index 45e2aedd73de28cb5d4935bf74b503d6c24d05a4..8b939f33115fcf089ff4103c1ebbd7ca194b4bf8 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -3,6 +3,8 @@ from typing import Any, List, Optional import numpy as np from semantic_router.encoders import DenseEncoder +from semantic_router.encoders.base import AsymmetricDenseMixin +from semantic_router.encoders.encode_input_type import EncodeInputType from semantic_router.index.base import BaseIndex from semantic_router.llms import BaseLLM from semantic_router.route import Route @@ -35,28 +37,60 @@ class SemanticRouter(BaseRouter): auto_sync=auto_sync, ) - def _encode(self, text: list[str]) -> Any: + def _encode(self, text: list[str], input_type: EncodeInputType) -> Any: """Given some text, encode it. :param text: The text to encode. :type text: list[str] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The encoded text. :rtype: Any """ # create query vector - xq = np.array(self.encoder(text)) + match input_type: + case "queries": + xq = np.array( + self.encoder(text) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.encode_queries(text) + ) + case "documents": + xq = np.array( + self.encoder(text) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.encode_documents(text) + ) return xq - async def _async_encode(self, text: list[str]) -> Any: + async def _async_encode(self, text: list[str], input_type: EncodeInputType) -> Any: """Given some text, encode it. :param text: The text to encode. :type text: list[str] + :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval + :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType :return: The encoded text. :rtype: Any """ # create query vector - xq = np.array(await self.encoder.acall(docs=text)) + match input_type: + case "queries": + xq = np.array( + await ( + self.encoder.acall(docs=text) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.aencode_queries(docs=text) + ) + ) + case "documents": + xq = np.array( + await ( + self.encoder.acall(docs=text) + if not isinstance(self.encoder, AsymmetricDenseMixin) + else self.encoder.aencode_documents(docs=text) + ) + ) return xq def add(self, routes: List[Route] | Route): @@ -79,7 +113,7 @@ class SemanticRouter(BaseRouter): all_function_schemas, all_metadata, ) = self._extract_routes_details(routes, include_metadata=True) - dense_emb = self._encode(all_utterances) + dense_emb = self._encode(all_utterances, input_type="documents") self.index.add( embeddings=dense_emb.tolist(), routes=route_names,