diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nomic/llama_index/embeddings/nomic/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-nomic/llama_index/embeddings/nomic/base.py index 1a38b9f1a8bd13e257151d89929a31741e470ba4..b88aeed10c2c1c241dec6672eb31012f6138023b 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-nomic/llama_index/embeddings/nomic/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nomic/llama_index/embeddings/nomic/base.py @@ -27,6 +27,7 @@ class NomicEmbedding(BaseEmbedding): # Instance variables initialized via Pydantic's mechanism query_task_type: Optional[str] = Field(description="Query Embedding prefix") document_task_type: Optional[str] = Field(description="Document Embedding prefix") + dimensionality: Optional[int] = Field(description="Dimension of the Embedding") model_name: str = Field(description="Embedding model name") _model: Any = PrivateAttr() @@ -38,6 +39,7 @@ class NomicEmbedding(BaseEmbedding): callback_manager: Optional[CallbackManager] = None, query_task_type: Optional[str] = "search_query", document_task_type: Optional[str] = "search_document", + dimensionality: Optional[int] = 768, **kwargs: Any, ) -> None: if query_task_type not in TASK_TYPES or document_task_type not in TASK_TYPES: @@ -63,12 +65,14 @@ class NomicEmbedding(BaseEmbedding): _model=embed, query_task_type=query_task_type, document_task_type=document_task_type, + dimensionality=dimensionality, **kwargs, ) self._model = embed self.model_name = model_name self.query_task_type = query_task_type self.document_task_type = document_task_type + self.dimensionality = dimensionality @classmethod def class_name(cls) -> str: @@ -78,7 +82,12 @@ class NomicEmbedding(BaseEmbedding): self, texts: List[str], task_type: Optional[str] = None ) -> List[List[float]]: """Embed sentences using NomicAI.""" - result = self._model.text(texts, model=self.model_name, task_type=task_type) + result = self._model.text( + texts, + model=self.model_name, + task_type=task_type, + dimensionality=self.dimensionality, + ) return result["embeddings"] def _get_query_embedding(self, query: str) -> List[float]: