Skip to content
Snippets Groups Projects
Unverified Commit 8f1e9787 authored by Jared Van Bortel's avatar Jared Van Bortel Committed by GitHub
Browse files

Implement local Nomic Embed with the inference_mode parameter (#13607)

parent e648b8e8
No related branches found
No related tags found
No related merge requests found
from enum import Enum
from typing import Any, List, Optional, Union
import nomic
import nomic.embed
import torch
from llama_index.core.base.embeddings.base import (
BaseEmbedding,
DEFAULT_EMBED_BATCH_SIZE,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.embeddings.huggingface.pooling import Pooling
import torch
import logging
DEFAULT_HUGGINGFACE_LENGTH = 512
logger = logging.getLogger(__name__)
class NomicAITaskType(str, Enum):
class NomicTaskType(str, Enum):
SEARCH_QUERY = "search_query"
SEARCH_DOCUMENT = "search_document"
CLUSTERING = "clustering"
CLASSIFICATION = "classification"
TASK_TYPES = [
NomicAITaskType.SEARCH_QUERY,
NomicAITaskType.SEARCH_DOCUMENT,
NomicAITaskType.CLUSTERING,
NomicAITaskType.CLASSIFICATION,
]
class NomicInferenceMode(str, Enum):
REMOTE = "remote"
LOCAL = "local"
DYNAMIC = "dynamic"
class NomicEmbedding(BaseEmbedding):
"""NomicEmbedding uses the Nomic API to generate embeddings."""
# Instance variables initialized via Pydantic's mechanism
query_task_type: Optional[str] = Field(description="Query Embedding prefix")
document_task_type: Optional[str] = Field(description="Document Embedding prefix")
dimensionality: Optional[int] = Field(description="Dimension of the Embedding")
query_task_type: Optional[NomicTaskType] = Field(
description="Task type for queries",
)
document_task_type: Optional[NomicTaskType] = Field(
description="Task type for documents",
)
dimensionality: Optional[int] = Field(
description="Embedding dimension, for use with Matryoshka-capable models",
)
model_name: str = Field(description="Embedding model name")
_model: Any = PrivateAttr()
inference_mode: NomicInferenceMode = Field(
description="Whether to generate embeddings locally",
)
device: Optional[str] = Field(description="Device to use for local embeddings")
def __init__(
self,
......@@ -53,39 +56,22 @@ class NomicEmbedding(BaseEmbedding):
query_task_type: Optional[str] = "search_query",
document_task_type: Optional[str] = "search_document",
dimensionality: Optional[int] = 768,
**kwargs: Any,
) -> None:
if query_task_type not in TASK_TYPES or document_task_type not in TASK_TYPES:
raise ValueError(
f"Invalid task type {query_task_type}, {document_task_type}. Must be one of {TASK_TYPES}"
)
try:
import nomic
from nomic import embed
except ImportError:
raise ImportError(
"NomicEmbedding requires the 'nomic' package to be installed.\n"
"Please install it with `pip install nomic`."
)
inference_mode: str = "remote",
device: Optional[str] = None,
):
if api_key is not None:
nomic.cli.login(api_key)
nomic.login(api_key)
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
_model=embed,
query_task_type=query_task_type,
document_task_type=document_task_type,
dimensionality=dimensionality,
**kwargs,
inference_mode=inference_mode,
device=device,
)
self._model = embed
self.model_name = model_name
self.query_task_type = query_task_type
self.document_task_type = document_task_type
self.dimensionality = dimensionality
@classmethod
def class_name(cls) -> str:
......@@ -94,35 +80,38 @@ class NomicEmbedding(BaseEmbedding):
def _embed(
self, texts: List[str], task_type: Optional[str] = None
) -> List[List[float]]:
"""Embed sentences using NomicAI."""
result = self._model.text(
result = nomic.embed.text(
texts,
model=self.model_name,
task_type=task_type,
dimensionality=self.dimensionality,
inference_mode=self.inference_mode,
device=self.device,
)
return result["embeddings"]
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._embed([query], task_type=self.query_task_type)[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding async."""
self._warn_async()
return self._get_query_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._embed([text], task_type=self.document_task_type)[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Get text embedding async."""
self._warn_async()
return self._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return self._embed(texts, task_type=self.document_task_type)
def _warn_async() -> None:
warnings.warn(
f"{self.class_name()} does not implement async embeddings, falling back to sync method.",
)
class NomicHFEmbedding(HuggingFaceEmbedding):
tokenizer_name: str = Field(description="Tokenizer name from HuggingFace.")
......
......@@ -21,20 +21,19 @@ ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Your Name <you@example.com>"]
authors = ["Jared Van Bortel <jared@nomic.ai>", "Zach Nussbaum <zach@nomic.ai>"]
description = "llama-index embeddings nomic integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-nomic"
readme = "README.md"
version = "0.1.6"
version = "0.2.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.11.post1"
llama-index-embeddings-huggingface = "^0.1.3"
einops = "^0.7.0"
nomic = "^3.0.12"
nomic = "^3.0.29"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment