From 1a0e705d4a25f523b0612cddf13a8a70b4fc7190 Mon Sep 17 00:00:00 2001
From: Matthew Farrellee <matt@cs.wisc.edu>
Date: Thu, 9 May 2024 16:01:14 -0400
Subject: [PATCH] add dynamic model listing support (#13398)

---
 .../llama_index/llms/nvidia/base.py              | 16 ++++++++++++----
 .../llms/llama-index-llms-nvidia/pyproject.toml  |  2 +-
 .../tests/test_integration.py                    | 13 +++++++++++++
 3 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py b/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py
index 61d28b47d..80b32393c 100644
--- a/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py
@@ -10,7 +10,6 @@ from llama_index.core.base.llms.generic_utils import (
     get_from_param_or_env,
 )
 
-from llama_index.llms.nvidia.utils import API_CATALOG_MODELS
 
 from llama_index.llms.openai_like import OpenAILike
 
@@ -54,10 +53,19 @@ class NVIDIA(OpenAILike):
 
     @property
     def available_models(self) -> List[Model]:
-        ids = API_CATALOG_MODELS.keys()
+        exclude = {
+            "mistralai/mixtral-8x22b-v0.1",  # not a /chat/completion endpoint
+        }
+        # do not exclude models in nim mode. the nim administrator has control
+        # over the model name and may deploy an excluded name on the nim's
+        # /chat/completion endpoint.
         if self._mode == "nim":
-            ids = [model.id for model in self._get_client().models.list()]
-        return [Model(id=name) for name in ids]
+            exclude = set()
+        return [
+            model
+            for model in self._get_client().models.list().data
+            if model.id not in exclude
+        ]
 
     @classmethod
     def class_name(cls) -> str:
diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml
index e6a493855..9321b6536 100644
--- a/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml
@@ -30,7 +30,7 @@ license = "MIT"
 name = "llama-index-llms-nvidia"
 packages = [{include = "llama_index/"}]
 readme = "README.md"
-version = "0.1.1"
+version = "0.1.2"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py b/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py
index 6915a7968..f26fb1e56 100644
--- a/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py
+++ b/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py
@@ -73,3 +73,16 @@ async def test_astream_complete(chat_model: str, mode: dict) -> None:
     responses = [response async for response in gen]
     assert all(isinstance(response, CompletionResponse) for response in responses)
     assert all(isinstance(response.delta, str) for response in responses)
+
+
+@pytest.mark.integration()
+@pytest.mark.parametrize(
+    "excluded",
+    [
+        "mistralai/mixtral-8x22b-v0.1",  # not a /chat/completion endpoint
+    ],
+)
+def test_exclude_models(mode: dict, excluded: str) -> None:
+    assert excluded not in [
+        model.id for model in NVIDIA().mode(**mode).available_models
+    ]
-- 
GitLab