From bdfe5ea123a15e14dd1130a6764a38846aa09b05 Mon Sep 17 00:00:00 2001
From: Bogdan Buduroiu <bogdan@buduroiu.com>
Date: Mon, 17 Feb 2025 22:24:08 +0800
Subject: [PATCH] chore: add `input_type` to encoders

---
 semantic_router/encoders/base.py              | 41 +++++++++++++++---
 semantic_router/encoders/bm25.py              | 26 ++++++++++-
 semantic_router/encoders/encode_input_type.py |  3 ++
 semantic_router/encoders/tfidf.py             | 31 ++++++++-----
 semantic_router/routers/base.py               | 20 ++++++---
 semantic_router/routers/hybrid.py             | 43 ++++++++++++-------
 6 files changed, 126 insertions(+), 38 deletions(-)
 create mode 100644 semantic_router/encoders/encode_input_type.py

diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index a9b460ad..3c14f945 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -3,6 +3,7 @@ from typing import Any, Coroutine, List, Optional
 import numpy as np
 from pydantic import BaseModel, Field, field_validator
 
+from semantic_router.encoders.encode_input_type import EncodeInputType
 from semantic_router.route import Route
 from semantic_router.schema import SparseEmbedding
 
@@ -26,25 +27,33 @@ class DenseEncoder(BaseModel):
         """
         return float(v) if v is not None else None
 
-    def __call__(self, docs: List[Any]) -> List[List[float]]:
+    def __call__(
+        self, docs: List[Any], input_type: EncodeInputType
+    ) -> List[List[float]]:
         """Encode a list of documents. Documents can be any type, but the encoder must
         be built to handle that data type. Typically, these types are strings or
         arrays representing images.
 
         :param docs: The documents to encode.
         :type docs: List[Any]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: Literal["queries", "documents"]
         :return: The encoded documents.
         :rtype: List[List[float]]
         """
         raise NotImplementedError("Subclasses must implement this method")
 
-    def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]:
+    def acall(
+        self, docs: List[Any], input_type: EncodeInputType
+    ) -> Coroutine[Any, Any, List[List[float]]]:
         """Encode a list of documents asynchronously. Documents can be any type, but the
         encoder must be built to handle that data type. Typically, these types are
         strings or arrays representing images.
 
         :param docs: The documents to encode.
         :type docs: List[Any]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType
         :return: The encoded documents.
         :rtype: List[List[float]]
         """
@@ -60,13 +69,35 @@ class SparseEncoder(BaseModel):
     class Config:
         arbitrary_types_allowed = True
 
-    def __call__(self, docs: List[str]) -> List[SparseEmbedding]:
+    def __call__(
+        self,
+        docs: List[str],
+        input_type: Optional[EncodeInputType] = "queries",
+    ) -> List[SparseEmbedding]:
         """Sparsely encode a list of documents. Documents can be any type, but the encoder must
         be built to handle that data type. Typically, these types are strings or
         arrays representing images.
 
         :param docs: The documents to encode.
         :type docs: List[Any]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType
+        :return: The encoded documents.
+        :rtype: List[SparseEmbedding]
+        """
+        raise NotImplementedError("Subclasses must implement this method")
+
+    def acall(
+        self, docs: List[Any], input_type: EncodeInputType
+    ) -> Coroutine[Any, Any, List[SparseEmbedding]]:
+        """Encode a list of documents asynchronously. Documents can be any type, but the
+        encoder must be built to handle that data type. Typically, these types are
+        strings or arrays representing images.
+
+        :param docs: The documents to encode.
+        :type docs: List[Any]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: Literal["queries", "documents"]
         :return: The encoded documents.
         :rtype: List[SparseEmbedding]
         """
@@ -80,11 +111,11 @@ class SparseEncoder(BaseModel):
         """Convert document texts to sparse embeddings optimized for storage"""
         raise NotImplementedError("Subclasses must implement this method")
 
-    async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]:
+    async def aencode_queries(self, docs: List[str]) -> List[SparseEmbedding]:
         """Async version of encode_queries"""
         raise NotImplementedError("Subclasses must implement this method")
 
-    async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]:
+    async def aencode_documents(self, docs: List[str]) -> List[SparseEmbedding]:
         """Async version of encode_documents"""
         raise NotImplementedError("Subclasses must implement this method")
 
diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py
index 5377e2e7..00c6039d 100644
--- a/semantic_router/encoders/bm25.py
+++ b/semantic_router/encoders/bm25.py
@@ -1,5 +1,6 @@
+import asyncio
 from functools import partial
-from typing import List
+from typing import Any, Coroutine, List, Literal
 
 import numpy as np
 
@@ -269,5 +270,26 @@ class BM25Encoder(SparseEncoder, FittableMixin):
 
         return self.encode_queries(docs)
 
-    def __call__(self, docs: List[str]) -> list[SparseEmbedding]:
+    async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]:
+        # While this is a CPU-bound operation, and doesn't benefit from asyncio
+        # we provide this method to abide by the `SparseEncoder` superclass
         return self.encode_queries(docs)
+
+    async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]:
+        # While this is a CPU-bound operation, and doesn't benefit from asyncio
+        # we provide this method to abide by the `SparseEncoder` superclass
+        return self.encode_documents(docs)
+
+    def __call__(
+        self, docs: List[str], input_type: Literal["queries", "documents"]
+    ) -> list[SparseEmbedding]:
+        match input_type:
+            case "queries":
+                return self.encode_queries(docs)
+            case "documents":
+                return self.encode_documents(docs)
+
+    def acall(
+        self, docs: List[Any], input_type: Literal["queries", "documents"]
+    ) -> Coroutine[Any, Any, List[SparseEmbedding]]:
+        return asyncio.to_thread(lambda: self.__call__(docs, input_type))
diff --git a/semantic_router/encoders/encode_input_type.py b/semantic_router/encoders/encode_input_type.py
new file mode 100644
index 00000000..1cfe9c87
--- /dev/null
+++ b/semantic_router/encoders/encode_input_type.py
@@ -0,0 +1,3 @@
+from typing import Literal
+
+EncodeInputType = Literal["queries", "documents"]
diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py
index 0d6642b7..3d3cbe57 100644
--- a/semantic_router/encoders/tfidf.py
+++ b/semantic_router/encoders/tfidf.py
@@ -1,11 +1,13 @@
+import asyncio
 import string
 from collections import Counter
-from typing import Dict, List
+from typing import Any, Coroutine, Dict, List
 
 import numpy as np
 
 from semantic_router.encoders import SparseEncoder
 from semantic_router.encoders.base import FittableMixin
+from semantic_router.encoders.encode_input_type import EncodeInputType
 from semantic_router.route import Route
 from semantic_router.schema import SparseEmbedding
 
@@ -22,7 +24,9 @@ class TfidfEncoder(SparseEncoder, FittableMixin):
         self.word_index = {}
         self.idf = np.array([])
 
-    def __call__(self, docs: List[str]) -> list[SparseEmbedding]:
+    def __call__(
+        self, docs: List[str], input_type: EncodeInputType
+    ) -> list[SparseEmbedding]:
         if len(self.word_index) == 0 or self.idf.size == 0:
             raise ValueError("Vectorizer is not initialized.")
         if len(docs) == 0:
@@ -33,21 +37,28 @@ class TfidfEncoder(SparseEncoder, FittableMixin):
         tfidf = tf * self.idf
         return self._array_to_sparse_embeddings(tfidf)
 
-    def encode_queries(self, docs: List[str]) -> list[SparseEmbedding]:
+    async def acall(
+        self, docs: List[str], input_type: EncodeInputType
+    ) -> Coroutine[Any, Any, List[SparseEmbedding]]:
+        return asyncio.to_thread(lambda: self.__call__(docs, input_type))
+
+    def encode_queries(self, docs: List[str]) -> List[SparseEmbedding]:
         """Encode documents using TF-IDF"""
-        return self.__call__(docs)  # TF-IDF uses same method for docs and queries
+        # TF-IDF uses same method for docs and queries
+        return self.__call__(docs, input_type="queries")
 
-    def encode_documents(self, docs: List[str]) -> list[SparseEmbedding]:
+    def encode_documents(self, docs: List[str]) -> List[SparseEmbedding]:
         """Encode documents using TF-IDF"""
-        return self.__call__(docs)  # TF-IDF uses same method for docs and queries
+        # TF-IDF uses same method for docs and queries
+        return self.__call__(docs, input_type="documents")
 
-    async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]:
+    async def aencode_queries(self, docs: List[str]) -> List[SparseEmbedding]:
         """Async version of encode_queries"""
-        return self.__call__(docs)
+        return self.__call__(docs, input_type="queries")
 
-    async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]:
+    async def aencode_documents(self, docs: List[str]) -> List[SparseEmbedding]:
         """Async version of encode_documents"""
-        return self.__call__(docs)
+        return self.__call__(docs, input_type="documents")
 
     def fit(self, routes: List[Route]):
         """Trains the encoder weights on the provided routes.
diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py
index 3d5c18af..61229b1e 100644
--- a/semantic_router/routers/base.py
+++ b/semantic_router/routers/base.py
@@ -3,7 +3,7 @@ import importlib
 import json
 import os
 import random
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
 
 import numpy as np
 import yaml  # type: ignore
@@ -822,7 +822,8 @@ class BaseRouter(BaseModel):
                 routes=[utt.route for utt in strategy["remote"]["upsert"]],
                 utterances=utterances_text,
                 function_schemas=[
-                    utt.function_schemas for utt in strategy["remote"]["upsert"]  # type: ignore
+                    utt.function_schemas
+                    for utt in strategy["remote"]["upsert"]  # type: ignore
                 ],
                 metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
             )
@@ -855,7 +856,8 @@ class BaseRouter(BaseModel):
                 routes=[utt.route for utt in strategy["remote"]["upsert"]],
                 utterances=utterances_text,
                 function_schemas=[
-                    utt.function_schemas for utt in strategy["remote"]["upsert"]  # type: ignore
+                    utt.function_schemas
+                    for utt in strategy["remote"]["upsert"]  # type: ignore
                 ],
                 metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
             )
@@ -1333,26 +1335,34 @@ class BaseRouter(BaseModel):
             return route_names, utterances, function_schemas, metadata
         return route_names, utterances, function_schemas
 
-    def _encode(self, text: list[str]) -> Any:
+    def _encode(
+        self, text: list[str], input_type: Literal["queries", "documents"]
+    ) -> Any:
         """Generates embeddings for a given text.
 
         Must be implemented by a subclass.
 
         :param text: The text to encode.
         :type text: list[str]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: Literal["queries", "documents"]
         :return: The embeddings of the text.
         :rtype: Any
         """
         # TODO: should encode "content" rather than text
         raise NotImplementedError("This method should be implemented by subclasses.")
 
-    async def _async_encode(self, text: list[str]) -> Any:
+    async def _async_encode(
+        self, text: list[str], input_type: Literal["queries", "documents"]
+    ) -> Any:
         """Asynchronously generates embeddings for a given text.
 
         Must be implemented by a subclass.
 
         :param text: The text to encode.
         :type text: list[str]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: Literal["queries", "documents"]
         :return: The embeddings of the text.
         :rtype: Any
         """
diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py
index 9c6680c6..ec210f25 100644
--- a/semantic_router/routers/hybrid.py
+++ b/semantic_router/routers/hybrid.py
@@ -11,6 +11,7 @@ from semantic_router.encoders import (
     SparseEncoder,
 )
 from semantic_router.encoders.base import FittableMixin
+from semantic_router.encoders.encode_input_type import EncodeInputType
 from semantic_router.index import BaseIndex, HybridLocalIndex
 from semantic_router.llms import BaseLLM
 from semantic_router.route import Route
@@ -112,7 +113,7 @@ class HybridRouter(BaseRouter):
         ) = self._extract_routes_details(routes, include_metadata=True)
         # TODO: to merge, self._encode should probably output a special
         # TODO Embedding type that can be either dense or hybrid
-        dense_emb, sparse_emb = self._encode(all_utterances)
+        dense_emb, sparse_emb = self._encode(all_utterances, input_type="documents")
         self.index.add(
             embeddings=dense_emb.tolist(),
             routes=route_names,
@@ -148,7 +149,9 @@ class HybridRouter(BaseRouter):
             self.index._remove_and_sync(data_to_delete)
         if strategy["remote"]["upsert"]:
             utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
-            dense_emb, sparse_emb = self._encode(utterances_text)
+            dense_emb, sparse_emb = self._encode(
+                utterances_text, input_type="documents"
+            )
             self.index.add(
                 embeddings=dense_emb.tolist(),
                 routes=[utt.route for utt in strategy["remote"]["upsert"]],
@@ -202,38 +205,44 @@ class HybridRouter(BaseRouter):
         return sparse_encoder
 
     def _encode(
-        self,
-        text: list[str],
+        self, text: list[str], input_type: EncodeInputType
     ) -> tuple[np.ndarray, list[SparseEmbedding]]:
         """Given some text, generates dense and sparse embeddings, then scales them
         using the chosen alpha value.
 
         :param text: List of texts to encode
         :type text: List[str]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType
         :return: Tuple of dense and sparse embeddings
         """
         if self.sparse_encoder is None:
             raise ValueError("self.sparse_encoder is not set.")
 
         # Create dense vector
-        xq_d = np.array(self.encoder(text))
+        xq_d = np.array(self.encoder(text, input_type=input_type))
 
         # Create sparse vector, handling BM25's query/document distinction
-        xq_s = self.sparse_encoder.encode_documents(text)
+        match input_type:
+            case "queries":
+                xq_s = self.sparse_encoder.encode_queries(text)
+            case "documents":
+                xq_s = self.sparse_encoder.encode_documents(text)
 
         # Convex scaling
         xq_d, xq_s = self._convex_scaling(dense=xq_d, sparse=xq_s)
         return xq_d, xq_s
 
     async def _async_encode(
-        self,
-        text: List[str],
+        self, text: List[str], input_type: EncodeInputType
     ) -> tuple[np.ndarray, list[SparseEmbedding]]:
         """Given some text, generates dense and sparse embeddings, then scales them
         using the chosen alpha value.
 
         :param text: The text to encode.
         :type text: List[str]
+        :param input_type: Specify whether encoding 'queries' or 'documents', used in asymmetric retrieval
+        :type input_type: semantic_router.encoders.encode_input_type.EncodeInputType
         :return: A tuple of the dense and sparse embeddings.
         :rtype: tuple[np.ndarray, list[SparseEmbedding]]
         """
@@ -242,8 +251,8 @@ class HybridRouter(BaseRouter):
         # TODO: should encode "content" rather than text
         # TODO: add alpha as a parameter
         # async encode both dense and sparse
-        dense_coro = self.encoder.acall(text)
-        sparse_coro = self.sparse_encoder.aencode_documents(text)
+        dense_coro = self.encoder.acall(text, input_type=input_type)
+        sparse_coro = self.sparse_encoder.acall(text, input_type=input_type)
         dense_vec, xq_s = await asyncio.gather(dense_coro, sparse_coro)
         # create dense query vector
         xq_d = np.array(dense_vec)
@@ -281,8 +290,8 @@ class HybridRouter(BaseRouter):
         if vector is None:
             if text is None:
                 raise ValueError("Either text or vector must be provided")
-            xq_d = np.array(self.encoder([text]))
-            xq_s = self.sparse_encoder.encode_queries([text])
+            xq_d = np.array(self.encoder([text], input_type="queries"))
+            xq_s = self.sparse_encoder([text], input_type="queries")
             vector, potential_sparse_vector = self._convex_scaling(
                 dense=xq_d, sparse=xq_s
             )
@@ -377,7 +386,7 @@ class HybridRouter(BaseRouter):
                 routes.append(utterance.route)
                 utterances.append(utterance.utterance)
                 metadata.append(utterance.metadata)
-            embeddings = self.encoder(utterances)
+            embeddings = self.encoder(utterances, input_type="documents")
             sparse_embeddings = self.sparse_encoder.encode_documents(utterances)
             self.index = HybridLocalIndex()
             self.index.add(
@@ -392,7 +401,9 @@ class HybridRouter(BaseRouter):
         Xq_d: List[List[float]] = []
         Xq_s: List[SparseEmbedding] = []
         for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
-            emb_d = np.array(self.encoder(X[i : i + batch_size]))
+            emb_d = np.array(
+                self.encoder(X[i : i + batch_size], input_type="documents")
+            )
             # TODO JB: for some reason the sparse encoder is receiving a tuple
             # like `("Hello",)`
             emb_s = self.sparse_encoder.encode_documents(X[i : i + batch_size])
@@ -441,8 +452,8 @@ class HybridRouter(BaseRouter):
         Xq_d: List[List[float]] = []
         Xq_s: List[SparseEmbedding] = []
         for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
-            emb_d = np.array(self.encoder(X[i : i + batch_size]))
-            emb_s = self.sparse_encoder.encode_queries(X[i : i + batch_size])
+            emb_d = np.array(self.encoder(X[i : i + batch_size], input_type="queries"))
+            emb_s = self.sparse_encoder(X[i : i + batch_size], input_type="queries")
             Xq_d.extend(emb_d)
             Xq_s.extend(emb_s)
 
-- 
GitLab