From c45b679d1e0fc4ce8f818bfe4f4ff572c5b46da4 Mon Sep 17 00:00:00 2001
From: Anush008 <anushshetty90@gmail.com>
Date: Mon, 18 Mar 2024 16:41:47 +0530
Subject: [PATCH] refactor: Addd metric enum

---
 semantic_router/index/qdrant.py | 24 ++++++++++++++++++++----
 semantic_router/schema.py       |  7 +++++++
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index cc1ddaef..e112fb35 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -4,6 +4,7 @@ import numpy as np
 from pydantic.v1 import Field
 
 from semantic_router.index.base import BaseIndex
+from semantic_router.schema import Metric
 
 DEFAULT_COLLECTION_NAME = "semantic-router-index"
 DEFAULT_UPLOAD_BATCH_SIZE = 100
@@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex):
         default=None,
         description="Embedding dimensions. Defaults to the embedding length of the configured encoder.",
     )
-    metric: str = Field(
-        default="Cosine", description="Distance metric to use for similarity search."
+    metric: Metric = Field(
+        default=Metric.COSINE,
+        description="Distance metric to use for similarity search.",
     )
     collection_options: Optional[Dict[str, Any]] = Field(
         default={},
@@ -124,8 +126,7 @@ class QdrantIndex(BaseIndex):
             self.client.create_collection(
                 collection_name=self.index_name,
                 vectors_config=models.VectorParams(
-                    size=self.dimensions,
-                    distance=self.metric,  # type: ignore
+                    size=self.dimensions, distance=self.convert_metric(self.metric)
                 ),
                 **self.collection_options,
             )
@@ -222,5 +223,20 @@ class QdrantIndex(BaseIndex):
     def delete_index(self):
         self.client.delete_collection(self.index_name)
 
+    def convert_metric(self, metric: Metric):
+        from qdrant_client.models import Distance
+
+        mapping = {
+            Metric.COSINE: Distance.COSINE,
+            Metric.EUCLIDEAN: Distance.EUCLID,
+            Metric.DOTPRODUCT: Distance.DOT,
+            Metric.MANHATTAN: Distance.MANHATTAN,
+        }
+
+        if metric not in mapping:
+            raise ValueError(f"Unsupported Qdrant similarity metric: {metric}")
+
+        return mapping[metric]
+
     def __len__(self):
         return self.client.get_collection(self.index_name).points_count
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 3e0cd5e5..85d428ef 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -85,3 +85,10 @@ class DocumentSplit(BaseModel):
     @property
     def content(self) -> str:
         return " ".join(self.docs)
+
+
+class Metric(Enum):
+    COSINE = "cosine"
+    DOTPRODUCT = "dotproduct"
+    EUCLIDEAN = "euclidean"
+    MANHATTAN = "manhattan"
-- 
GitLab