From d460becf151943262ddaecff443a56be8c38e79d Mon Sep 17 00:00:00 2001 From: Anush008 <anushshetty90@gmail.com> Date: Thu, 13 Jun 2024 15:29:30 +0530 Subject: [PATCH] chore: Qdrant async query --- semantic_router/index/qdrant.py | 71 +++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index bb49f1fd..165a4c93 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -6,6 +6,8 @@ from pydantic.v1 import Field from semantic_router.index.base import BaseIndex from semantic_router.schema import Metric +from semantic_router.utils.logger import logger + DEFAULT_COLLECTION_NAME = "semantic-router-index" DEFAULT_UPLOAD_BATCH_SIZE = 100 SCROLL_SIZE = 1000 @@ -86,17 +88,18 @@ class QdrantIndex(BaseIndex): description="Collection options passed to `QdrantClient#create_collection`.", ) client: Any = Field(default=None, exclude=True) + aclient: Any = Field(default=None, exclude=True) def __init__(self, **kwargs): super().__init__(**kwargs) self.type = "qdrant" - self.client = self._initialize_client() + self.client, self.aclient = self._initialize_clients() - def _initialize_client(self): + def _initialize_clients(self): try: - from qdrant_client import QdrantClient + from qdrant_client import QdrantClient, AsyncQdrantClient - return QdrantClient( + sync_client = QdrantClient( location=self.location, url=self.url, port=self.port, @@ -111,6 +114,27 @@ class QdrantIndex(BaseIndex): grpc_options=self.grpc_options, ) + async_client: Optional[AsyncQdrantClient] = None + + if all([self.location != ":memory:", self.path is None]): + # Local Qdrant cannot interoperate with sync and async clients + # We fallback to sync operations in this case + async_client = AsyncQdrantClient( + location=self.location, + url=self.url, + port=self.port, + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + https=self.https, + api_key=self.api_key, + prefix=self.prefix, + timeout=self.timeout, + host=self.host, + path=self.path, + grpc_options=self.grpc_options, + ) + + return sync_client, async_client except ImportError as e: raise ImportError( "Please install 'qdrant-client' to use QdrantIndex." @@ -223,11 +247,44 @@ class QdrantIndex(BaseIndex): top_k: int = 5, route_filter: Optional[List[str]] = None, ) -> Tuple[np.ndarray, List[str]]: - from qdrant_client import models + from qdrant_client import models, QdrantClient + + self.client: QdrantClient + filter = None + if route_filter is not None: + filter = models.Filter( + must=[ + models.FieldCondition( + key=SR_ROUTE_PAYLOAD_KEY, + match=models.MatchAny(any=route_filter), + ) + ] + ) results = self.client.search( - self.index_name, query_vector=vector, limit=top_k, with_payload=True + self.index_name, + query_vector=vector, + limit=top_k, + with_payload=True, + query_filter=filter, ) + scores = [result.score for result in results] + route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results] + return np.array(scores), route_names + + async def aquery( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: + from qdrant_client import models, AsyncQdrantClient + + self.aclient: Optional[AsyncQdrantClient] + if self.aclient is None: + logger.warning("Cannot use async query with an in-memory Qdrant instance") + return self.query(vector, top_k, route_filter) + filter = None if route_filter is not None: filter = models.Filter( @@ -239,7 +296,7 @@ class QdrantIndex(BaseIndex): ] ) - results = self.client.search( + results = await self.aclient.search( self.index_name, query_vector=vector, limit=top_k, -- GitLab