From c45b679d1e0fc4ce8f818bfe4f4ff572c5b46da4 Mon Sep 17 00:00:00 2001 From: Anush008 <anushshetty90@gmail.com> Date: Mon, 18 Mar 2024 16:41:47 +0530 Subject: [PATCH] refactor: Addd metric enum --- semantic_router/index/qdrant.py | 24 ++++++++++++++++++++---- semantic_router/schema.py | 7 +++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index cc1ddaef..e112fb35 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -4,6 +4,7 @@ import numpy as np from pydantic.v1 import Field from semantic_router.index.base import BaseIndex +from semantic_router.schema import Metric DEFAULT_COLLECTION_NAME = "semantic-router-index" DEFAULT_UPLOAD_BATCH_SIZE = 100 @@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex): default=None, description="Embedding dimensions. Defaults to the embedding length of the configured encoder.", ) - metric: str = Field( - default="Cosine", description="Distance metric to use for similarity search." + metric: Metric = Field( + default=Metric.COSINE, + description="Distance metric to use for similarity search.", ) collection_options: Optional[Dict[str, Any]] = Field( default={}, @@ -124,8 +126,7 @@ class QdrantIndex(BaseIndex): self.client.create_collection( collection_name=self.index_name, vectors_config=models.VectorParams( - size=self.dimensions, - distance=self.metric, # type: ignore + size=self.dimensions, distance=self.convert_metric(self.metric) ), **self.collection_options, ) @@ -222,5 +223,20 @@ class QdrantIndex(BaseIndex): def delete_index(self): self.client.delete_collection(self.index_name) + def convert_metric(self, metric: Metric): + from qdrant_client.models import Distance + + mapping = { + Metric.COSINE: Distance.COSINE, + Metric.EUCLIDEAN: Distance.EUCLID, + Metric.DOTPRODUCT: Distance.DOT, + Metric.MANHATTAN: Distance.MANHATTAN, + } + + if metric not in mapping: + raise ValueError(f"Unsupported Qdrant similarity metric: {metric}") + + return mapping[metric] + def __len__(self): return self.client.get_collection(self.index_name).points_count diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 3e0cd5e5..85d428ef 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -85,3 +85,10 @@ class DocumentSplit(BaseModel): @property def content(self) -> str: return " ".join(self.docs) + + +class Metric(Enum): + COSINE = "cosine" + DOTPRODUCT = "dotproduct" + EUCLIDEAN = "euclidean" + MANHATTAN = "manhattan" -- GitLab