diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 531eb345b561ce884f9a96829bb7a7906211c6a3..cc1ddaefe702dd56e22330db48516abbe4387f17 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -5,7 +5,7 @@ from pydantic.v1 import Field from semantic_router.index.base import BaseIndex -DEFAULT_COLLECTION_NAME = "semantic-router-collection" +DEFAULT_COLLECTION_NAME = "semantic-router-index" DEFAULT_UPLOAD_BATCH_SIZE = 100 SCROLL_SIZE = 1000 SR_UTTERANCE_PAYLOAD_KEY = "sr_utterance" @@ -15,7 +15,7 @@ SR_ROUTE_PAYLOAD_KEY = "sr_route" class QdrantIndex(BaseIndex): "The name of the collection to use" - collection_name: str = Field( + index_name: str = Field( default=DEFAULT_COLLECTION_NAME, description=f"The name of the Qdrant collection to use. Defaults to '{DEFAULT_COLLECTION_NAME}'", ) @@ -67,11 +67,11 @@ class QdrantIndex(BaseIndex): default=None, description="Options to be passed to the low-level Qdrant GRPC client, if used.", ) - size: Union[int, None] = Field( + dimensions: Union[int, None] = Field( default=None, description="Embedding dimensions. Defaults to the embedding length of the configured encoder.", ) - distance: str = Field( + metric: str = Field( default="Cosine", description="Distance metric to use for similarity search." ) collection_options: Optional[Dict[str, Any]] = Field( @@ -115,17 +115,17 @@ class QdrantIndex(BaseIndex): from qdrant_client import QdrantClient, models self.client: QdrantClient - if not self.client.collection_exists(self.collection_name): + if not self.client.collection_exists(self.index_name): if not self.dimensions: raise ValueError( "Cannot create a collection without specifying the dimensions." ) self.client.create_collection( - collection_name=self.collection_name, + collection_name=self.index_name, vectors_config=models.VectorParams( size=self.dimensions, - distance=self.distance, # type: ignore + distance=self.metric, # type: ignore ), **self.collection_options, ) @@ -147,7 +147,7 @@ class QdrantIndex(BaseIndex): # UUIDs are autogenerated by qdrant-client if not provided explicitly self.client.upload_collection( - self.collection_name, + self.index_name, vectors=embeddings, payload=payloads, batch_size=batch_size, @@ -168,7 +168,7 @@ class QdrantIndex(BaseIndex): stop_scrolling = False while not stop_scrolling: records, next_offset = self.client.scroll( - self.collection_name, + self.index_name, limit=SCROLL_SIZE, offset=next_offset, with_payload=True, @@ -191,7 +191,7 @@ class QdrantIndex(BaseIndex): from qdrant_client import models self.client.delete( - self.collection_name, + self.index_name, points_selector=models.Filter( must=[ models.FieldCondition( @@ -203,7 +203,7 @@ class QdrantIndex(BaseIndex): ) def describe(self) -> dict: - collection_info = self.client.get_collection(self.collection_name) + collection_info = self.client.get_collection(self.index_name) return { "type": self.type, @@ -213,14 +213,14 @@ class QdrantIndex(BaseIndex): def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: results = self.client.search( - self.collection_name, query_vector=vector, limit=top_k, with_payload=True + self.index_name, query_vector=vector, limit=top_k, with_payload=True ) scores = [result.score for result in results] route_names = [result.payload["sr_route"] for result in results] return np.array(scores), route_names def delete_index(self): - self.client.delete_collection(self.collection_name) + self.client.delete_collection(self.index_name) def __len__(self): - return self.client.get_collection(self.collection_name).points_count + return self.client.get_collection(self.index_name).points_count