diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py index c2668478089d868f0f65ac5a2892afcab3f10642..cdf2cc7dc9dca9f1e09eb3fe42db038b730d4e40 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/llama_index/embeddings/nvidia/base.py @@ -15,6 +15,11 @@ from openai import OpenAI, AsyncOpenAI BASE_RETRIEVAL_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" DEFAULT_MODEL = "NV-Embed-QA" +MODEL_ENDPOINT_MAP = { + DEFAULT_MODEL: BASE_RETRIEVAL_URL, + "snowflake/arctic-embed-l": "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l", +} + class Model(BaseModel): id: str @@ -69,9 +74,15 @@ class NVIDIAEmbedding(BaseEmbedding): "api_key", nvidia_api_key or api_key, "NVIDIA_API_KEY", "none" ) + # TODO: we should not assume unknown models are at the base url, but + # we cannot error out here because + # NVIDIAEmbedding(model="special").mode("nim", base_url=...) + # is valid usage + base_url = MODEL_ENDPOINT_MAP.get(model, BASE_RETRIEVAL_URL) + self._client = OpenAI( api_key=api_key, - base_url=BASE_RETRIEVAL_URL, + base_url=base_url, timeout=timeout, max_retries=max_retries, ) @@ -79,7 +90,7 @@ class NVIDIAEmbedding(BaseEmbedding): self._aclient = AsyncOpenAI( api_key=api_key, - base_url=BASE_RETRIEVAL_URL, + base_url=base_url, timeout=timeout, max_retries=max_retries, ) @@ -95,7 +106,7 @@ class NVIDIAEmbedding(BaseEmbedding): @property def available_models(self) -> List[Model]: """Get available models.""" - ids = [DEFAULT_MODEL] + ids = MODEL_ENDPOINT_MAP.keys() if self._mode == "nim": ids = [model.id for model in self._client.models.list()] return [Model(id=id) for id in ids] @@ -116,7 +127,8 @@ class NVIDIAEmbedding(BaseEmbedding): if not base_url: raise ValueError("base_url is required for nim mode") if not base_url: - base_url = BASE_RETRIEVAL_URL + # TODO: we should not assume unknown models are at the base url + base_url = MODEL_ENDPOINT_MAP.get(model or self.model, BASE_RETRIEVAL_URL) self._mode = mode if base_url: diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml index bb337a0c0b7055380eed3ab1f2b7e1ecdbceb6b8..7256f262c5b28c270658cdcefdca0a76be932922 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-nvidia" readme = "README.md" -version = "0.1.1" +version = "0.1.2" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..146a017bf8f22c69e5bee978c60b7a47156620d1 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-nvidia/tests/test_integration.py @@ -0,0 +1,12 @@ +import pytest + +from llama_index.embeddings.nvidia import NVIDIAEmbedding + + +@pytest.mark.integration() +def test_basic(model: str, mode: dict) -> None: + client = NVIDIAEmbedding(model=model).mode(**mode) + response = client.get_query_embedding("Hello, world!") + assert isinstance(response, list) + assert len(response) > 0 + assert isinstance(response[0], float)