From dc14624e6ea1bed2a2c1166f99966d19e8d6d13e Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sun, 15 Dec 2024 16:28:43 +0400 Subject: [PATCH] fix: types and arrays --- semantic_router/routers/base.py | 52 +++++++++++++++++++++++---------- tests/unit/test_router.py | 2 ++ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 18d71ced..0bfc4eea 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -4,6 +4,7 @@ import os import random import hashlib from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing_extensions import deprecated from pydantic import BaseModel, Field import numpy as np @@ -280,6 +281,20 @@ class RouterConfig: ) +def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray: + # convert to numpy array if not already + if not isinstance(xq, np.ndarray): + xq = np.array(xq) + # check if vector is 1D and expand to 2D if necessary + if len(xq.shape) == 1: + xq = np.expand_dims(xq, axis=0) + if xq.shape[0] != 1: + raise ValueError( + f"Expected (1, x) dimensional input for query, got {xq.shape}." + ) + return xq + + class BaseRouter(BaseModel): encoder: DenseEncoder = Field(default_factory=OpenAIEncoder) index: BaseIndex = Field(default_factory=BaseIndex) @@ -402,7 +417,7 @@ class BaseRouter(BaseModel): def __call__( self, text: Optional[str] = None, - vector: Optional[List[float]] = None, + vector: Optional[List[float] | np.ndarray] = None, simulate_static: bool = False, route_filter: Optional[List[str]] = None, ) -> RouteChoice: @@ -411,6 +426,9 @@ class BaseRouter(BaseModel): if text is None: raise ValueError("Either text or vector must be provided") vector = self._encode(text=[text]) + # convert to numpy array if not already + vector = xq_reshape(vector) + # calculate semantics route, top_class_scores = self._retrieve_top_route(vector, route_filter) passed = self._check_threshold(top_class_scores, route) if passed and route is not None and not simulate_static: @@ -444,7 +462,7 @@ class BaseRouter(BaseModel): async def acall( self, text: Optional[str] = None, - vector: Optional[List[float]] = None, + vector: Optional[List[float] | np.ndarray] = None, simulate_static: bool = False, route_filter: Optional[List[str]] = None, ) -> RouteChoice: @@ -453,7 +471,9 @@ class BaseRouter(BaseModel): if text is None: raise ValueError("Either text or vector must be provided") vector = await self._async_encode(text=[text]) - + # convert to numpy array if not already + vector = xq_reshape(vector) + # calculate semantics route, top_class_scores = await self._async_retrieve_top_route( vector, route_filter ) @@ -483,19 +503,21 @@ class BaseRouter(BaseModel): # if no route passes threshold, return empty route choice return RouteChoice() + # TODO: add multiple routes return to __call__ and acall + @deprecated("This method is deprecated. Use `__call__` instead.") def retrieve_multiple_routes( self, text: Optional[str] = None, - vector: Optional[List[float]] = None, + vector: Optional[List[float] | np.ndarray] = None, ) -> List[RouteChoice]: if vector is None: if text is None: raise ValueError("Either text or vector must be provided") - vector_arr = self._encode(text=[text]) - else: - vector_arr = np.array(vector) + vector = self._encode(text=[text]) + # convert to numpy array if not already + vector = xq_reshape(vector) # get relevant utterances - results = self._retrieve(xq=vector_arr) + results = self._retrieve(xq=vector) # decide most relevant routes categories_with_scores = self._semantic_classify_multiple_routes(results) return [ @@ -514,16 +536,14 @@ class BaseRouter(BaseModel): # return route_choices def _retrieve_top_route( - self, vector: List[float], route_filter: Optional[List[str]] = None + self, vector: np.ndarray, route_filter: Optional[List[str]] = None ) -> Tuple[Optional[Route], List[float]]: """ Retrieve the top matching route based on the given vector. Returns a tuple of the route (if any) and the scores of the top class. """ # get relevant results (scores and routes) - results = self._retrieve( - xq=vector[0], top_k=self.top_k, route_filter=route_filter - ) + results = self._retrieve(xq=vector, top_k=self.top_k, route_filter=route_filter) # decide most relevant routes top_class, top_class_scores = self._semantic_classify(results) # TODO do we need this check? @@ -531,11 +551,11 @@ class BaseRouter(BaseModel): return route, top_class_scores async def _async_retrieve_top_route( - self, vector: List[float], route_filter: Optional[List[str]] = None + self, vector: np.ndarray, route_filter: Optional[List[str]] = None ) -> Tuple[Optional[Route], List[float]]: # get relevant results (scores and routes) results = await self._async_retrieve( - xq=vector[0], top_k=self.top_k, route_filter=route_filter + xq=vector, top_k=self.top_k, route_filter=route_filter ) # decide most relevant routes top_class, top_class_scores = await self._async_semantic_classify(results) @@ -939,7 +959,7 @@ class BaseRouter(BaseModel): """Given a query vector, retrieve the top_k most similar records.""" # get scores and routes scores, routes = self.index.query( - vector=xq, top_k=top_k, route_filter=route_filter + vector=xq[0], top_k=top_k, route_filter=route_filter ) return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] @@ -949,7 +969,7 @@ class BaseRouter(BaseModel): """Given a query vector, retrieve the top_k most similar records.""" # get scores and routes scores, routes = await self.index.aquery( - vector=xq, top_k=top_k, route_filter=route_filter + vector=xq[0], top_k=top_k, route_filter=route_filter ) return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index 1741865b..1f743f1c 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -826,6 +826,8 @@ class TestSemanticRouter: auto_sync="local", ) vector = [0.1, 0.2, 0.3] + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) # allow for index to be populated results = route_layer.retrieve_multiple_routes(vector=vector) assert len(results) >= 1, "Expected at least one result" assert any( -- GitLab