diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 5271b897298af0d99efb5e6f23f87c94da4e7b12..f8a965df015139cd20afb3bbf3a6b9b0f87531cc 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -43,7 +43,12 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") - def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: + def query( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: """ Search the index for the query_vector and return top_k results. This method should be implemented by subclasses. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 81da71d1d6114cb58849ec364a0fa175254020c6..4bf212dcbe9742759f4ac2624c0d3ad8b02a9601 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -56,17 +56,35 @@ class LocalIndex(BaseIndex): "vectors": self.index.shape[0] if self.index is not None else 0, } - def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: + def query( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: """ Search the index for the query and return top_k results. """ if self.index is None or self.routes is None: raise ValueError("Index or routes are not populated.") - sim = similarity_matrix(vector, self.index) - # extract the index values of top scoring vectors - scores, idx = top_scores(sim, top_k) - # get routes from index values - route_names = self.routes[idx].copy() + if route_filter is not None: + print(f"Filtering routes with filter: {route_filter}") + filtered_index = [] + filtered_routes = [] + for route, vec in zip(self.routes, self.index): + if route in route_filter: + filtered_index.append(vec) + filtered_routes.append(route) + if not filtered_routes: + raise ValueError("No routes found matching the filter criteria.") + sim = similarity_matrix(vector, np.array(filtered_index)) + scores, idx = top_scores(sim, top_k) + route_names = [filtered_routes[i] for i in idx] + else: + sim = similarity_matrix(vector, self.index) + scores, idx = top_scores(sim, top_k) + route_names = [self.routes[i] for i in idx] + print(f"Routes considered for similarity calculation: {route_names}") return scores, route_names def delete(self, route_name: str): diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 321dee3045d373354a313f84e52db22eefa06689..500b78e9e866ee3bafdd7cddf3648c1b4e4302ef 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -219,17 +219,29 @@ class PineconeIndex(BaseIndex): else: raise ValueError("Index is None, cannot describe index stats.") - def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: + def query( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: if self.index is None: raise ValueError("Index is not populated.") query_vector_list = vector.tolist() + if route_filter is not None: + print(f"Filtering routes with filter: {route_filter}") + filter_query = {"sr_route": {"$in": route_filter}} + else: + filter_query = None results = self.index.query( vector=[query_vector_list], top_k=top_k, + filter=filter_query, include_metadata=True, ) scores = [result["score"] for result in results["matches"]] route_names = [result["metadata"]["sr_route"] for result in results["matches"]] + print(f"Routes considered for similarity calculation: {route_names}") return np.array(scores), route_names def delete_index(self): diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 221de2bef3f02e5992750fd98d63fd1e02e94659..2d5eee0c86bfb3a67568592a5735772c4198f3b3 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -232,6 +232,7 @@ class RouteLayer: text: Optional[str] = None, vector: Optional[List[float]] = None, simulate_static: bool = False, + route_filter: Optional[List[str]] = None, ) -> RouteChoice: # if no vector provided, encode text to get vector if vector is None: @@ -239,7 +240,8 @@ class RouteLayer: raise ValueError("Either text or vector must be provided") vector = self._encode(text=text) - route, top_class_scores = self._retrieve_top_route(vector) + route, top_class_scores = self._retrieve_top_route(vector, route_filter) + print(f"Selected route: {route.name if route else 'None'}") passed = self._check_threshold(top_class_scores, route) if passed and route is not None and not simulate_static: @@ -271,14 +273,16 @@ class RouteLayer: return RouteChoice() def _retrieve_top_route( - self, vector: List[float] + self, vector: List[float], route_filter: Optional[List[str]] = None ) -> Tuple[Optional[Route], List[float]]: """ Retrieve the top matching route based on the given vector. Returns a tuple of the route (if any) and the scores of the top class. """ # get relevant results (scores and routes) - results = self._retrieve(xq=np.array(vector), top_k=self.top_k) + results = self._retrieve( + xq=np.array(vector), top_k=self.top_k, route_filter=route_filter + ) # decide most relevant routes top_class, top_class_scores = self._semantic_classify(results) # TODO do we need this check? @@ -397,10 +401,14 @@ class RouteLayer: xq = np.squeeze(xq) # Reduce to 1d array. return xq - def _retrieve(self, xq: Any, top_k: int = 5) -> List[dict]: + def _retrieve( + self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None + ) -> List[dict]: """Given a query vector, retrieve the top_k most similar records.""" # get scores and routes - scores, routes = self.index.query(vector=xq, top_k=top_k) + scores, routes = self.index.query( + vector=xq, top_k=top_k, route_filter=route_filter + ) return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] def _set_aggregation_method(self, aggregation: str = "sum"):