diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 70a3f6ee12904d19e23466b98a82e2f25f03543e..3fe68d8ca7c104e2ebcbb686a0762b169d66cf4a 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -14,6 +14,6 @@ class BaseEncoder(BaseModel): def __call__(self, docs: List[Any]) -> List[List[float]]: raise NotImplementedError("Subclasses must implement this method") - + def acall(self, docs: List[Any]) -> List[List[float]]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 96303155151c4e43b5e3937f801f42120414be7f..6b2aeab38bf143687cb08bfdc92d742237aafd77 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -173,4 +173,3 @@ class OpenAIEncoder(BaseEncoder): embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings - diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index 6968583aacde861250a2a8155ad414732bb2b9eb..dba936004d3f9faf83fa61be9f87056aac3f7021 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -130,7 +130,7 @@ class AzureOpenAIEncoder(BaseEncoder): embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings - + async def acall(self, docs: List[str]) -> List[List[float]]: if self.async_client is None: raise ValueError("Azure OpenAI async client is not initialized.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 04cfc796fdab4058b13d85a824287fc96cadd527..3572144ccd956134166e5172ceb909860f94b9cd 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -100,14 +100,14 @@ class PineconeIndex(BaseIndex): pinecone_args["namespace"] = self.namespace return Pinecone(**pinecone_args) - + def _initialize_async_client(self, api_key: Optional[str] = None): async_client = aiohttp.ClientSession( headers={ "Api-Key": api_key, "Content-Type": "application/json", "X-Pinecone-API-Version": "2024-07", - "User-Agent": "source_tag=semanticrouter" + "User-Agent": "source_tag=semanticrouter", } ) return async_client @@ -158,7 +158,7 @@ class PineconeIndex(BaseIndex): if index is not None: self.host = self.client.describe_index(self.index_name)["host"] return index - + async def _init_async_index(self, force_create: bool = False) -> Union[Any, None]: index_stats = None indexes = await self._async_list_indexes() @@ -171,7 +171,7 @@ class PineconeIndex(BaseIndex): dimension=self.dimensions, metric=self.metric, cloud=self.cloud, - region=self.region + region=self.region, ) # TODO describe index and async sleep index_ready = "false" @@ -332,7 +332,7 @@ class PineconeIndex(BaseIndex): scores = [result["score"] for result in results["matches"]] route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names - + async def aquery( self, vector: np.ndarray, @@ -356,7 +356,6 @@ class PineconeIndex(BaseIndex): scores = [result["score"] for result in results["matches"]] route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names - def delete_index(self): self.client.delete_index(self.index_name) @@ -386,7 +385,7 @@ class PineconeIndex(BaseIndex): async def _async_list_indexes(self): async with self.async_client.get(f"{self.base_url}/indexes") as response: return await response.json(content_type=None) - + async def _async_create_index( self, name: str, @@ -399,12 +398,7 @@ class PineconeIndex(BaseIndex): "name": name, "dimension": dimension, "metric": metric, - "spec": { - "serverless": { - "cloud": cloud, - "region": region - } - }, + "spec": {"serverless": {"cloud": cloud, "region": region}}, } async with self.async_client.post( f"{self.base_url}/indexes", @@ -412,7 +406,7 @@ class PineconeIndex(BaseIndex): json=params, ) as response: return await response.json(content_type=None) - + async def _async_describe_index(self, name: str): async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: return await response.json(content_type=None) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 0b882ff67a440db1dc9d33e2eec3f30622d4ade3..6804352392bcc4ccab6772b40ae2cda5b3f3ab02 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -185,7 +185,6 @@ class RouteLayer: top_k: int = 5, aggregation: str = "sum", ): - logger.info("local") self.index: BaseIndex = index if index is not None else LocalIndex() if encoder is None: logger.warning( @@ -270,6 +269,43 @@ class RouteLayer: # if no route passes threshold, return empty route choice return RouteChoice() + async def acall( + self, + text: Optional[str] = None, + vector: Optional[List[float]] = None, + simulate_static: bool = False, + route_filter: Optional[List[str]] = None, + ) -> RouteChoice: + # if no vector provided, encode text to get vector + if vector is None: + if text is None: + raise ValueError("Either text or vector must be provided") + vector = await self._async_encode(text=text) + + route, top_class_scores = await self._async_retrieve_top_route( + vector, route_filter + ) + passed = self._check_threshold(top_class_scores, route) + if passed and route is not None and not simulate_static: + if route.function_schemas and text is None: + raise ValueError( + "Route has a function schema, but no text was provided." + ) + if route.function_schemas and not isinstance(route.llm, BaseLLM): + raise NotImplementedError( + "Dynamic routes not yet supported for async calls." + ) + return route(text) + elif passed and route is not None and simulate_static: + return RouteChoice( + name=route.name, + function_call=None, + similarity_score=None, + ) + else: + # if no route passes threshold, return empty route choice + return RouteChoice() + def retrieve_multiple_routes( self, text: Optional[str] = None, @@ -313,6 +349,19 @@ class RouteLayer: route = self.check_for_matching_routes(top_class) return route, top_class_scores + async def _async_retrieve_top_route( + self, vector: List[float], route_filter: Optional[List[str]] = None + ) -> Tuple[Optional[Route], List[float]]: + # get relevant results (scores and routes) + results = await self._async_retrieve( + xq=np.array(vector), top_k=self.top_k, route_filter=route_filter + ) + # decide most relevant routes + top_class, top_class_scores = await self._async_semantic_classify(results) + # TODO do we need this check? + route = self.check_for_matching_routes(top_class) + return route, top_class_scores + def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: """ Check if the route's score passes the specified threshold. @@ -425,6 +474,13 @@ class RouteLayer: xq = np.squeeze(xq) # Reduce to 1d array. return xq + async def _async_encode(self, text: str) -> Any: + """Given some text, encode it.""" + # create query vector + xq = np.array(await self.encoder.acall(docs=[text])) + xq = np.squeeze(xq) # Reduce to 1d array. + return xq + def _retrieve( self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None ) -> List[Dict]: @@ -435,6 +491,16 @@ class RouteLayer: ) return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] + async def _async_retrieve( + self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None + ) -> List[Dict]: + """Given a query vector, retrieve the top_k most similar records.""" + # get scores and routes + scores, routes = await self.index.aquery( + vector=xq, top_k=top_k, route_filter=route_filter + ) + return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] + def _set_aggregation_method(self, aggregation: str = "sum"): if aggregation == "sum": return lambda x: sum(x) @@ -464,6 +530,25 @@ class RouteLayer: logger.warning("No classification found for semantic classifier.") return "", [] + async def _async_semantic_classify( + self, query_results: List[Dict] + ) -> Tuple[str, List[float]]: + scores_by_class = await self.async_group_scores_by_class(query_results) + + # Calculate total score for each class + total_scores = { + route: self.aggregation_method(scores) + for route, scores in scores_by_class.items() + } + top_class = max(total_scores, key=lambda x: total_scores[x], default=None) + + # Return the top class and its associated scores + if top_class is not None: + return str(top_class), scores_by_class.get(top_class, []) + else: + logger.warning("No classification found for semantic classifier.") + return "", [] + def get(self, name: str) -> Optional[Route]: for route in self.routes: if route.name == name: @@ -507,6 +592,19 @@ class RouteLayer: scores_by_class[route] = [score] return scores_by_class + async def async_group_scores_by_class( + self, query_results: List[Dict] + ) -> Dict[str, List[float]]: + scores_by_class: Dict[str, List[float]] = {} + for result in query_results: + score = result["score"] + route = result["route"] + if route in scores_by_class: + scores_by_class[route].append(score) + else: + scores_by_class[route] = [score] + return scores_by_class + def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold