Skip to content
Snippets Groups Projects
Unverified Commit c45b679d authored by Anush008's avatar Anush008
Browse files

refactor: Addd metric enum

parent a5703d5d
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
from pydantic.v1 import Field from pydantic.v1 import Field
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric
DEFAULT_COLLECTION_NAME = "semantic-router-index" DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100 DEFAULT_UPLOAD_BATCH_SIZE = 100
...@@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex): ...@@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex):
default=None, default=None,
description="Embedding dimensions. Defaults to the embedding length of the configured encoder.", description="Embedding dimensions. Defaults to the embedding length of the configured encoder.",
) )
metric: str = Field( metric: Metric = Field(
default="Cosine", description="Distance metric to use for similarity search." default=Metric.COSINE,
description="Distance metric to use for similarity search.",
) )
collection_options: Optional[Dict[str, Any]] = Field( collection_options: Optional[Dict[str, Any]] = Field(
default={}, default={},
...@@ -124,8 +126,7 @@ class QdrantIndex(BaseIndex): ...@@ -124,8 +126,7 @@ class QdrantIndex(BaseIndex):
self.client.create_collection( self.client.create_collection(
collection_name=self.index_name, collection_name=self.index_name,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=self.dimensions, size=self.dimensions, distance=self.convert_metric(self.metric)
distance=self.metric, # type: ignore
), ),
**self.collection_options, **self.collection_options,
) )
...@@ -222,5 +223,20 @@ class QdrantIndex(BaseIndex): ...@@ -222,5 +223,20 @@ class QdrantIndex(BaseIndex):
def delete_index(self): def delete_index(self):
self.client.delete_collection(self.index_name) self.client.delete_collection(self.index_name)
def convert_metric(self, metric: Metric):
from qdrant_client.models import Distance
mapping = {
Metric.COSINE: Distance.COSINE,
Metric.EUCLIDEAN: Distance.EUCLID,
Metric.DOTPRODUCT: Distance.DOT,
Metric.MANHATTAN: Distance.MANHATTAN,
}
if metric not in mapping:
raise ValueError(f"Unsupported Qdrant similarity metric: {metric}")
return mapping[metric]
def __len__(self): def __len__(self):
return self.client.get_collection(self.index_name).points_count return self.client.get_collection(self.index_name).points_count
...@@ -85,3 +85,10 @@ class DocumentSplit(BaseModel): ...@@ -85,3 +85,10 @@ class DocumentSplit(BaseModel):
@property @property
def content(self) -> str: def content(self) -> str:
return " ".join(self.docs) return " ".join(self.docs)
class Metric(Enum):
COSINE = "cosine"
DOTPRODUCT = "dotproduct"
EUCLIDEAN = "euclidean"
MANHATTAN = "manhattan"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment