Skip to content
Snippets Groups Projects
Unverified Commit a59e7d18 authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge pull request #355 from aurelio-labs/ashraq/sparse-vector

feat: add sparse vector for PineconeIndex
parents ae08fab8 75a17993
Branches
Tags v0.0.53
No related merge requests found
[tool.poetry] [tool.poetry]
name = "semantic-router" name = "semantic-router"
version = "0.0.52" version = "0.0.53"
description = "Super fast semantic router for AI decision making" description = "Super fast semantic router for AI decision making"
authors = [ authors = [
"James Briggs <james@aurelio.ai>", "James Briggs <james@aurelio.ai>",
......
...@@ -4,4 +4,4 @@ from semantic_router.route import Route ...@@ -4,4 +4,4 @@ from semantic_router.route import Route
__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] __all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]
__version__ = "0.0.50" __version__ = "0.0.53"
...@@ -99,6 +99,35 @@ class LocalIndex(BaseIndex): ...@@ -99,6 +99,35 @@ class LocalIndex(BaseIndex):
route_names = [self.routes[i] for i in idx] route_names = [self.routes[i] for i in idx]
return scores, route_names return scores, route_names
async def aquery(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query and return top_k results.
"""
if self.index is None or self.routes is None:
raise ValueError("Index or routes are not populated.")
if route_filter is not None:
filtered_index = []
filtered_routes = []
for route, vec in zip(self.routes, self.index):
if route in route_filter:
filtered_index.append(vec)
filtered_routes.append(route)
if not filtered_routes:
raise ValueError("No routes found matching the filter criteria.")
sim = similarity_matrix(vector, np.array(filtered_index))
scores, idx = top_scores(sim, top_k)
route_names = [filtered_routes[i] for i in idx]
else:
sim = similarity_matrix(vector, self.index)
scores, idx = top_scores(sim, top_k)
route_names = [self.routes[i] for i in idx]
return scores, route_names
def delete(self, route_name: str): def delete(self, route_name: str):
""" """
Delete all records of a specific route from the index. Delete all records of a specific route from the index.
......
...@@ -464,7 +464,25 @@ class PineconeIndex(BaseIndex): ...@@ -464,7 +464,25 @@ class PineconeIndex(BaseIndex):
vector: np.ndarray, vector: np.ndarray,
top_k: int = 5, top_k: int = 5,
route_filter: Optional[List[str]] = None, route_filter: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]: ) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query vector and return the top_k results.
:param vector: The query vector to search for.
:type vector: np.ndarray
:param top_k: The number of top results to return, defaults to 5.
:type top_k: int, optional
:param route_filter: A list of route names to filter the search results, defaults to None.
:type route_filter: Optional[List[str]], optional
:param kwargs: Additional keyword arguments for the query, including sparse_vector.
:type kwargs: Any
:keyword sparse_vector: An optional sparse vector to include in the query.
:type sparse_vector: Optional[dict]
:return: A tuple containing an array of scores and a list of route names.
:rtype: Tuple[np.ndarray, List[str]]
:raises ValueError: If the index is not populated.
"""
if self.index is None: if self.index is None:
raise ValueError("Index is not populated.") raise ValueError("Index is not populated.")
query_vector_list = vector.tolist() query_vector_list = vector.tolist()
...@@ -474,6 +492,7 @@ class PineconeIndex(BaseIndex): ...@@ -474,6 +492,7 @@ class PineconeIndex(BaseIndex):
filter_query = None filter_query = None
results = self.index.query( results = self.index.query(
vector=[query_vector_list], vector=[query_vector_list],
sparse_vector=kwargs.get("sparse_vector", None),
top_k=top_k, top_k=top_k,
filter=filter_query, filter=filter_query,
include_metadata=True, include_metadata=True,
...@@ -488,7 +507,25 @@ class PineconeIndex(BaseIndex): ...@@ -488,7 +507,25 @@ class PineconeIndex(BaseIndex):
vector: np.ndarray, vector: np.ndarray,
top_k: int = 5, top_k: int = 5,
route_filter: Optional[List[str]] = None, route_filter: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]: ) -> Tuple[np.ndarray, List[str]]:
"""
Asynchronously search the index for the query vector and return the top_k results.
:param vector: The query vector to search for.
:type vector: np.ndarray
:param top_k: The number of top results to return, defaults to 5.
:type top_k: int, optional
:param route_filter: A list of route names to filter the search results, defaults to None.
:type route_filter: Optional[List[str]], optional
:param kwargs: Additional keyword arguments for the query, including sparse_vector.
:type kwargs: Any
:keyword sparse_vector: An optional sparse vector to include in the query.
:type sparse_vector: Optional[dict]
:return: A tuple containing an array of scores and a list of route names.
:rtype: Tuple[np.ndarray, List[str]]
:raises ValueError: If the index is not populated.
"""
if self.async_client is None or self.host is None: if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.") raise ValueError("Async client or host are not initialized.")
query_vector_list = vector.tolist() query_vector_list = vector.tolist()
...@@ -498,6 +535,7 @@ class PineconeIndex(BaseIndex): ...@@ -498,6 +535,7 @@ class PineconeIndex(BaseIndex):
filter_query = None filter_query = None
results = await self._async_query( results = await self._async_query(
vector=query_vector_list, vector=query_vector_list,
sparse_vector=kwargs.get("sparse_vector", None),
namespace=self.namespace or "", namespace=self.namespace or "",
filter=filter_query, filter=filter_query,
top_k=top_k, top_k=top_k,
...@@ -514,6 +552,7 @@ class PineconeIndex(BaseIndex): ...@@ -514,6 +552,7 @@ class PineconeIndex(BaseIndex):
async def _async_query( async def _async_query(
self, self,
vector: list[float], vector: list[float],
sparse_vector: Optional[dict] = None,
namespace: str = "", namespace: str = "",
filter: Optional[dict] = None, filter: Optional[dict] = None,
top_k: int = 5, top_k: int = 5,
...@@ -521,6 +560,7 @@ class PineconeIndex(BaseIndex): ...@@ -521,6 +560,7 @@ class PineconeIndex(BaseIndex):
): ):
params = { params = {
"vector": vector, "vector": vector,
"sparse_vector": sparse_vector,
"namespace": namespace, "namespace": namespace,
"filter": filter, "filter": filter,
"top_k": top_k, "top_k": top_k,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment