Skip to content
Snippets Groups Projects
Unverified Commit 1a0e705d authored by Matthew Farrellee's avatar Matthew Farrellee Committed by GitHub
Browse files

add dynamic model listing support (#13398)

parent 0928dc25
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,6 @@ from llama_index.core.base.llms.generic_utils import ( ...@@ -10,7 +10,6 @@ from llama_index.core.base.llms.generic_utils import (
get_from_param_or_env, get_from_param_or_env,
) )
from llama_index.llms.nvidia.utils import API_CATALOG_MODELS
from llama_index.llms.openai_like import OpenAILike from llama_index.llms.openai_like import OpenAILike
...@@ -54,10 +53,19 @@ class NVIDIA(OpenAILike): ...@@ -54,10 +53,19 @@ class NVIDIA(OpenAILike):
@property @property
def available_models(self) -> List[Model]: def available_models(self) -> List[Model]:
ids = API_CATALOG_MODELS.keys() exclude = {
"mistralai/mixtral-8x22b-v0.1", # not a /chat/completion endpoint
}
# do not exclude models in nim mode. the nim administrator has control
# over the model name and may deploy an excluded name on the nim's
# /chat/completion endpoint.
if self._mode == "nim": if self._mode == "nim":
ids = [model.id for model in self._get_client().models.list()] exclude = set()
return [Model(id=name) for name in ids] return [
model
for model in self._get_client().models.list().data
if model.id not in exclude
]
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
......
...@@ -30,7 +30,7 @@ license = "MIT" ...@@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-llms-nvidia" name = "llama-index-llms-nvidia"
packages = [{include = "llama_index/"}] packages = [{include = "llama_index/"}]
readme = "README.md" readme = "README.md"
version = "0.1.1" version = "0.1.2"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
...@@ -73,3 +73,16 @@ async def test_astream_complete(chat_model: str, mode: dict) -> None: ...@@ -73,3 +73,16 @@ async def test_astream_complete(chat_model: str, mode: dict) -> None:
responses = [response async for response in gen] responses = [response async for response in gen]
assert all(isinstance(response, CompletionResponse) for response in responses) assert all(isinstance(response, CompletionResponse) for response in responses)
assert all(isinstance(response.delta, str) for response in responses) assert all(isinstance(response.delta, str) for response in responses)
@pytest.mark.integration()
@pytest.mark.parametrize(
"excluded",
[
"mistralai/mixtral-8x22b-v0.1", # not a /chat/completion endpoint
],
)
def test_exclude_models(mode: dict, excluded: str) -> None:
assert excluded not in [
model.id for model in NVIDIA().mode(**mode).available_models
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment