Skip to content
Snippets Groups Projects
Unverified Commit 622d422c authored by Anush008's avatar Anush008
Browse files

chore: collection_options to config

parent c45b679d
Branches
Tags
No related merge requests found
...@@ -18,15 +18,18 @@ class QdrantIndex(BaseIndex): ...@@ -18,15 +18,18 @@ class QdrantIndex(BaseIndex):
index_name: str = Field( index_name: str = Field(
default=DEFAULT_COLLECTION_NAME, default=DEFAULT_COLLECTION_NAME,
description=f"The name of the Qdrant collection to use. Defaults to '{DEFAULT_COLLECTION_NAME}'", description="Name of the Qdrant collection."
f"Default: '{DEFAULT_COLLECTION_NAME}'",
) )
location: Optional[str] = Field( location: Optional[str] = Field(
default=":memory:", default=":memory:",
description="If ':memory:' - use an in-memory Qdrant instance. Used as 'url' value otherwise", description="If ':memory:' - use an in-memory Qdrant instance."
"Used as 'url' value otherwise",
) )
url: Optional[str] = Field( url: Optional[str] = Field(
default=None, default=None,
description="Qualified URL of the Qdrant instance. Optional[scheme], host, Optional[port], Optional[prefix]", description="Qualified URL of the Qdrant instance."
"Optional[scheme], host, Optional[port], Optional[prefix]",
) )
port: Optional[int] = Field( port: Optional[int] = Field(
default=6333, default=6333,
...@@ -58,7 +61,8 @@ class QdrantIndex(BaseIndex): ...@@ -58,7 +61,8 @@ class QdrantIndex(BaseIndex):
) )
host: Optional[str] = Field( host: Optional[str] = Field(
default=None, default=None,
description="Host name of Qdrant service. If url and host are None, set to 'localhost'.", description="Host name of Qdrant service."
"If url and host are None, set to 'localhost'.",
) )
path: Optional[str] = Field( path: Optional[str] = Field(
default=None, default=None,
...@@ -66,24 +70,25 @@ class QdrantIndex(BaseIndex): ...@@ -66,24 +70,25 @@ class QdrantIndex(BaseIndex):
) )
grpc_options: Optional[Dict[str, Any]] = Field( grpc_options: Optional[Dict[str, Any]] = Field(
default=None, default=None,
description="Options to be passed to the low-level Qdrant GRPC client, if used.", description="Options to be passed to the low-level GRPC client, if used.",
) )
dimensions: Union[int, None] = Field( dimensions: Union[int, None] = Field(
default=None, default=None,
description="Embedding dimensions. Defaults to the embedding length of the configured encoder.", description="Embedding dimensions."
"Defaults to the embedding length of the configured encoder.",
) )
metric: Metric = Field( metric: Metric = Field(
default=Metric.COSINE, default=Metric.COSINE,
description="Distance metric to use for similarity search.", description="Distance metric to use for similarity search.",
) )
collection_options: Optional[Dict[str, Any]] = Field( config: Optional[Dict[str, Any]] = Field(
default={}, default={},
description="Additonal options to be passed to `QdrantClient#create_collection`.", description="Collection options passed to `QdrantClient#create_collection`.",
) )
client: Any = Field(default=None, exclude=True) client: Any = Field(default=None, exclude=True)
def __init__(self, **data): def __init__(self, **kwargs):
super().__init__(**data) super().__init__(**kwargs)
self.type = "qdrant" self.type = "qdrant"
self.client = self._initialize_client() self.client = self._initialize_client()
...@@ -128,7 +133,7 @@ class QdrantIndex(BaseIndex): ...@@ -128,7 +133,7 @@ class QdrantIndex(BaseIndex):
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=self.dimensions, distance=self.convert_metric(self.metric) size=self.dimensions, distance=self.convert_metric(self.metric)
), ),
**self.collection_options, **self.config,
) )
def add( def add(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment