Skip to content
Snippets Groups Projects
Unverified Commit 024c418f authored by Matthew Farrellee's avatar Matthew Farrellee Committed by GitHub
Browse files

add snowflake/arctic-embed-l support (#13555)

* add snowflake/arctic-embed-l to set of NVIDIAEmbedding models

* bump version to 0.1.2
parent e6daed5f
Branches
Tags
No related merge requests found
......@@ -15,6 +15,11 @@ from openai import OpenAI, AsyncOpenAI
BASE_RETRIEVAL_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia"
DEFAULT_MODEL = "NV-Embed-QA"
MODEL_ENDPOINT_MAP = {
DEFAULT_MODEL: BASE_RETRIEVAL_URL,
"snowflake/arctic-embed-l": "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l",
}
class Model(BaseModel):
id: str
......@@ -69,9 +74,15 @@ class NVIDIAEmbedding(BaseEmbedding):
"api_key", nvidia_api_key or api_key, "NVIDIA_API_KEY", "none"
)
# TODO: we should not assume unknown models are at the base url, but
# we cannot error out here because
# NVIDIAEmbedding(model="special").mode("nim", base_url=...)
# is valid usage
base_url = MODEL_ENDPOINT_MAP.get(model, BASE_RETRIEVAL_URL)
self._client = OpenAI(
api_key=api_key,
base_url=BASE_RETRIEVAL_URL,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
)
......@@ -79,7 +90,7 @@ class NVIDIAEmbedding(BaseEmbedding):
self._aclient = AsyncOpenAI(
api_key=api_key,
base_url=BASE_RETRIEVAL_URL,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
)
......@@ -95,7 +106,7 @@ class NVIDIAEmbedding(BaseEmbedding):
@property
def available_models(self) -> List[Model]:
"""Get available models."""
ids = [DEFAULT_MODEL]
ids = MODEL_ENDPOINT_MAP.keys()
if self._mode == "nim":
ids = [model.id for model in self._client.models.list()]
return [Model(id=id) for id in ids]
......@@ -116,7 +127,8 @@ class NVIDIAEmbedding(BaseEmbedding):
if not base_url:
raise ValueError("base_url is required for nim mode")
if not base_url:
base_url = BASE_RETRIEVAL_URL
# TODO: we should not assume unknown models are at the base url
base_url = MODEL_ENDPOINT_MAP.get(model or self.model, BASE_RETRIEVAL_URL)
self._mode = mode
if base_url:
......
......@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-nvidia"
readme = "README.md"
version = "0.1.1"
version = "0.1.2"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
......
import pytest
from llama_index.embeddings.nvidia import NVIDIAEmbedding
@pytest.mark.integration()
def test_basic(model: str, mode: dict) -> None:
client = NVIDIAEmbedding(model=model).mode(**mode)
response = client.get_query_embedding("Hello, world!")
assert isinstance(response, list)
assert len(response) > 0
assert isinstance(response[0], float)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment