Skip to content
Snippets Groups Projects
Unverified Commit d460becf authored by Anush008's avatar Anush008
Browse files

chore: Qdrant async query

parent 4da3da10
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,8 @@ from pydantic.v1 import Field
from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric
from semantic_router.utils.logger import logger
DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100
SCROLL_SIZE = 1000
......@@ -86,17 +88,18 @@ class QdrantIndex(BaseIndex):
description="Collection options passed to `QdrantClient#create_collection`.",
)
client: Any = Field(default=None, exclude=True)
aclient: Any = Field(default=None, exclude=True)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "qdrant"
self.client = self._initialize_client()
self.client, self.aclient = self._initialize_clients()
def _initialize_client(self):
def _initialize_clients(self):
try:
from qdrant_client import QdrantClient
from qdrant_client import QdrantClient, AsyncQdrantClient
return QdrantClient(
sync_client = QdrantClient(
location=self.location,
url=self.url,
port=self.port,
......@@ -111,6 +114,27 @@ class QdrantIndex(BaseIndex):
grpc_options=self.grpc_options,
)
async_client: Optional[AsyncQdrantClient] = None
if all([self.location != ":memory:", self.path is None]):
# Local Qdrant cannot interoperate with sync and async clients
# We fallback to sync operations in this case
async_client = AsyncQdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
grpc_options=self.grpc_options,
)
return sync_client, async_client
except ImportError as e:
raise ImportError(
"Please install 'qdrant-client' to use QdrantIndex."
......@@ -223,11 +247,44 @@ class QdrantIndex(BaseIndex):
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
from qdrant_client import models
from qdrant_client import models, QdrantClient
self.client: QdrantClient
filter = None
if route_filter is not None:
filter = models.Filter(
must=[
models.FieldCondition(
key=SR_ROUTE_PAYLOAD_KEY,
match=models.MatchAny(any=route_filter),
)
]
)
results = self.client.search(
self.index_name, query_vector=vector, limit=top_k, with_payload=True
self.index_name,
query_vector=vector,
limit=top_k,
with_payload=True,
query_filter=filter,
)
scores = [result.score for result in results]
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names
async def aquery(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
from qdrant_client import models, AsyncQdrantClient
self.aclient: Optional[AsyncQdrantClient]
if self.aclient is None:
logger.warning("Cannot use async query with an in-memory Qdrant instance")
return self.query(vector, top_k, route_filter)
filter = None
if route_filter is not None:
filter = models.Filter(
......@@ -239,7 +296,7 @@ class QdrantIndex(BaseIndex):
]
)
results = self.client.search(
results = await self.aclient.search(
self.index_name,
query_vector=vector,
limit=top_k,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment