Skip to content
Snippets Groups Projects
Commit 28d69316 authored by James Briggs's avatar James Briggs
Browse files

fix: types and arrays for hybrid

parent dc14624e
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ from semantic_router.route import Route ...@@ -14,7 +14,7 @@ from semantic_router.route import Route
from semantic_router.index import BaseIndex, HybridLocalIndex from semantic_router.index import BaseIndex, HybridLocalIndex
from semantic_router.schema import RouteChoice, SparseEmbedding, Utterance from semantic_router.schema import RouteChoice, SparseEmbedding, Utterance
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from semantic_router.routers.base import BaseRouter from semantic_router.routers.base import BaseRouter, xq_reshape
from semantic_router.llms import BaseLLM from semantic_router.llms import BaseLLM
...@@ -197,18 +197,19 @@ class HybridRouter(BaseRouter): ...@@ -197,18 +197,19 @@ class HybridRouter(BaseRouter):
def __call__( def __call__(
self, self,
text: Optional[str] = None, text: Optional[str] = None,
vector: Optional[List[float]] = None, vector: Optional[List[float] | np.ndarray] = None,
simulate_static: bool = False, simulate_static: bool = False,
route_filter: Optional[List[str]] = None, route_filter: Optional[List[str]] = None,
sparse_vector: dict[int, float] | SparseEmbedding | None = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None,
) -> RouteChoice: ) -> RouteChoice:
vector_arr: np.ndarray | None = None
potential_sparse_vector: List[SparseEmbedding] | None = None potential_sparse_vector: List[SparseEmbedding] | None = None
# if no vector provided, encode text to get vector # if no vector provided, encode text to get vector
if vector is None: if vector is None:
if text is None: if text is None:
raise ValueError("Either text or vector must be provided") raise ValueError("Either text or vector must be provided")
vector_arr, potential_sparse_vector = self._encode(text=[text]) vector, potential_sparse_vector = self._encode(text=[text])
# convert to numpy array if not already
vector = xq_reshape(vector)
if sparse_vector is None: if sparse_vector is None:
if text is None: if text is None:
raise ValueError("Either text or sparse_vector must be provided") raise ValueError("Either text or sparse_vector must be provided")
...@@ -217,10 +218,9 @@ class HybridRouter(BaseRouter): ...@@ -217,10 +218,9 @@ class HybridRouter(BaseRouter):
) )
if sparse_vector is None: if sparse_vector is None:
raise ValueError("Sparse vector is required for HybridLocalIndex.") raise ValueError("Sparse vector is required for HybridLocalIndex.")
vector_arr = vector_arr if vector_arr is not None else np.array(vector)
# TODO: add alpha as a parameter # TODO: add alpha as a parameter
scores, route_names = self.index.query( scores, route_names = self.index.query(
vector=vector_arr, vector=vector,
top_k=self.top_k, top_k=self.top_k,
route_filter=route_filter, route_filter=route_filter,
sparse_vector=sparse_vector, sparse_vector=sparse_vector,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment