Skip to content
Snippets Groups Projects
Unverified Commit d460becf authored by Anush008's avatar Anush008
Browse files

chore: Qdrant async query

parent 4da3da10
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,8 @@ from pydantic.v1 import Field ...@@ -6,6 +6,8 @@ from pydantic.v1 import Field
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric from semantic_router.schema import Metric
from semantic_router.utils.logger import logger
DEFAULT_COLLECTION_NAME = "semantic-router-index" DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100 DEFAULT_UPLOAD_BATCH_SIZE = 100
SCROLL_SIZE = 1000 SCROLL_SIZE = 1000
...@@ -86,17 +88,18 @@ class QdrantIndex(BaseIndex): ...@@ -86,17 +88,18 @@ class QdrantIndex(BaseIndex):
description="Collection options passed to `QdrantClient#create_collection`.", description="Collection options passed to `QdrantClient#create_collection`.",
) )
client: Any = Field(default=None, exclude=True) client: Any = Field(default=None, exclude=True)
aclient: Any = Field(default=None, exclude=True)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.type = "qdrant" self.type = "qdrant"
self.client = self._initialize_client() self.client, self.aclient = self._initialize_clients()
def _initialize_client(self): def _initialize_clients(self):
try: try:
from qdrant_client import QdrantClient from qdrant_client import QdrantClient, AsyncQdrantClient
return QdrantClient( sync_client = QdrantClient(
location=self.location, location=self.location,
url=self.url, url=self.url,
port=self.port, port=self.port,
...@@ -111,6 +114,27 @@ class QdrantIndex(BaseIndex): ...@@ -111,6 +114,27 @@ class QdrantIndex(BaseIndex):
grpc_options=self.grpc_options, grpc_options=self.grpc_options,
) )
async_client: Optional[AsyncQdrantClient] = None
if all([self.location != ":memory:", self.path is None]):
# Local Qdrant cannot interoperate with sync and async clients
# We fallback to sync operations in this case
async_client = AsyncQdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
grpc_options=self.grpc_options,
)
return sync_client, async_client
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please install 'qdrant-client' to use QdrantIndex." "Please install 'qdrant-client' to use QdrantIndex."
...@@ -223,11 +247,44 @@ class QdrantIndex(BaseIndex): ...@@ -223,11 +247,44 @@ class QdrantIndex(BaseIndex):
top_k: int = 5, top_k: int = 5,
route_filter: Optional[List[str]] = None, route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]: ) -> Tuple[np.ndarray, List[str]]:
from qdrant_client import models from qdrant_client import models, QdrantClient
self.client: QdrantClient
filter = None
if route_filter is not None:
filter = models.Filter(
must=[
models.FieldCondition(
key=SR_ROUTE_PAYLOAD_KEY,
match=models.MatchAny(any=route_filter),
)
]
)
results = self.client.search( results = self.client.search(
self.index_name, query_vector=vector, limit=top_k, with_payload=True self.index_name,
query_vector=vector,
limit=top_k,
with_payload=True,
query_filter=filter,
) )
scores = [result.score for result in results]
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names
async def aquery(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
from qdrant_client import models, AsyncQdrantClient
self.aclient: Optional[AsyncQdrantClient]
if self.aclient is None:
logger.warning("Cannot use async query with an in-memory Qdrant instance")
return self.query(vector, top_k, route_filter)
filter = None filter = None
if route_filter is not None: if route_filter is not None:
filter = models.Filter( filter = models.Filter(
...@@ -239,7 +296,7 @@ class QdrantIndex(BaseIndex): ...@@ -239,7 +296,7 @@ class QdrantIndex(BaseIndex):
] ]
) )
results = self.client.search( results = await self.aclient.search(
self.index_name, self.index_name,
query_vector=vector, query_vector=vector,
limit=top_k, limit=top_k,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment