From d460becf151943262ddaecff443a56be8c38e79d Mon Sep 17 00:00:00 2001
From: Anush008 <anushshetty90@gmail.com>
Date: Thu, 13 Jun 2024 15:29:30 +0530
Subject: [PATCH] chore: Qdrant async query

---
 semantic_router/index/qdrant.py | 71 +++++++++++++++++++++++++++++----
 1 file changed, 64 insertions(+), 7 deletions(-)

diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index bb49f1fd..165a4c93 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -6,6 +6,8 @@ from pydantic.v1 import Field
 from semantic_router.index.base import BaseIndex
 from semantic_router.schema import Metric
 
+from semantic_router.utils.logger import logger
+
 DEFAULT_COLLECTION_NAME = "semantic-router-index"
 DEFAULT_UPLOAD_BATCH_SIZE = 100
 SCROLL_SIZE = 1000
@@ -86,17 +88,18 @@ class QdrantIndex(BaseIndex):
         description="Collection options passed to `QdrantClient#create_collection`.",
     )
     client: Any = Field(default=None, exclude=True)
+    aclient: Any = Field(default=None, exclude=True)
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.type = "qdrant"
-        self.client = self._initialize_client()
+        self.client, self.aclient = self._initialize_clients()
 
-    def _initialize_client(self):
+    def _initialize_clients(self):
         try:
-            from qdrant_client import QdrantClient
+            from qdrant_client import QdrantClient, AsyncQdrantClient
 
-            return QdrantClient(
+            sync_client = QdrantClient(
                 location=self.location,
                 url=self.url,
                 port=self.port,
@@ -111,6 +114,27 @@ class QdrantIndex(BaseIndex):
                 grpc_options=self.grpc_options,
             )
 
+            async_client: Optional[AsyncQdrantClient] = None
+
+            if all([self.location != ":memory:", self.path is None]):
+                # Local Qdrant cannot interoperate with sync and async clients
+                # We fallback to sync operations in this case
+                async_client = AsyncQdrantClient(
+                    location=self.location,
+                    url=self.url,
+                    port=self.port,
+                    grpc_port=self.grpc_port,
+                    prefer_grpc=self.prefer_grpc,
+                    https=self.https,
+                    api_key=self.api_key,
+                    prefix=self.prefix,
+                    timeout=self.timeout,
+                    host=self.host,
+                    path=self.path,
+                    grpc_options=self.grpc_options,
+                )
+
+            return sync_client, async_client
         except ImportError as e:
             raise ImportError(
                 "Please install 'qdrant-client' to use QdrantIndex."
@@ -223,11 +247,44 @@ class QdrantIndex(BaseIndex):
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
     ) -> Tuple[np.ndarray, List[str]]:
-        from qdrant_client import models
+        from qdrant_client import models, QdrantClient
+
+        self.client: QdrantClient
+        filter = None
+        if route_filter is not None:
+            filter = models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key=SR_ROUTE_PAYLOAD_KEY,
+                        match=models.MatchAny(any=route_filter),
+                    )
+                ]
+            )
 
         results = self.client.search(
-            self.index_name, query_vector=vector, limit=top_k, with_payload=True
+            self.index_name,
+            query_vector=vector,
+            limit=top_k,
+            with_payload=True,
+            query_filter=filter,
         )
+        scores = [result.score for result in results]
+        route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
+        return np.array(scores), route_names
+
+    async def aquery(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
+        from qdrant_client import models, AsyncQdrantClient
+
+        self.aclient: Optional[AsyncQdrantClient]
+        if self.aclient is None:
+            logger.warning("Cannot use async query with an in-memory Qdrant instance")
+            return self.query(vector, top_k, route_filter)
+
         filter = None
         if route_filter is not None:
             filter = models.Filter(
@@ -239,7 +296,7 @@ class QdrantIndex(BaseIndex):
                 ]
             )
 
-        results = self.client.search(
+        results = await self.aclient.search(
             self.index_name,
             query_vector=vector,
             limit=top_k,
-- 
GitLab