From 1a0e705d4a25f523b0612cddf13a8a70b4fc7190 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee <matt@cs.wisc.edu> Date: Thu, 9 May 2024 16:01:14 -0400 Subject: [PATCH] add dynamic model listing support (#13398) --- .../llama_index/llms/nvidia/base.py | 16 ++++++++++++---- .../llms/llama-index-llms-nvidia/pyproject.toml | 2 +- .../tests/test_integration.py | 13 +++++++++++++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py b/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py index 61d28b47d..80b32393c 100644 --- a/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py +++ b/llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py @@ -10,7 +10,6 @@ from llama_index.core.base.llms.generic_utils import ( get_from_param_or_env, ) -from llama_index.llms.nvidia.utils import API_CATALOG_MODELS from llama_index.llms.openai_like import OpenAILike @@ -54,10 +53,19 @@ class NVIDIA(OpenAILike): @property def available_models(self) -> List[Model]: - ids = API_CATALOG_MODELS.keys() + exclude = { + "mistralai/mixtral-8x22b-v0.1", # not a /chat/completion endpoint + } + # do not exclude models in nim mode. the nim administrator has control + # over the model name and may deploy an excluded name on the nim's + # /chat/completion endpoint. if self._mode == "nim": - ids = [model.id for model in self._get_client().models.list()] - return [Model(id=name) for name in ids] + exclude = set() + return [ + model + for model in self._get_client().models.list().data + if model.id not in exclude + ] @classmethod def class_name(cls) -> str: diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml index e6a493855..9321b6536 100644 --- a/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-nvidia/pyproject.toml @@ -30,7 +30,7 @@ license = "MIT" name = "llama-index-llms-nvidia" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.1.1" +version = "0.1.2" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py b/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py index 6915a7968..f26fb1e56 100644 --- a/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py +++ b/llama-index-integrations/llms/llama-index-llms-nvidia/tests/test_integration.py @@ -73,3 +73,16 @@ async def test_astream_complete(chat_model: str, mode: dict) -> None: responses = [response async for response in gen] assert all(isinstance(response, CompletionResponse) for response in responses) assert all(isinstance(response.delta, str) for response in responses) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "excluded", + [ + "mistralai/mixtral-8x22b-v0.1", # not a /chat/completion endpoint + ], +) +def test_exclude_models(mode: dict, excluded: str) -> None: + assert excluded not in [ + model.id for model in NVIDIA().mode(**mode).available_models + ] -- GitLab