From 024c418fb1d5bab927ee03dd36f579c52b4e5afe Mon Sep 17 00:00:00 2001 From: Matthew Farrellee <matt@cs.wisc.edu> Date: Fri, 17 May 2024 15:21:42 -0400 Subject: [PATCH] add snowflake/arctic-embed-l support (#13555) * add snowflake/arctic-embed-l to set of NVIDIAEmbedding models * bump version to 0.1.2 --- .../llama_index/embeddings/nvidia/base.py | 20 +++++++++++++++---- .../pyproject.toml | 2 +- .../tests/test_integration.py | 12 +++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py index c266847808..cdf2cc7dc9 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py @@ -15,6 +15,11 @@ from openai import OpenAI, AsyncOpenAI BASE_RETRIEVAL_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" DEFAULT_MODEL = "NV-Embed-QA" +MODEL_ENDPOINT_MAP = { + DEFAULT_MODEL: BASE_RETRIEVAL_URL, + "snowflake/arctic-embed-l": "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l", +} + class Model(BaseModel): id: str @@ -69,9 +74,15 @@ class NVIDIAEmbedding(BaseEmbedding): "api_key", nvidia_api_key or api_key, "NVIDIA_API_KEY", "none" ) + # TODO: we should not assume unknown models are at the base url, but + # we cannot error out here because + # NVIDIAEmbedding(model="special").mode("nim", base_url=...) + # is valid usage + base_url = MODEL_ENDPOINT_MAP.get(model, BASE_RETRIEVAL_URL) + self._client = OpenAI( api_key=api_key, - base_url=BASE_RETRIEVAL_URL, + base_url=base_url, timeout=timeout, max_retries=max_retries, ) @@ -79,7 +90,7 @@ class NVIDIAEmbedding(BaseEmbedding): self._aclient = AsyncOpenAI( api_key=api_key, - base_url=BASE_RETRIEVAL_URL, + base_url=base_url, timeout=timeout, max_retries=max_retries, ) @@ -95,7 +106,7 @@ class NVIDIAEmbedding(BaseEmbedding): @property def available_models(self) -> List[Model]: """Get available models.""" - ids = [DEFAULT_MODEL] + ids = MODEL_ENDPOINT_MAP.keys() if self._mode == "nim": ids = [model.id for model in self._client.models.list()] return [Model(id=id) for id in ids] @@ -116,7 +127,8 @@ class NVIDIAEmbedding(BaseEmbedding): if not base_url: raise ValueError("base_url is required for nim mode") if not base_url: - base_url = BASE_RETRIEVAL_URL + # TODO: we should not assume unknown models are at the base url + base_url = MODEL_ENDPOINT_MAP.get(model or self.model, BASE_RETRIEVAL_URL) self._mode = mode if base_url: diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml index bb337a0c0b..7256f262c5 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-nvidia" readme = "README.md" -version = "0.1.1" +version = "0.1.2" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py new file mode 100644 index 0000000000..146a017bf8 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py @@ -0,0 +1,12 @@ +import pytest + +from llama_index.embeddings.nvidia import NVIDIAEmbedding + + +@pytest.mark.integration() +def test_basic(model: str, mode: dict) -> None: + client = NVIDIAEmbedding(model=model).mode(**mode) + response = client.get_query_embedding("Hello, world!") + assert isinstance(response, list) + assert len(response) > 0 + assert isinstance(response[0], float) -- GitLab