diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index 994fcb2dfc8c5e4d85589ca5cf777c2ab7a26d26..54901d5e50a0116be8a88496869b17ec00aa06e8 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -14,7 +14,7 @@ from semantic_router.route import Route from semantic_router.index import BaseIndex, HybridLocalIndex from semantic_router.schema import RouteChoice, SparseEmbedding, Utterance from semantic_router.utils.logger import logger -from semantic_router.routers.base import BaseRouter +from semantic_router.routers.base import BaseRouter, xq_reshape from semantic_router.llms import BaseLLM @@ -197,18 +197,19 @@ class HybridRouter(BaseRouter): def __call__( self, text: Optional[str] = None, - vector: Optional[List[float]] = None, + vector: Optional[List[float] | np.ndarray] = None, simulate_static: bool = False, route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> RouteChoice: - vector_arr: np.ndarray | None = None potential_sparse_vector: List[SparseEmbedding] | None = None # if no vector provided, encode text to get vector if vector is None: if text is None: raise ValueError("Either text or vector must be provided") - vector_arr, potential_sparse_vector = self._encode(text=[text]) + vector, potential_sparse_vector = self._encode(text=[text]) + # convert to numpy array if not already + vector = xq_reshape(vector) if sparse_vector is None: if text is None: raise ValueError("Either text or sparse_vector must be provided") @@ -217,10 +218,9 @@ class HybridRouter(BaseRouter): ) if sparse_vector is None: raise ValueError("Sparse vector is required for HybridLocalIndex.") - vector_arr = vector_arr if vector_arr is not None else np.array(vector) # TODO: add alpha as a parameter scores, route_names = self.index.query( - vector=vector_arr, + vector=vector, top_k=self.top_k, route_filter=route_filter, sparse_vector=sparse_vector,